Skip to content

Commit

Permalink
psm batch generator unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Dec 2, 2024
1 parent ec20013 commit 2233839
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 11 deletions.
4 changes: 2 additions & 2 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ max_mods: 1
# where aa is a standard amino acid (or "nterm" for an N-terminal mod)
# and mod_residue is a key from the "residues" dictionary.
# Example: "M:M+15.995,nterm:+43.006"
allowed_fixed_mods: "C:C+57.021"
allowed_var_mods: "M:M+15.995,N:N+0.984,Q:Q+0.984,nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027"
allowed_fixed_mods: "C:C[Carbamidomethyl]"
allowed_var_mods: "M:M[Oxidation],N:N[Deamidated],Q:Q[Deamidated],nterm:[Acetyl]-,nterm:[Carbamyl]-,nterm:[Ammonia-loss]-,nterm:[+25.980265]-"


###
Expand Down
4 changes: 2 additions & 2 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,8 +1288,8 @@ def _pep_batch_ready(self, num_candidate_psms: int) -> bool:
True if the batch is ready, False otherwise.
"""
return (
num_candidate_psms % self.psm_batch_size
) == self.psm_batch_size - 1
num_candidate_psms % self.psm_batch_size == 0
) and num_candidate_psms != 0

def _initialize_psm_batch(self, batch: Dict[str, Any]) -> Dict[str, List]:
"""
Expand Down
1 change: 0 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from click.testing import CliRunner

from casanovo import casanovo
from casanovo.config import Config

TEST_DIR = Path(__file__).resolve().parent

Expand Down
110 changes: 104 additions & 6 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hashlib
import heapq
import io
import math
import os
import pathlib
import platform
Expand All @@ -26,10 +27,16 @@
import torch

from casanovo import casanovo, utils
from casanovo.config import Config
from casanovo.data import db_utils, ms_io
from casanovo.denovo.dataloaders import DeNovoDataModule
from casanovo.denovo.evaluate import aa_match, aa_match_batch, aa_match_metrics
from casanovo.denovo.model import Spec2Pep, _aa_pep_score, _calc_match_score
from casanovo.denovo.model import (
DbSpec2Pep,
Spec2Pep,
_aa_pep_score,
_calc_match_score,
)


def test_version():
Expand Down Expand Up @@ -1008,13 +1015,104 @@ def test_digest_fasta_enzyme(tiny_fasta_file):
),
tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(),
)
peptide_list = pdb.db_peptides.index.to_list()
assert pdb.db_peptides.index.to_list() == expected_nonspecific

first = peptide_list[:50]
second = peptide_list[50:100]
third = peptide_list[100:]

assert pdb.db_peptides.index.to_list() == expected_nonspecific
def test_psm_batches(tiny_config):
peptides_one = [
"SGSGSG",
"GSGSGT",
"SGSGTD",
"FSGSGS",
"ATSIPA",
"GASTRA",
"LSLSPG",
"ASQSVS",
"GSGTDF",
"SLSPGE",
"AQLLFL",
"QPEDFA",
]

peptides_two = [
"SQSVSS",
"KPGQAP",
"SPPTLS",
"ASTRAT",
"RFSGSG",
"IYGAST",
"APAQLL",
"PTLSLS",
"TLSLSP",
"TLTISS",
"WYQQKP",
"TWYQQK",
]

def mock_get_candidates(precursor_mz, precorsor_charge):
if precorsor_charge == 1:
return pd.Series(peptides_one)
elif precorsor_charge == 2:
return pd.Series(peptides_two)
else:
return pd.Series()

tokenizer = depthcharge.tokenizers.peptides.PeptideTokenizer(
residues=Config(tiny_config).residues
)
db_model = DbSpec2Pep(tokenizer=tokenizer)
db_model.protein_database = unittest.mock.MagicMock()
db_model.protein_database.get_candidates = mock_get_candidates

mock_batch = {
"precursor_mz": torch.Tensor([42.0, 84.0, 126.0]),
"precursor_charge": torch.Tensor([1, 2, 3]),
"peak_file": ["one.mgf", "two.mgf", "three.mgf"],
"scan_id": [1, 2, 3],
}

expected_batch_all = {
"precursor_mz": torch.Tensor([42.0] * 12 + [84.0] * 12),
"precursor_charge": torch.Tensor([1] * 12 + [2] * 12),
"seq": tokenizer.tokenize(peptides_one + peptides_two),
"peak_file": ["one.mgf"] * 12 + ["two.mgf"] * 12,
"scan_id": [1] * 12 + [2] * 12,
}

for psm_batch_size in [24, 12, 8, 10]:
db_model.psm_batch_size = psm_batch_size
psm_batches = list(db_model._psm_batches(mock_batch))
assert len(psm_batches) == math.ceil(24 / psm_batch_size)
num_spectra = 0

for psm_batch in psm_batches:
end_idx = min(
num_spectra + psm_batch_size,
len(expected_batch_all["peak_file"]),
)
assert torch.allclose(
psm_batch["precursor_mz"],
expected_batch_all["precursor_mz"][num_spectra:end_idx],
)
assert torch.equal(
psm_batch["precursor_charge"],
expected_batch_all["precursor_charge"][num_spectra:end_idx],
)
assert torch.equal(
psm_batch["seq"],
expected_batch_all["seq"][num_spectra:end_idx],
)
assert (
psm_batch["peak_file"]
== expected_batch_all["peak_file"][num_spectra:end_idx]
)
assert (
psm_batch["scan_id"]
== expected_batch_all["scan_id"][num_spectra:end_idx]
)
num_spectra += len(psm_batch["peak_file"])

assert num_spectra == 24


def test_get_candidates(tiny_fasta_file):
Expand Down

0 comments on commit 2233839

Please sign in to comment.