From 7cb8e141ccab5b865a3af00711d290cd6cab788d Mon Sep 17 00:00:00 2001 From: VarunAnanth2003 Date: Mon, 12 Aug 2024 14:50:18 -0700 Subject: [PATCH] small fixes regarding documentation, import syntax, etc. --- casanovo/casanovo.py | 39 ++++++---- casanovo/data/db_utils.py | 71 +++++++++-------- casanovo/denovo/dataloaders.py | 10 +-- casanovo/denovo/model.py | 31 ++++---- casanovo/denovo/model_runner.py | 24 ++---- tests/conftest.py | 11 +-- tests/unit_tests/test_unit.py | 132 +++++++++++--------------------- 7 files changed, 137 insertions(+), 181 deletions(-) diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index 8ae9a81b..4b9b4e38 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -130,7 +130,7 @@ def sequence( ) -> None: """De novo sequence peptides from tandem mass spectra. - PEAK_PATH must be one or more mzMl, mzXML, or MGF files from which + PEAK_PATH must be one or more mzML, mzXML, or MGF files from which to sequence peptides. """ output = setup_logging(output, verbosity) @@ -205,7 +205,7 @@ def sequence( ) @click.option( "--digestion", - help="Digestion: full, partial", + help="Full: standard digestion. Semi: Include products of semi-specific cleavage", type=click.Choice( ["full", "partial"], case_sensitive=False, @@ -214,37 +214,41 @@ def sequence( ) @click.option( "--missed_cleavages", - help="Number of allowed missed cleavages", + help="Number of allowed missed cleavages when digesting protein", type=int, default=0, ) @click.option( "--max_mods", - help="Maximum number of modifications per peptide", + help="Maximum number of amino acid modifications per peptide", type=int, default=0, ) @click.option( - "--min_length", - help="Minimum peptide length", + "--min_peptide_length", + help="Minimum peptide length to consider", type=int, default=6, ) @click.option( - "--max_length", - help="Maximum peptide length", + "--max_peptide_length", + help="Maximum peptide length to consider", type=int, default=50, ) @click.option( "--precursor_tolerance", - help="Precursor tolerance window size (ppm)", - type=int, + help="Precursor tolerance window size (units: ppm)", + type=float, default=20, ) @click.option( "--isotope_error", - help="Isotope error levels to consider (list of ints, e.g: 1,2)", + help="Isotope error levels to consider. \ + Creates multiple mass windows to consider per spectrum \ + to account for observed mass not matching monoisotopic mass \ + due to the instrument assigning the 13C isotope \ + peak as the precursor (list of ints, e.g: 1,2)", type=str, default="0", ) @@ -255,9 +259,9 @@ def db_search( digestion: str, missed_cleavages: int, max_mods: int, - min_length: int, - max_length: int, - precursor_tolerance: int, + min_peptide_length: int, + max_peptide_length: int, + precursor_tolerance: float, isotope_error: str, model: Optional[str], config: Optional[str], @@ -266,7 +270,8 @@ def db_search( ) -> None: """Perform a database search on MS/MS data using Casanovo-DB. - PEAK_PATH must be one MGF file. FASTA_PATH must be one FASTA file. + PEAK_PATH must be one or more mzML, mzXML, or MGF files. + FASTA_PATH must be one FASTA file. """ output = setup_logging(output, verbosity) config, model = setup_model(model, config, output, False) @@ -284,8 +289,8 @@ def db_search( digestion, missed_cleavages, max_mods, - min_length, - max_length, + min_peptide_length, + max_peptide_length, precursor_tolerance, isotope_error, output, diff --git a/casanovo/data/db_utils.py b/casanovo/data/db_utils.py index 921c75bd..1af09a47 100644 --- a/casanovo/data/db_utils.py +++ b/casanovo/data/db_utils.py @@ -1,15 +1,16 @@ """Unique methods used within db-search mode""" -import os -import depthcharge.masses -from pyteomics import fasta, parser import bisect import logging - +import os from typing import List, Tuple +import depthcharge.masses +from pyteomics import fasta, parser + logger = logging.getLogger("casanovo") + # CONSTANTS HYDROGEN = 1.007825035 OXYGEN = 15.99491463 @@ -51,8 +52,8 @@ def digest_fasta( digestion: str, missed_cleavages: int, max_mods: int, - min_length: int, - max_length: int, + min_peptide_length: int, + max_peptide_length: int, ): """ Digests a FASTA file and returns the peptides, their masses, and associated protein. @@ -70,9 +71,9 @@ def digest_fasta( The number of missed cleavages to allow. max_mods : int The maximum number of modifications to allow per peptide. - min_length : int + min_peptide_length : int The minimum length of peptides to consider. - max_length : int + max_peptide_length : int The maximum length of peptides to consider. Returns @@ -81,35 +82,36 @@ def digest_fasta( A list of tuples containing the peptide sequence, mass, and associated protein. Sorted by neutral mass in ascending order. """ - - # Verify the eistence of the file: + # Verify the existence of the file: if not os.path.isfile(fasta_filename): - print(f"File {fasta_filename} does not exist.") + logger.error("File %s does not exist.", fasta_filename) raise FileNotFoundError(f"File {fasta_filename} does not exist.") fasta_data = fasta.read(fasta_filename) peptide_list = [] - if digestion in ["full", "partial"]: - semi = True if digestion == "partial" else False - for header, seq in fasta_data: - pep_set = parser.cleave( - seq, - rule=parser.expasy_rules[enzyme], - missed_cleavages=missed_cleavages, - semi=semi, - ) - protein = header.split()[0] - for pep in pep_set: - if len(pep) < min_length or len(pep) > max_length: - continue - if "X" in pep or "U" in pep: - logger.warn( - "Skipping peptide with ambiguous amino acids: %s", pep - ) - continue - peptide_list.append((pep, protein)) - else: + if digestion not in ["full", "partial"]: + logger.error("Digestion type %s not recognized.", digestion) raise ValueError(f"Digestion type {digestion} not recognized.") + semi = digestion == "partial" + for header, seq in fasta_data: + pep_set = parser.cleave( + seq, + rule=parser.expasy_rules[enzyme], + missed_cleavages=missed_cleavages, + semi=semi, + ) + protein = header.split()[0] + for pep in pep_set: + if len(pep) < min_peptide_length or len(pep) > max_peptide_length: + continue + if any( + aa in pep for aa in "BJOUXZ" + ): # Check for incorrect AA letters + logger.warn( + "Skipping peptide with ambiguous amino acids: %s", pep + ) + continue + peptide_list.append((pep, protein)) # Generate modified peptides mass_calculator = depthcharge.masses.PeptideMass(residues="massivekb") @@ -136,7 +138,7 @@ def get_candidates( precursor_mz: float, charge: int, peptide_list: List[Tuple[str, float, str]], - precursor_tolerance: int, + precursor_tolerance: float, isotope_error: str, ): """ @@ -156,7 +158,6 @@ def get_candidates( isotope_error : str The isotope error levels to consider. """ - candidates = set() isotope_error = [int(x) for x in isotope_error.split(",")] @@ -219,7 +220,9 @@ def _to_raw_mass(mz_mass, charge): def get_mass_indices(masses, m_low, m_high): - """Grabs mass indices from a list of mass values that fall within a specified range. + """Grabs mass indices that fall within a specified range. + + Pulls from masses, a list of mass values. Requires that the mass values are sorted in ascending order. Parameters diff --git a/casanovo/denovo/dataloaders.py b/casanovo/denovo/dataloaders.py index 80a4f7dc..14a0ff99 100644 --- a/casanovo/denovo/dataloaders.py +++ b/casanovo/denovo/dataloaders.py @@ -2,20 +2,20 @@ import functools import os -from typing import List, Optional, Tuple -from functools import partial import logging +from typing import List, Optional, Tuple +from depthcharge.data import AnnotatedSpectrumIndex import lightning.pytorch as pl import numpy as np import torch -from depthcharge.data import AnnotatedSpectrumIndex +from ..data import db_utils from ..data.datasets import ( AnnotatedSpectrumDataset, SpectrumDataset, ) -from ..data import db_utils + logger = logging.getLogger("casanovo") @@ -186,7 +186,7 @@ def db_dataloader(self) -> torch.utils.data.DataLoader: return torch.utils.data.DataLoader( self.test_dataset, batch_size=self.eval_batch_size, - collate_fn=partial( + collate_fn=functools.partial( prepare_psm_batch, digest=self.digest, precursor_tolerance=self.precursor_tolerance, diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 3a069dcd..79848682 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -16,7 +16,7 @@ from . import evaluate from .. import config -from ..data import ms_io, db_utils +from ..data import ms_io logger = logging.getLogger("casanovo") @@ -991,7 +991,8 @@ def configure_optimizers( class DbSpec2Pep(Spec2Pep): """ - Subclass of Spec2Pep for the use of Casanovo as an MS/MS database search score function. + Subclass of Spec2Pep for the use of Casanovo as an \ + MS/MS database search score function. Uses teacher forcing to 'query' Casanovo for its score for each AA within a candidate peptide, and takes the geometric average of these scores @@ -1008,7 +1009,6 @@ class DbSpec2Pep(Spec2Pep): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.total_psms = 0 self.psm_batch_size = 1024 def predict_step(self, batch, *args): @@ -1029,11 +1029,14 @@ def predict_step(self, batch, *args): scores, amino acid-level scores, and associated proteins. """ predictions = [] - while len(batch[0]) > 0: - next_batch = [b[self.psm_batch_size :] for b in batch] - batch = [b[: self.psm_batch_size] for b in batch] + for start_idx in range(0, len(batch[0]), self.psm_batch_size): + current_batch = [ + b[start_idx : start_idx + self.psm_batch_size] for b in batch + ] pred, truth = self.decoder( - batch[3], batch[1], *self.encoder(batch[0]) + current_batch[3], + current_batch[1], + *self.encoder(current_batch[0]), ) pred = self.softmax(pred) all_scores, per_aa_scores = _calc_match_score( @@ -1048,13 +1051,13 @@ def predict_step(self, batch, *args): peptide, protein, ) in zip( - batch[1][:, 1].cpu().detach().numpy(), - batch[1][:, 2].cpu().detach().numpy(), - batch[2], + current_batch[1][:, 1].cpu().detach().numpy(), + current_batch[1][:, 2].cpu().detach().numpy(), + current_batch[2], all_scores.cpu().detach().numpy(), per_aa_scores.cpu().detach().numpy(), - batch[3], - batch[4], + current_batch[3], + current_batch[4], ): predictions.append( ( @@ -1067,8 +1070,6 @@ def predict_step(self, batch, *args): protein, ) ) - batch = next_batch - self.total_psms += len(predictions) return predictions def on_predict_batch_end( @@ -1088,8 +1089,6 @@ def on_predict_batch_end( aa_scores, protein, ) in outputs: - if len(peptide) == 0: - continue self.out_writer.psms.append( ( peptide, diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index a6b59ed9..c2b71098 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -10,8 +10,6 @@ from pathlib import Path from typing import Iterable, List, Optional, Union -import time - import lightning.pytorch as pl import numpy as np import torch @@ -20,7 +18,7 @@ from lightning.pytorch.callbacks import ModelCheckpoint from ..config import Config -from ..data import ms_io, db_utils +from ..data import db_utils, ms_io from ..denovo.dataloaders import DeNovoDataModule from ..denovo.model import Spec2Pep, DbSpec2Pep @@ -89,8 +87,8 @@ def db_search( digestion: str, missed_cleavages: int, max_mods: int, - min_length: int, - max_length: int, + min_peptide_length: int, + max_peptide_length: int, precursor_tolerance: float, isotope_error: str, output: str, @@ -100,7 +98,7 @@ def db_search( Parameters ---------- peak_path : Iterable[str] - The path to the .mgf data file for database search. + The paths to the .mgf data files for database search. fasta_path : str The path to the FASTA file for database search. enzyme : str @@ -111,9 +109,9 @@ def db_search( The number of missed cleavages allowed. max_mods : int The maximum number of modifications allowed per peptide. - min_length : int + min_peptide_length : int The minimum peptide length. - max_length : int + max_peptide_length : int The maximum peptide length. precursor_tolerance : float The precursor mass tolerance in ppm. @@ -147,19 +145,13 @@ def db_search( digestion, missed_cleavages, max_mods, - min_length, - max_length, + min_peptide_length, + max_peptide_length, ) self.loaders.precursor_tolerance = precursor_tolerance self.loaders.isotope_error = isotope_error - t1 = time.time() self.trainer.predict(self.model, self.loaders.db_dataloader()) - t2 = time.time() - logger.info("Database search took %.3f seconds", t2 - t1) - logger.info("Scored %s PSMs", self.model.total_psms) - logger.info("%.3f PSMs per second", self.model.total_psms / (t2 - t1)) - logger.info("%s seconds per PSM", (t2 - t1) / self.model.total_psms) def train( self, diff --git a/tests/conftest.py b/tests/conftest.py index b2244308..60afcd83 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,19 +17,16 @@ def mgf_small(tmp_path): @pytest.fixture -def tiny_fasta_file(tmp_path, fasta_raw_data): +def tiny_fasta_file(tmp_path): fasta_file = tmp_path / "tiny_fasta.fasta" with fasta_file.open("w+") as fasta_ref: - fasta_ref.write(fasta_raw_data) + fasta_ref.write( + ">foo\nMEAPAQLLFLLLLWLPDTTREIVMTQSPPTLSLSPGERVTLSCRASQSVSSSYLTWYQQKPGQAPRLLIYGASTRATSIPARFSGSGSGTDFTLTISSLQPEDFAVYYCQQDYNLP" + ) return fasta_file -@pytest.fixture -def fasta_raw_data(): - return ">foo\nMEAPAQLLFLLLLWLPDTTREIVMTQSPPTLSLSPGERVTLSCRASQSVSSSYLTWYQQKPGQAPRLLIYGASTRATSIPARFSGSGSGTDFTLTISSLQPEDFAVYYCQQDYNLP" - - @pytest.fixture def mgf_db_search(tmp_path): """An MGF file with 7 spectra and scan numbers, C+57.021 mass modification considered""" diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index e3707917..419cf3ef 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -2,6 +2,7 @@ import heapq import os import platform +import re import shutil import tempfile @@ -10,11 +11,10 @@ import numpy as np import pytest import torch -import re from casanovo import casanovo from casanovo import utils -from casanovo.data import ms_io, db_utils +from casanovo.data import db_utils, ms_io from casanovo.data.datasets import SpectrumDataset, AnnotatedSpectrumDataset from casanovo.denovo.evaluate import aa_match_batch, aa_match_metrics from casanovo.denovo.model import Spec2Pep, _aa_pep_score, _calc_match_score @@ -220,10 +220,7 @@ def test_calc_match_score(): assert np.sum(masked_per_aa_scores.numpy()[3]) == 3 -def test_digest_fasta_cleave(fasta_raw_data): - - with open("temp_fasta", "w") as file: - file.write(fasta_raw_data) +def test_digest_fasta_cleave(tiny_fasta_file): # No missed cleavages expected_normal = [ @@ -275,49 +272,24 @@ def test_digest_fasta_cleave(fasta_raw_data): "EIVMTQSPPTLSLSPGERVTLSC+57.021RASQSVSSSYLTWYQQKPGQAPR", "LLIYGASTRATSIPARFSGSGSGTDFTLTISSLQPEDFAVYYC+57.021QQDYNLP", ] + for missed_cleavages, expected in zip( + (0, 1, 3), + (expected_normal, expected_1missedcleavage, expected_3missedcleavage), + ): + peptide_list = db_utils.digest_fasta( + fasta_filename=str(tiny_fasta_file), + enzyme="trypsin", + digestion="full", + missed_cleavages=missed_cleavages, + max_mods=0, + min_peptide_length=6, + max_peptide_length=50, + ) + peptide_list = [x[0] for x in peptide_list] + assert peptide_list == expected - peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", - enzyme="trypsin", - digestion="full", - missed_cleavages=0, - max_mods=0, - min_length=6, - max_length=50, - ) - peptide_list = [x[0] for x in peptide_list] - assert peptide_list == expected_normal - - peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", - enzyme="trypsin", - digestion="full", - missed_cleavages=1, - max_mods=0, - min_length=6, - max_length=50, - ) - peptide_list = [x[0] for x in peptide_list] - assert peptide_list == expected_1missedcleavage - - peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", - enzyme="trypsin", - digestion="full", - missed_cleavages=3, - max_mods=0, - min_length=6, - max_length=50, - ) - peptide_list = [x[0] for x in peptide_list] - assert peptide_list == expected_3missedcleavage - - -def test_digest_fasta_mods(fasta_raw_data): - - with open("temp_fasta", "w") as file: - file.write(fasta_raw_data) +def test_digest_fasta_mods(tiny_fasta_file): # 1 modification allowed # fixed: C+57.02146 # variable: 1M+15.994915,1N+0.984016,1Q+0.984016 @@ -373,13 +345,13 @@ def test_digest_fasta_mods(fasta_raw_data): ] peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="trypsin", digestion="full", missed_cleavages=0, max_mods=1, - min_length=6, - max_length=50, + min_peptide_length=6, + max_peptide_length=50, ) peptide_list = [x[0] for x in peptide_list] peptide_list = [ @@ -392,11 +364,7 @@ def test_digest_fasta_mods(fasta_raw_data): assert peptide_list == expected_1mod -def test_length_restrictions(fasta_raw_data): - - with open("temp_fasta", "w") as file: - file.write(fasta_raw_data) - +def test_length_restrictions(tiny_fasta_file): # length between 20 and 50 expected_long = [ "MEAPAQLLFLLLLWLPDTTR", @@ -408,35 +376,31 @@ def test_length_restrictions(fasta_raw_data): expected_short = ["ATSIPAR", "VTLSC+57.021R"] peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="trypsin", digestion="full", missed_cleavages=0, max_mods=0, - min_length=20, - max_length=50, + min_peptide_length=20, + max_peptide_length=50, ) peptide_list = [x[0] for x in peptide_list] assert peptide_list == expected_long peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="trypsin", digestion="full", missed_cleavages=0, max_mods=0, - min_length=6, - max_length=8, + min_peptide_length=6, + max_peptide_length=8, ) peptide_list = [x[0] for x in peptide_list] assert peptide_list == expected_short -def test_digest_fasta_enzyme(fasta_raw_data): - - with open("temp_fasta", "w") as file: - file.write(fasta_raw_data) - +def test_digest_fasta_enzyme(tiny_fasta_file): # arg-c enzyme expected_argc = [ "ATSIPAR", @@ -452,35 +416,31 @@ def test_digest_fasta_enzyme(fasta_raw_data): expected_aspn = ["DFAVYYC+57.021QQ", "DFTLTISSLQPE", "MEAPAQLLFLLLLWLP"] peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="arg-c", digestion="full", missed_cleavages=0, max_mods=0, - min_length=6, - max_length=50, + min_peptide_length=6, + max_peptide_length=50, ) peptide_list = [x[0] for x in peptide_list] assert peptide_list == expected_argc peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="asp-n", digestion="full", missed_cleavages=0, max_mods=0, - min_length=6, - max_length=50, + min_peptide_length=6, + max_peptide_length=50, ) peptide_list = [x[0] for x in peptide_list] assert peptide_list == expected_aspn -def test_get_candidates(fasta_raw_data): - - with open("temp_fasta", "w") as file: - file.write(fasta_raw_data) - +def test_get_candidates(tiny_fasta_file): # precursor_window is 10000 expected_smallwindow = ["LLIYGASTR"] @@ -491,13 +451,13 @@ def test_get_candidates(fasta_raw_data): expected_widewindow = ["ATSIPAR", "VTLSC+57.021R", "LLIYGASTR"] peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="trypsin", digestion="full", missed_cleavages=1, max_mods=0, - min_length=6, - max_length=50, + min_peptide_length=6, + max_peptide_length=50, ) candidates = db_utils.get_candidates( @@ -511,13 +471,13 @@ def test_get_candidates(fasta_raw_data): assert expected_smallwindow == candidates peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="trypsin", digestion="full", missed_cleavages=1, max_mods=0, - min_length=6, - max_length=50, + min_peptide_length=6, + max_peptide_length=50, ) candidates = db_utils.get_candidates( @@ -531,13 +491,13 @@ def test_get_candidates(fasta_raw_data): assert expected_midwindow == candidates peptide_list = db_utils.digest_fasta( - fasta_filename="temp_fasta", + fasta_filename=str(tiny_fasta_file), enzyme="trypsin", digestion="full", missed_cleavages=1, max_mods=0, - min_length=6, - max_length=50, + min_peptide_length=6, + max_peptide_length=50, ) candidates = db_utils.get_candidates(