Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stabilize torch.topk() behavior #290

Merged
merged 11 commits into from
Feb 14, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

- Instead of having to specify `train_from_scratch` in the config file, training will proceed from an existing model weights file if this is given as an argument to `casanovo train`.

### Fixed

- Fixed beam search decoding error due to non-deterministic selection of beams with equal scores.

## [4.0.0] - 2023-12-22

### Added
Expand Down
1 change: 1 addition & 0 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The command line entry point for Casanovo."""

import datetime
import functools
import logging
Expand Down
1 change: 1 addition & 0 deletions casanovo/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Parse the YAML configuration."""

import logging
import shutil
from pathlib import Path
Expand Down
1 change: 1 addition & 0 deletions casanovo/data/datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A PyTorch Dataset class for annotated spectra."""

from typing import Optional, Tuple

import depthcharge
Expand Down
1 change: 1 addition & 0 deletions casanovo/data/ms_io.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Mass spectrometry file type input/output operations."""

import collections
import csv
import operator
Expand Down
1 change: 1 addition & 0 deletions casanovo/denovo/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Data loaders for the de novo sequencing task."""

import functools
import os
from typing import List, Optional, Tuple
Expand Down
1 change: 1 addition & 0 deletions casanovo/denovo/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Methods to evaluate peptide-spectrum predictions."""

import re
from typing import Dict, Iterable, List, Tuple

Expand Down
21 changes: 9 additions & 12 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A de novo peptide sequencing model."""

import collections
import heapq
import logging
Expand Down Expand Up @@ -606,21 +607,17 @@ def _get_topk_beams(
scores[:, step, :, :], "B V S -> B (V S)"
)

# Mask out terminated beams. Include precursor m/z tolerance induced
# termination.
# TODO: `clone()` is necessary to get the correct output with n_beams=1.
# An alternative implementation using base PyTorch instead of einops
# might be more efficient.
finished_mask = einops.repeat(
finished_beams, "(B S) -> B (V S)", S=beam, V=vocab
).clone()
# Find all still active beams by masking out terminated beams.
active_mask = (
~finished_beams.reshape(batch, beam).repeat(1, vocab)
).float()
# Mask out the index '0', i.e. padding token, by default.
finished_mask[:, :beam] = True
# FIXME: Set this to a very small, yet non-zero value, to only
# get padding after stop token.
active_mask[:, :beam] = 1e-8

# Figure out the top K decodings.
_, top_idx = torch.topk(
step_scores.nanmean(dim=1) * (~finished_mask).float(), beam
)
_, top_idx = torch.topk(step_scores.nanmean(dim=1) * active_mask, beam)
v_idx, s_idx = np.unravel_index(top_idx.cpu(), (vocab, beam))
s_idx = einops.rearrange(s_idx, "B S -> (B S)")
b_idx = einops.repeat(torch.arange(batch), "B -> (B S)", S=beam)
Expand Down
7 changes: 4 additions & 3 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Training and testing functionality for the de novo peptide sequencing
model."""

import glob
import logging
import os
Expand Down Expand Up @@ -306,9 +307,9 @@ def initialize_data_module(
self,
train_index: Optional[AnnotatedSpectrumIndex] = None,
valid_index: Optional[AnnotatedSpectrumIndex] = None,
test_index: (
Optional[Union[AnnotatedSpectrumIndex, SpectrumIndex]]
) = None,
test_index: Optional[
Union[AnnotatedSpectrumIndex, SpectrumIndex]
] = None,
) -> None:
"""Initialize the data module

Expand Down
1 change: 1 addition & 0 deletions casanovo/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Small utility functions"""

import logging
import os
import platform
Expand Down
1 change: 1 addition & 0 deletions casanovo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Package version information."""

from typing import Optional


Expand Down
164 changes: 11 additions & 153 deletions docs/images/help.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Fixtures used for testing."""

import numpy as np
import psims
import pytest
Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test configuration loading"""

import pytest
import yaml

Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Unit tests specifically for the model_runner module."""

import pytest
import torch

Expand Down
43 changes: 43 additions & 0 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,49 @@ def test_beam_search_decode():
)
assert torch.equal(discarded_beams, torch.tensor([False, True, True]))

# Test _get_topk_beams() with finished beams in the batch.
model = Spec2Pep(n_beams=1, residues="massivekb", min_peptide_len=3)

# Sizes and other variables.
batch = 2 # B
beam = model.n_beams # S
model.decoder.reverse = True
length = model.max_length + 1 # L
vocab = model.decoder.vocab_size + 1 # V
step = 4

# Initialize dummyy scores and tokens.
scores = torch.full(
size=(batch, length, vocab, beam), fill_value=torch.nan
)
scores = einops.rearrange(scores, "B L V S -> (B S) L V")
tokens = torch.zeros(batch * beam, length, dtype=torch.int64)

# Simulate non-zero amino acid-level probability scores.
scores[:, : step + 1, :] = torch.rand(batch, step + 1, vocab)
scores[:, step, range(1, 4)] = torch.tensor([1.0, 2.0, 3.0])

# Simulate one finished and one unfinished beam in the same batch.
tokens[0, :step] = torch.tensor([4, 14, 4, 28])
tokens[1, :step] = torch.tensor([4, 14, 4, 1])

# Set finished beams array to allow decoding from only one beam.
test_finished_beams = torch.tensor([True, False])

new_tokens, new_scores = model._get_topk_beams(
tokens, scores, test_finished_beams, batch, step
)

# Only the second peptide should have a new token predicted.
expected_tokens = torch.tensor(
[
[4, 14, 4, 28, 0],
[4, 14, 4, 1, 3],
]
)

assert torch.equal(new_tokens[:, : step + 1], expected_tokens)


def test_eval_metrics():
"""
Expand Down