Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval metrics and circular import bug fix. #380

Merged
merged 33 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
617fcb8
eval metrics bug fix
Lilferrit Sep 12, 2024
a52ba83
better eval metrics bug fix
Lilferrit Sep 12, 2024
81f4515
eval metrics bug fix
Lilferrit Sep 12, 2024
e30b674
better eval metrics bug fix
Lilferrit Sep 12, 2024
7b6bab3
eval stats unit test, circular import fix
Lilferrit Sep 16, 2024
ddbc93a
log metrics unit test
Lilferrit Sep 17, 2024
00fd170
resolved upstream merge conflict
Lilferrit Sep 17, 2024
9d4109e
removed unused import
Lilferrit Sep 17, 2024
86747d9
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
c863b4a
aa_match_batch hanles none, additional skipped spectra test cases
Lilferrit Sep 20, 2024
34c456d
Log optimizer and training metrics to CSV file (#376)
Lilferrit Sep 20, 2024
8f21edb
aa_match_batch and aa_match handle None
Lilferrit Sep 23, 2024
217eeb8
top_match eval metrics warning
Lilferrit Sep 23, 2024
3b27582
removed unused import
Lilferrit Sep 17, 2024
4e89028
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
64a681f
aa_match_batch hanles none, additional skipped spectra test cases
Lilferrit Sep 20, 2024
a3d5763
aa_match_batch and aa_match handle None
Lilferrit Sep 23, 2024
8be20ab
top_match eval metrics warning
Lilferrit Sep 23, 2024
60d4159
Merge branch 'eval-metrics-fix' of github.com:Noble-Lab/casanovo into…
Lilferrit Sep 23, 2024
5f38ea8
eval metrics bug fix
Lilferrit Sep 12, 2024
8b6e925
better eval metrics bug fix
Lilferrit Sep 12, 2024
bacf243
eval stats unit test, circular import fix
Lilferrit Sep 16, 2024
5bbbe6f
log metrics unit test
Lilferrit Sep 17, 2024
4788fab
removed unused import
Lilferrit Sep 17, 2024
c473f20
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
63ac6ad
aa_match_batch hanles none, additional skipped spectra test cases
Lilferrit Sep 20, 2024
7b4b6e6
aa_match_batch and aa_match handle None
Lilferrit Sep 23, 2024
78bb897
top_match eval metrics warning
Lilferrit Sep 23, 2024
fb975b2
removed unused import
Lilferrit Sep 17, 2024
692cd7e
log metrics refactor, additional log metrics test case
Lilferrit Sep 19, 2024
7740a77
metrics file logging bug fix
Lilferrit Sep 23, 2024
e9bb5ec
merge conflicts
Lilferrit Sep 23, 2024
60524af
aa_match test cases, minor aa_match refactor
Lilferrit Sep 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

- During training, model checkpoints will be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run.
- Besides as a local file, model weights can be specified from a URL. Upon initial download, the weights file is cached for future re-use.
- Training and optimizer metrics can now be logged to a CSV file by setting the `log_metrics` config file option to true - the CSV file will be written to under a sub-directory of the output directory named `csv_logs`.

### Changed

- Removed the `evaluate` sub-command, and all model evaluation functionality has been moved to the `sequence` command using the new `--evaluate` flag.
- The `--output` option has been split into two options, `--output_dir` and `--output_root`.
- The `--validation_peak_path` is now optional when training; if `--validation_peak_path` is not set then the `train_peak_path` will also be used for validation.
- The `tb_summarywriter` config option is now a boolean config option, and if set to true the TensorBoard summary will be written to a sub-directory of the output directory named `tensorboard`.
bittremieux marked this conversation as resolved.
Show resolved Hide resolved

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@

runner.predict(
peak_path,
str((output_path / output_root).with_suffix(".mztab")),
str((output_path / output_root_name).with_suffix(".mztab")),
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
evaluate=evaluate,
)
psms = runner.writer.psms
Expand Down Expand Up @@ -253,7 +253,7 @@
logger.info(" %s", peak_file)

if len(validation_peak_path) == 0:
validation_peak_path = train_peak_path

Check warning on line 256 in casanovo/casanovo.py

View check run for this annotation

Codecov / codecov/patch

casanovo/casanovo.py#L256

Added line #L256 was not covered by tests

logger.info("Using the following validation files:")
for peak_file in validation_peak_path:
Expand Down
2 changes: 2 additions & 0 deletions casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class Config:
residues=dict,
n_log=int,
tb_summarywriter=bool,
log_metrics=bool,
log_every_n_steps=int,
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
train_label_smoothing=float,
warmup_iters=int,
cosine_schedule_period_iters=int,
Expand Down
4 changes: 4 additions & 0 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ random_seed: 454
n_log: 1
# Whether to create tensorboard directory
tb_summarywriter: false
# Whether to create csv_logs directory
log_metrics: false
# How often to log optimizer parameters in steps
log_every_n_steps: 50
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
# Model validation and checkpointing frequency in training steps.
val_check_interval: 50_000

Expand Down
41 changes: 2 additions & 39 deletions casanovo/data/ms_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,17 @@

import collections
import csv
import dataclasses
import operator
import os
import re
from pathlib import Path
from typing import List, Tuple, Iterable
from typing import List

import natsort

from .. import __version__
from ..config import Config


@dataclasses.dataclass
class PepSpecMatch:
"""
Peptide Spectrum Match (PSM) dataclass

Parameters
----------
sequence : str
The amino acid sequence of the peptide.
spectrum_id : Tuple[str, str]
A tuple containing the spectrum identifier in the form
(spectrum file name, spectrum file idx)
peptide_score : float
Score of the match between the full peptide sequence and the
spectrum.
charge : int
The precursor charge state of the peptide ion observed in the spectrum.
calc_mz : float
The calculated mass-to-charge ratio (m/z) of the peptide based on its
sequence and charge state.
exp_mz : float
The observed (experimental) precursor mass-to-charge ratio (m/z) of the
peptide as detected in the spectrum.
aa_scores : Iterable[float]
A list of scores for individual amino acids in the peptide
sequence, where len(aa_scores) == len(sequence)
"""

sequence: str
spectrum_id: Tuple[str, str]
peptide_score: float
charge: int
calc_mz: float
exp_mz: float
aa_scores: Iterable[float]
from .psm import PepSpecMatch


class MztabWriter:
Expand Down
41 changes: 41 additions & 0 deletions casanovo/data/psm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Peptide spectrum match dataclass"""

import dataclasses
from typing import Tuple, Iterable


@dataclasses.dataclass
class PepSpecMatch:
"""
Peptide Spectrum Match (PSM) dataclass

Parameters
----------
sequence : str
The amino acid sequence of the peptide.
spectrum_id : Tuple[str, str]
A tuple containing the spectrum identifier in the form
(spectrum file name, spectrum file idx)
peptide_score : float
Score of the match between the full peptide sequence and the
spectrum.
charge : int
The precursor charge state of the peptide ion observed in the spectrum.
calc_mz : float
The calculated mass-to-charge ratio (m/z) of the peptide based on its
sequence and charge state.
exp_mz : float
The observed (experimental) precursor mass-to-charge ratio (m/z) of the
peptide as detected in the spectrum.
aa_scores : Iterable[float]
A list of scores for individual amino acids in the peptide
sequence, where len(aa_scores) == len(sequence)
"""

sequence: str
spectrum_id: Tuple[str, str]
peptide_score: float
charge: int
calc_mz: float
exp_mz: float
aa_scores: Iterable[float]
20 changes: 14 additions & 6 deletions casanovo/denovo/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@


def aa_match(
peptide1: List[str],
peptide2: List[str],
peptide1: List[str] | None,
peptide2: List[str] | None,
aa_dict: Dict[str, float],
cum_mass_threshold: float = 0.5,
ind_mass_threshold: float = 0.1,
Expand All @@ -139,9 +139,9 @@

Parameters
----------
peptide1 : List[str]
peptide1 : List[str] | None,
The first tokenized peptide sequence to be compared.
peptide2 : List[str]
peptide2 : List[str] | None
The second tokenized peptide sequence to be compared.
aa_dict : Dict[str, float]
Mapping of amino acid tokens to their mass values.
Expand All @@ -161,7 +161,12 @@
pep_match : bool
Boolean flag to indicate whether the two peptide sequences fully match.
"""
if mode == "best":
if peptide1 is None and peptide2 is None:
return np.array([], dtype=bool), False

Check warning on line 165 in casanovo/denovo/evaluate.py

View check run for this annotation

Codecov / codecov/patch

casanovo/denovo/evaluate.py#L165

Added line #L165 was not covered by tests
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
elif (peptide1 is None) != (peptide2 is None):
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
peptide = peptide1 if peptide2 is None else peptide2
return np.array([False] * len(peptide)), False
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
elif mode == "best":
return aa_match_prefix_suffix(
peptide1, peptide2, aa_dict, cum_mass_threshold, ind_mass_threshold
)
Expand Down Expand Up @@ -225,9 +230,12 @@
# Split peptides into individual AAs if necessary.
if isinstance(peptide1, str):
peptide1 = re.split(r"(?<=.)(?=[A-Z])", peptide1)

if isinstance(peptide2, str):
peptide2 = re.split(r"(?<=.)(?=[A-Z])", peptide2)
n_aa1, n_aa2 = n_aa1 + len(peptide1), n_aa2 + len(peptide2)

n_aa1 += 0 if peptide1 is None else len(peptide1)
n_aa2 += 0 if peptide2 is None else len(peptide2)
aa_matches_batch.append(
aa_match(
peptide1,
Expand Down
4 changes: 2 additions & 2 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from . import evaluate
from .. import config
from ..data import ms_io
from ..data import ms_io, psm

logger = logging.getLogger("casanovo")

Expand Down Expand Up @@ -914,7 +914,7 @@ def on_predict_batch_end(
if len(peptide) == 0:
continue
self.out_writer.psms.append(
ms_io.PepSpecMatch(
psm.PepSpecMatch(
sequence=peptide,
spectrum_id=tuple(spectrum_i),
peptide_score=peptide_score,
Expand Down
65 changes: 57 additions & 8 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@

import depthcharge.masses
import lightning.pytorch as pl
import lightning.pytorch.loggers
import numpy as np
import torch
from depthcharge.data import AnnotatedSpectrumIndex, SpectrumIndex
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor

from .. import utils
from ..config import Config
Expand Down Expand Up @@ -63,6 +64,8 @@
self.config = config
self.model_filename = model_filename
self.output_dir = output_dir
self.output_rootname = output_rootname
self.overwrite_ckpt_check = overwrite_ckpt_check

# Initialized later:
self.tmp_dir = None
Expand Down Expand Up @@ -105,6 +108,7 @@
filename=best_filename,
enable_version_counter=False,
),
LearningRateMonitor(log_momentum=True, log_weight_decay=True),
]

def __enter__(self):
Expand Down Expand Up @@ -163,20 +167,38 @@
Index containing the annotated spectra used to generate model
predictions
"""
model_output = [psm.sequence for psm in self.writer.psms]
spectrum_annotations = [
test_index[i][4] for i in range(test_index.n_spectra)
]
aa_precision, _, pep_precision = aa_match_metrics(
seq_pred = []
seq_true = []
pred_idx = 0

with test_index as t_ind:
for true_idx in range(t_ind.n_spectra):
seq_true.append(t_ind[true_idx][4])
if pred_idx < len(self.writer.psms) and self.writer.psms[
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
pred_idx
].spectrum_id == t_ind.get_spectrum_id(true_idx):
seq_pred.append(self.writer.psms[pred_idx].sequence)
pred_idx += 1
else:
seq_pred.append(None)

aa_precision, aa_recall, pep_precision = aa_match_metrics(
*aa_match_batch(
spectrum_annotations,
model_output,
seq_true,
seq_pred,
depthcharge.masses.PeptideMass().masses,
)
)

if self.config["top_match"] > 1:
logger.warning(

Check warning on line 194 in casanovo/denovo/model_runner.py

View check run for this annotation

Codecov / codecov/patch

casanovo/denovo/model_runner.py#L194

Added line #L194 was not covered by tests
"The behavior for calculating evaluation metrics is undefined when "
"the 'top_match' configuration option is set to a value greater than 1."
)

logger.info("Peptide Precision: %.2f%%", 100 * pep_precision)
logger.info("Amino Acid Precision: %.2f%%", 100 * aa_precision)
logger.info("Amino Acid Recall: %.2f%%", 100 * aa_recall)

def predict(
self,
Expand Down Expand Up @@ -255,7 +277,34 @@
strategy=self._get_strategy(),
val_check_interval=self.config.val_check_interval,
check_val_every_n_epoch=None,
log_every_n_steps=self.config.log_every_n_steps,
)

if self.config.log_metrics:
if not self.output_dir:
logger.warning(

Check warning on line 285 in casanovo/denovo/model_runner.py

View check run for this annotation

Codecov / codecov/patch

casanovo/denovo/model_runner.py#L285

Added line #L285 was not covered by tests
"Output directory not set in model runner. "
"No loss file will be created."
)
else:
csv_log_dir = "csv_logs"
if self.overwrite_ckpt_check:
utils.check_dir_file_exists(
self.output_dir,
csv_log_dir,
)

additional_cfg.update(
{
"logger": lightning.pytorch.loggers.CSVLogger(
self.output_dir,
version=csv_log_dir,
name=None,
),
"log_every_n_steps": self.config.log_every_n_steps,
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
}
)

trainer_cfg.update(additional_cfg)

self.trainer = pl.Trainer(**trainer_cfg)
Expand All @@ -272,7 +321,7 @@
tb_summarywriter = None
if self.config.tb_summarywriter:
if self.output_dir is None:
logger.warning(

Check warning on line 324 in casanovo/denovo/model_runner.py

View check run for this annotation

Codecov / codecov/patch

casanovo/denovo/model_runner.py#L324

Added line #L324 was not covered by tests
"Can not create tensorboard because the output directory "
"is not set in the model runner."
)
Expand Down
2 changes: 1 addition & 1 deletion casanovo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import psutil
import torch

from .data.ms_io import PepSpecMatch
from .data.psm import PepSpecMatch


SCORE_BINS = [0.0, 0.5, 0.9, 0.95, 0.99]
Expand Down
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def tiny_config(tmp_path):
"devices": None,
"random_seed": 454,
"n_log": 1,
"tb_summarywriter": None,
"tb_summarywriter": False,
"log_metrics": False,
"log_every_n_steps": 50,
bittremieux marked this conversation as resolved.
Show resolved Hide resolved
"n_peaks": 150,
"min_mz": 50.0,
"max_mz": 2500.0,
Expand Down
Loading
Loading