diff --git a/casanovo/data/ms_io.py b/casanovo/data/ms_io.py index d1e937f9..50932399 100644 --- a/casanovo/data/ms_io.py +++ b/casanovo/data/ms_io.py @@ -2,54 +2,17 @@ import collections import csv -import dataclasses import operator import os import re from pathlib import Path -from typing import List, Tuple, Iterable +from typing import List import natsort from .. import __version__ from ..config import Config - - -@dataclasses.dataclass -class PepSpecMatch: - """ - Peptide Spectrum Match (PSM) dataclass - - Parameters - ---------- - sequence : str - The amino acid sequence of the peptide. - spectrum_id : Tuple[str, str] - A tuple containing the spectrum identifier in the form - (spectrum file name, spectrum file idx) - peptide_score : float - Score of the match between the full peptide sequence and the - spectrum. - charge : int - The precursor charge state of the peptide ion observed in the spectrum. - calc_mz : float - The calculated mass-to-charge ratio (m/z) of the peptide based on its - sequence and charge state. - exp_mz : float - The observed (experimental) precursor mass-to-charge ratio (m/z) of the - peptide as detected in the spectrum. - aa_scores : Iterable[float] - A list of scores for individual amino acids in the peptide - sequence, where len(aa_scores) == len(sequence) - """ - - sequence: str - spectrum_id: Tuple[str, str] - peptide_score: float - charge: int - calc_mz: float - exp_mz: float - aa_scores: Iterable[float] +from .psm import PepSpecMatch class MztabWriter: diff --git a/casanovo/data/psm.py b/casanovo/data/psm.py new file mode 100644 index 00000000..0dc3c48b --- /dev/null +++ b/casanovo/data/psm.py @@ -0,0 +1,41 @@ +"""Peptide spectrum match dataclass""" + +import dataclasses +from typing import Tuple, Iterable + + +@dataclasses.dataclass +class PepSpecMatch: + """ + Peptide Spectrum Match (PSM) dataclass + + Parameters + ---------- + sequence : str + The amino acid sequence of the peptide. + spectrum_id : Tuple[str, str] + A tuple containing the spectrum identifier in the form + (spectrum file name, spectrum file idx) + peptide_score : float + Score of the match between the full peptide sequence and the + spectrum. + charge : int + The precursor charge state of the peptide ion observed in the spectrum. + calc_mz : float + The calculated mass-to-charge ratio (m/z) of the peptide based on its + sequence and charge state. + exp_mz : float + The observed (experimental) precursor mass-to-charge ratio (m/z) of the + peptide as detected in the spectrum. + aa_scores : Iterable[float] + A list of scores for individual amino acids in the peptide + sequence, where len(aa_scores) == len(sequence) + """ + + sequence: str + spectrum_id: Tuple[str, str] + peptide_score: float + charge: int + calc_mz: float + exp_mz: float + aa_scores: Iterable[float] diff --git a/casanovo/denovo/evaluate.py b/casanovo/denovo/evaluate.py index cbf9e74f..6bc1ff2e 100644 --- a/casanovo/denovo/evaluate.py +++ b/casanovo/denovo/evaluate.py @@ -127,8 +127,8 @@ def aa_match_prefix_suffix( def aa_match( - peptide1: List[str], - peptide2: List[str], + peptide1: List[str] | None, + peptide2: List[str] | None, aa_dict: Dict[str, float], cum_mass_threshold: float = 0.5, ind_mass_threshold: float = 0.1, @@ -139,9 +139,9 @@ def aa_match( Parameters ---------- - peptide1 : List[str] + peptide1 : List[str] | None, The first tokenized peptide sequence to be compared. - peptide2 : List[str] + peptide2 : List[str] | None The second tokenized peptide sequence to be compared. aa_dict : Dict[str, float] Mapping of amino acid tokens to their mass values. @@ -161,7 +161,12 @@ def aa_match( pep_match : bool Boolean flag to indicate whether the two peptide sequences fully match. """ - if mode == "best": + if peptide1 is None and peptide2 is None: + return np.empty(0, dtype=bool), False + elif peptide1 is None or peptide2 is None: + peptide = peptide1 if peptide2 is None else peptide2 + return np.zeros(len(peptide), dtype=bool), False + elif mode == "best": return aa_match_prefix_suffix( peptide1, peptide2, aa_dict, cum_mass_threshold, ind_mass_threshold ) @@ -225,9 +230,12 @@ def aa_match_batch( # Split peptides into individual AAs if necessary. if isinstance(peptide1, str): peptide1 = re.split(r"(?<=.)(?=[A-Z])", peptide1) + if isinstance(peptide2, str): peptide2 = re.split(r"(?<=.)(?=[A-Z])", peptide2) - n_aa1, n_aa2 = n_aa1 + len(peptide1), n_aa2 + len(peptide2) + + n_aa1 += 0 if peptide1 is None else len(peptide1) + n_aa2 += 0 if peptide2 is None else len(peptide2) aa_matches_batch.append( aa_match( peptide1, diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 6e984a1d..ce8621d8 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -16,7 +16,7 @@ from . import evaluate from .. import config -from ..data import ms_io +from ..data import ms_io, psm logger = logging.getLogger("casanovo") @@ -914,7 +914,7 @@ def on_predict_batch_end( if len(peptide) == 0: continue self.out_writer.psms.append( - ms_io.PepSpecMatch( + psm.PepSpecMatch( sequence=peptide, spectrum_id=tuple(spectrum_i), peptide_score=peptide_score, diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 6c3271d0..1b4ade13 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -167,20 +167,38 @@ def log_metrics(self, test_index: AnnotatedSpectrumIndex) -> None: Index containing the annotated spectra used to generate model predictions """ - model_output = [psm.sequence for psm in self.writer.psms] - spectrum_annotations = [ - test_index[i][4] for i in range(test_index.n_spectra) - ] - aa_precision, _, pep_precision = aa_match_metrics( + seq_pred = [] + seq_true = [] + pred_idx = 0 + + with test_index as t_ind: + for true_idx in range(t_ind.n_spectra): + seq_true.append(t_ind[true_idx][4]) + if pred_idx < len(self.writer.psms) and self.writer.psms[ + pred_idx + ].spectrum_id == t_ind.get_spectrum_id(true_idx): + seq_pred.append(self.writer.psms[pred_idx].sequence) + pred_idx += 1 + else: + seq_pred.append(None) + + aa_precision, aa_recall, pep_precision = aa_match_metrics( *aa_match_batch( - spectrum_annotations, - model_output, + seq_true, + seq_pred, depthcharge.masses.PeptideMass().masses, ) ) + if self.config["top_match"] > 1: + logger.warning( + "The behavior for calculating evaluation metrics is undefined when " + "the 'top_match' configuration option is set to a value greater than 1." + ) + logger.info("Peptide Precision: %.2f%%", 100 * pep_precision) logger.info("Amino Acid Precision: %.2f%%", 100 * aa_precision) + logger.info("Amino Acid Recall: %.2f%%", 100 * aa_recall) def predict( self, @@ -259,10 +277,10 @@ def initialize_trainer(self, train: bool) -> None: strategy=self._get_strategy(), val_check_interval=self.config.val_check_interval, check_val_every_n_epoch=None, - log_every_n_steps=self.config.get("log_every_n_steps"), + log_every_n_steps=self.config.log_every_n_steps, ) - if self.config.get("log_metrics"): + if self.config.log_metrics: if not self.output_dir: logger.warning( "Output directory not set in model runner. " @@ -283,9 +301,7 @@ def initialize_trainer(self, train: bool) -> None: version=csv_log_dir, name=None, ), - "log_every_n_steps": self.config.get( - "log_every_n_steps" - ), + "log_every_n_steps": self.config.log_every_n_steps, } ) diff --git a/casanovo/utils.py b/casanovo/utils.py index 1161b5eb..43b1cb7d 100644 --- a/casanovo/utils.py +++ b/casanovo/utils.py @@ -15,7 +15,7 @@ import psutil import torch -from .data.ms_io import PepSpecMatch +from .data.psm import PepSpecMatch SCORE_BINS = [0.0, 0.5, 0.9, 0.95, 0.99] diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index 812ad19c..cf04cf83 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -1,12 +1,14 @@ """Unit tests specifically for the model_runner module.""" import shutil +import unittest.mock from pathlib import Path import pytest import torch from casanovo.config import Config +from casanovo.data.psm import PepSpecMatch from casanovo.denovo.model_runner import ModelRunner @@ -287,8 +289,8 @@ def test_evaluate( def test_metrics_logging(tmp_path, mgf_small, tiny_config): config = Config(tiny_config) - config._user_config["log_metrics"] = True - config._user_config["log_every_n_steps"] = 1 + config.log_metrics = True + config.log_every_n_steps = 1 config.tb_summarywriter = True config.max_epochs = 1 @@ -321,3 +323,185 @@ def test_metrics_logging(tmp_path, mgf_small, tiny_config): assert not best_model_path.is_file() assert not tb_path.is_dir() assert csv_path.is_dir() + + +def test_log_metrics(monkeypatch, tiny_config): + def get_mock_index(psm_list): + mock_test_index = unittest.mock.MagicMock() + mock_test_index.__enter__.return_value = mock_test_index + mock_test_index.__exit__.return_value = False + mock_test_index.n_spectra = len(psm_list) + mock_test_index.get_spectrum_id = lambda idx: psm_list[idx].spectrum_id + + mock_spectra = [ + (None, None, None, None, curr_psm.sequence) + for curr_psm in psm_list + ] + mock_test_index.__getitem__.side_effect = lambda idx: mock_spectra[idx] + return mock_test_index + + def get_mock_psm(sequence, spectrum_id): + return PepSpecMatch( + sequence=sequence, + spectrum_id=spectrum_id, + peptide_score=None, + charge=None, + exp_mz=None, + aa_scores=None, + calc_mz=None, + ) + + with monkeypatch.context() as ctx: + mock_logger = unittest.mock.MagicMock() + ctx.setattr("casanovo.denovo.model_runner.logger", mock_logger) + + with ModelRunner(Config(tiny_config)) as runner: + runner.writer = unittest.mock.MagicMock() + + # Test 100% peptide precision + infer_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PET", ("foo", "index=2")), + ] + + act_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PET", ("foo", "index=2")), + ] + + runner.writer.psms = infer_psms + mock_index = get_mock_index(act_psms) + runner.log_metrics(mock_index) + + pep_precision = mock_logger.info.call_args_list[-3][0][1] + aa_precision = mock_logger.info.call_args_list[-2][0][1] + aa_recall = mock_logger.info.call_args_list[-1][0][1] + assert pep_precision == pytest.approx(100) + assert aa_precision == pytest.approx(100) + assert aa_recall == pytest.approx(100) + + # Test 50% peptide precision (one wrong) + infer_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PET", ("foo", "index=2")), + ] + + act_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PEP", ("foo", "index=2")), + ] + + runner.writer.psms = infer_psms + mock_index = get_mock_index(act_psms) + runner.log_metrics(mock_index) + + pep_precision = mock_logger.info.call_args_list[-3][0][1] + aa_precision = mock_logger.info.call_args_list[-2][0][1] + aa_recall = mock_logger.info.call_args_list[-1][0][1] + assert pep_precision == pytest.approx(100 * (1 / 2)) + assert aa_precision == pytest.approx(100 * (5 / 6)) + assert aa_recall == pytest.approx(100 * (5 / 6)) + + # Test skipped spectra + act_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PET", ("foo", "index=2")), + get_mock_psm("PEI", ("foo", "index=3")), + get_mock_psm("PEG", ("foo", "index=4")), + get_mock_psm("PEA", ("foo", "index=5")), + ] + + infer_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PET", ("foo", "index=2")), + get_mock_psm("PEI", ("foo", "index=3")), + get_mock_psm("PEA", ("foo", "index=5")), + ] + + runner.writer.psms = infer_psms + mock_index = get_mock_index(act_psms) + runner.log_metrics(mock_index) + + pep_precision = mock_logger.info.call_args_list[-3][0][1] + aa_precision = mock_logger.info.call_args_list[-2][0][1] + aa_recall = mock_logger.info.call_args_list[-1][0][1] + assert pep_precision == pytest.approx(100 * (4 / 5)) + assert aa_precision == pytest.approx(100) + assert aa_recall == pytest.approx(100 * (4 / 5)) + + infer_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PET", ("foo", "index=2")), + get_mock_psm("PEI", ("foo", "index=3")), + get_mock_psm("PEG", ("foo", "index=4")), + ] + + runner.writer.psms = infer_psms + mock_index = get_mock_index(act_psms) + runner.log_metrics(mock_index) + + pep_precision = mock_logger.info.call_args_list[-3][0][1] + aa_precision = mock_logger.info.call_args_list[-2][0][1] + aa_recall = mock_logger.info.call_args_list[-1][0][1] + assert pep_precision == pytest.approx(100 * (4 / 5)) + assert aa_precision == pytest.approx(100) + assert aa_recall == pytest.approx(100 * (4 / 5)) + + infer_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PEI", ("foo", "index=3")), + ] + + runner.writer.psms = infer_psms + mock_index = get_mock_index(act_psms) + runner.log_metrics(mock_index) + + pep_precision = mock_logger.info.call_args_list[-3][0][1] + aa_precision = mock_logger.info.call_args_list[-2][0][1] + aa_recall = mock_logger.info.call_args_list[-1][0][1] + assert pep_precision == pytest.approx(100 * (2 / 5)) + assert aa_precision == pytest.approx(100) + assert aa_recall == pytest.approx(100 * (2 / 5)) + + infer_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PEA", ("foo", "index=5")), + ] + + runner.writer.psms = infer_psms + mock_index = get_mock_index(act_psms) + runner.log_metrics(mock_index) + + pep_precision = mock_logger.info.call_args_list[-3][0][1] + aa_precision = mock_logger.info.call_args_list[-2][0][1] + aa_recall = mock_logger.info.call_args_list[-1][0][1] + assert pep_precision == pytest.approx(100 * (2 / 5)) + assert aa_precision == pytest.approx(100) + assert aa_recall == pytest.approx(100 * (2 / 5)) + + # Test un-inferred spectra + act_psms = [ + get_mock_psm("PEP", ("foo", "index=1")), + get_mock_psm("PET", ("foo", "index=2")), + get_mock_psm("PEI", ("foo", "index=3")), + get_mock_psm("PEG", ("foo", "index=4")), + ] + + infer_psms = [ + get_mock_psm("PE", ("foo", "index=1")), + get_mock_psm("PE", ("foo", "index=2")), + get_mock_psm("PE", ("foo", "index=3")), + get_mock_psm("PE", ("foo", "index=4")), + get_mock_psm("PE", ("foo", "index=5")), + ] + + runner.writer.psms = infer_psms + mock_index = get_mock_index(act_psms) + runner.log_metrics(mock_index) + + pep_precision = mock_logger.info.call_args_list[-3][0][1] + aa_precision = mock_logger.info.call_args_list[-2][0][1] + aa_recall = mock_logger.info.call_args_list[-1][0][1] + assert pep_precision == pytest.approx(0) + assert aa_precision == pytest.approx(100) + assert aa_recall == pytest.approx(100 * (2 / 3)) diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index 18136ab2..c2c5b628 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -14,6 +14,7 @@ import unittest import unittest.mock +import depthcharge.masses import einops import github import numpy as np @@ -24,7 +25,7 @@ from casanovo import utils from casanovo.data import ms_io from casanovo.data.datasets import SpectrumDataset, AnnotatedSpectrumDataset -from casanovo.denovo.evaluate import aa_match_batch, aa_match_metrics +from casanovo.denovo.evaluate import aa_match_batch, aa_match_metrics, aa_match from casanovo.denovo.model import Spec2Pep, _aa_pep_score from depthcharge.data import SpectrumIndex, AnnotatedSpectrumIndex @@ -846,6 +847,20 @@ def test_eval_metrics(): assert 26 / 40 == pytest.approx(aa_recall) assert 26 / 41 == pytest.approx(aa_precision) + aa_matches, pep_match = aa_match( + None, None, depthcharge.masses.PeptideMass().masses + ) + + assert aa_matches.shape == (0,) + assert not pep_match + + aa_matches, pep_match = aa_match( + "PEPTIDE", None, depthcharge.masses.PeptideMass().masses + ) + + assert np.array_equal(aa_matches, np.zeros(len("PEPTIDE"), dtype=bool)) + assert not pep_match + def test_spectrum_id_mgf(mgf_small, tmp_path): """Test that spectra from MGF files are specified by their index."""