diff --git a/casanovo/config.yaml b/casanovo/config.yaml index ffb9bf45..74d6b782 100644 --- a/casanovo/config.yaml +++ b/casanovo/config.yaml @@ -63,8 +63,8 @@ max_mods: 1 # where aa is a standard amino acid (or "nterm" for an N-terminal mod) # and mod_residue is a key from the "residues" dictionary. # Example: "M:M+15.995,nterm:+43.006" -allowed_fixed_mods: "C:C+57.021" -allowed_var_mods: "M:M+15.995,N:N+0.984,Q:Q+0.984,nterm:+42.011,nterm:+43.006,nterm:-17.027,nterm:+43.006-17.027" +allowed_fixed_mods: "C:C[Carbamidomethyl]" +allowed_var_mods: "M:M[Oxidation],N:N[Deamidated],Q:Q[Deamidated],nterm:[Acetyl]-,nterm:[Carbamyl]-,nterm:[Ammonia-loss]-,nterm:[+25.980265]-" ### diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 69730ed2..53c6a9a0 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -1288,8 +1288,8 @@ def _pep_batch_ready(self, num_candidate_psms: int) -> bool: True if the batch is ready, False otherwise. """ return ( - num_candidate_psms % self.psm_batch_size - ) == self.psm_batch_size - 1 + num_candidate_psms % self.psm_batch_size == 0 + ) and num_candidate_psms != 0 def _initialize_psm_batch(self, batch: Dict[str, Any]) -> Dict[str, List]: """ diff --git a/tests/test_integration.py b/tests/test_integration.py index b5adfa96..948cff63 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -6,7 +6,6 @@ from click.testing import CliRunner from casanovo import casanovo -from casanovo.config import Config TEST_DIR = Path(__file__).resolve().parent diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 05fe5a11..d5458d84 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -5,6 +5,7 @@ import hashlib import heapq import io +import math import os import pathlib import platform @@ -26,10 +27,16 @@ import torch from casanovo import casanovo, utils +from casanovo.config import Config from casanovo.data import db_utils, ms_io from casanovo.denovo.dataloaders import DeNovoDataModule from casanovo.denovo.evaluate import aa_match, aa_match_batch, aa_match_metrics -from casanovo.denovo.model import Spec2Pep, _aa_pep_score, _calc_match_score +from casanovo.denovo.model import ( + DbSpec2Pep, + Spec2Pep, + _aa_pep_score, + _calc_match_score, +) def test_version(): @@ -1008,13 +1015,104 @@ def test_digest_fasta_enzyme(tiny_fasta_file): ), tokenizer=depthcharge.tokenizers.PeptideTokenizer.from_massivekb(), ) - peptide_list = pdb.db_peptides.index.to_list() + assert pdb.db_peptides.index.to_list() == expected_nonspecific - first = peptide_list[:50] - second = peptide_list[50:100] - third = peptide_list[100:] - assert pdb.db_peptides.index.to_list() == expected_nonspecific +def test_psm_batches(tiny_config): + peptides_one = [ + "SGSGSG", + "GSGSGT", + "SGSGTD", + "FSGSGS", + "ATSIPA", + "GASTRA", + "LSLSPG", + "ASQSVS", + "GSGTDF", + "SLSPGE", + "AQLLFL", + "QPEDFA", + ] + + peptides_two = [ + "SQSVSS", + "KPGQAP", + "SPPTLS", + "ASTRAT", + "RFSGSG", + "IYGAST", + "APAQLL", + "PTLSLS", + "TLSLSP", + "TLTISS", + "WYQQKP", + "TWYQQK", + ] + + def mock_get_candidates(precursor_mz, precorsor_charge): + if precorsor_charge == 1: + return pd.Series(peptides_one) + elif precorsor_charge == 2: + return pd.Series(peptides_two) + else: + return pd.Series() + + tokenizer = depthcharge.tokenizers.peptides.PeptideTokenizer( + residues=Config(tiny_config).residues + ) + db_model = DbSpec2Pep(tokenizer=tokenizer) + db_model.protein_database = unittest.mock.MagicMock() + db_model.protein_database.get_candidates = mock_get_candidates + + mock_batch = { + "precursor_mz": torch.Tensor([42.0, 84.0, 126.0]), + "precursor_charge": torch.Tensor([1, 2, 3]), + "peak_file": ["one.mgf", "two.mgf", "three.mgf"], + "scan_id": [1, 2, 3], + } + + expected_batch_all = { + "precursor_mz": torch.Tensor([42.0] * 12 + [84.0] * 12), + "precursor_charge": torch.Tensor([1] * 12 + [2] * 12), + "seq": tokenizer.tokenize(peptides_one + peptides_two), + "peak_file": ["one.mgf"] * 12 + ["two.mgf"] * 12, + "scan_id": [1] * 12 + [2] * 12, + } + + for psm_batch_size in [24, 12, 8, 10]: + db_model.psm_batch_size = psm_batch_size + psm_batches = list(db_model._psm_batches(mock_batch)) + assert len(psm_batches) == math.ceil(24 / psm_batch_size) + num_spectra = 0 + + for psm_batch in psm_batches: + end_idx = min( + num_spectra + psm_batch_size, + len(expected_batch_all["peak_file"]), + ) + assert torch.allclose( + psm_batch["precursor_mz"], + expected_batch_all["precursor_mz"][num_spectra:end_idx], + ) + assert torch.equal( + psm_batch["precursor_charge"], + expected_batch_all["precursor_charge"][num_spectra:end_idx], + ) + assert torch.equal( + psm_batch["seq"], + expected_batch_all["seq"][num_spectra:end_idx], + ) + assert ( + psm_batch["peak_file"] + == expected_batch_all["peak_file"][num_spectra:end_idx] + ) + assert ( + psm_batch["scan_id"] + == expected_batch_all["scan_id"][num_spectra:end_idx] + ) + num_spectra += len(psm_batch["peak_file"]) + + assert num_spectra == 24 def test_get_candidates(tiny_fasta_file):