Skip to content

Commit

Permalink
updates to dataloaders, model_runner, and model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunAnanth2003 committed Nov 3, 2024
1 parent 68e67e8 commit d01dd7f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 201 deletions.
12 changes: 6 additions & 6 deletions casanovo/denovo/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def _make_loader(
dataset: torch.utils.data.Dataset,
batch_size: int,
shuffle: bool = False,
collate_fn: Optional[callable] = None,
) -> torch.utils.data.DataLoader:
"""
Create a PyTorch DataLoader.
Expand All @@ -149,6 +150,8 @@ def _make_loader(
The batch size to use.
shuffle : bool
Option to shuffle the batches.
collate_fn : Optional[callable]
A function to collate the data into a batch.
Returns
-------
Expand All @@ -158,7 +161,7 @@ def _make_loader(
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
collate_fn=prepare_batch,
collate_fn=prepare_batch if collate_fn is None else collate_fn,
pin_memory=True,
num_workers=self.n_workers,
shuffle=shuffle,
Expand All @@ -184,15 +187,12 @@ def predict_dataloader(self) -> torch.utils.data.DataLoader:

def db_dataloader(self) -> torch.utils.data.DataLoader:
"""Get a special dataloader for DB search"""
return torch.utils.data.DataLoader(
return self._make_loader(
self.test_dataset,
batch_size=self.eval_batch_size,
self.eval_batch_size,
collate_fn=functools.partial(
prepare_psm_batch, protein_database=self.protein_database
),
pin_memory=True,
num_workers=self.n_workers,
shuffle=False,
)


Expand Down
108 changes: 43 additions & 65 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,17 +994,12 @@ class DbSpec2Pep(Spec2Pep):
Subclass of Spec2Pep for the use of Casanovo as an
MS/MS database search score function.
Uses teacher forcing to 'query' Casanovo for its score for each AA
within a candidate peptide, and takes the geometric average of these scores
and reports this as the score for the spectrum-peptide pair. Note that the
geometric mean of the AA scores is actually calculated by a
summation and average of the log of the scores, to preserve numerical
stability. This does not affect PSM ranking.
Uses teacher forcing to 'query' Casanovo to score a peptide-spectrum
pair. Higher scores indicate a better match between the peptide and
spectrum. The amino acid-level scores are also returned.
Also note that although teacher-forcing is used within this method,
there is *no training* involved. This is a prediction-only method.
Output is provided in .mztab format.
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -1034,17 +1029,15 @@ def predict_step(self, batch, *args):
current_batch = [
b[start_idx : start_idx + self.psm_batch_size] for b in batch
]
pred, truth = self.decoder(
current_batch[3],
current_batch[1],
*self.encoder(current_batch[0]),
pred, truth = self._forward_step(
current_batch[0], current_batch[1], current_batch[3]
)
pred = self.softmax(pred)
all_scores, per_aa_scores = _calc_match_score(
all_peptide_scores, all_aa_scores = _calc_match_score(
pred, truth, self.decoder.reverse
)
for (
precursor_charge,
charge,
precursor_mz,
spectrum_i,
peptide_score,
Expand All @@ -1054,27 +1047,32 @@ def predict_step(self, batch, *args):
current_batch[1][:, 1].cpu().detach().numpy(),
current_batch[1][:, 2].cpu().detach().numpy(),
current_batch[2],
all_scores.cpu().detach().numpy(),
per_aa_scores.cpu().detach().numpy(),
all_peptide_scores,
all_aa_scores,
current_batch[3],
):
store_dict[str(spectrum_i)].append(
(
spectrum_i,
precursor_charge,
precursor_mz,
peptide,
peptide_score,
aa_scores,
self.protein_database.get_associated_protein(peptide),
store_dict[spectrum_i].append(
ms_io.PepSpecMatch(
sequence=peptide,
spectrum_id=tuple(spectrum_i),
peptide_score=peptide_score,
charge=int(charge),
calc_mz=precursor_mz,
exp_mz=self.peptide_mass_calculator.mass(
peptide, charge
),
aa_scores=aa_scores,
protein=self.protein_database.get_associated_protein(
peptide
),
)
)
predictions = []
for spectrum_i in store_dict:
predictions.extend(
sorted(
store_dict[str(spectrum_i)],
key=lambda x: x[4],
store_dict[spectrum_i],
key=lambda x: x.peptide_score,
reverse=True,
)[: self.top_match]
)
Expand All @@ -1090,27 +1088,7 @@ def on_predict_batch_end(
"""
Write the database search results to the output file.
"""
for (
spectrum_i,
charge,
precursor_mz,
peptide,
peptide_score,
aa_scores,
protein,
) in outputs:
self.out_writer.psms.append(
ms_io.PepSpecMatch(
sequence=peptide,
spectrum_id=tuple(spectrum_i),
peptide_score=peptide_score,
charge=int(charge),
calc_mz=precursor_mz,
exp_mz=self.peptide_mass_calculator.mass(peptide, charge),
aa_scores=aa_scores,
protein=protein,
)
)
self.out_writer.psms.extend(outputs)


def _calc_match_score(
Expand All @@ -1124,8 +1102,7 @@ def _calc_match_score(
Take in teacher-forced scoring of amino acids
of the peptides (in a batch) and use the truth labels
to calculate a score between the input spectra and
associated peptide. The score is the geometric
mean of the AA probabilities
associated peptide.
Parameters
----------
Expand All @@ -1134,18 +1111,19 @@ def _calc_match_score(
the vocabulary for every prediction made to generate
the associated peptide (for an entire batch)
truth_aa_indices : torch.Tensor
Indicies of the score for each actual amino acid
Indices of the score for each actual amino acid
in the peptide (for an entire batch)
decoder_reverse : bool
Whether the decoder is reversed.
Returns
-------
(all_scores, per_aa_scores) : Tuple[torch.Tensor, torch.Tensor]
all_peptide_scores: List[float]
The score between the input spectra and associated peptide
(for an entire batch)
a list of lists of per amino acid scores
(for an entire batch)
for each PSM in the batch.
all_aa_scores : List[List[float]]
A list of lists of per amino acid scores
for each PSM in the batch.
"""
# Remove trailing tokens from predictions based on decoder reversal
if not decoder_reverse:
Expand All @@ -1162,19 +1140,19 @@ def _calc_match_score(
cols = torch.arange(0, batch_all_aa_scores.shape[1]).expand_as(rows)

per_aa_scores = batch_all_aa_scores[rows, cols, truth_aa_indices]

per_aa_scores = per_aa_scores.cpu().detach().numpy()
per_aa_scores[per_aa_scores == 0] += 1e-10
score_mask = truth_aa_indices != 0
per_aa_scores[~score_mask] = 0
log_per_aa_scores = torch.log(per_aa_scores)
all_scores = torch.where(
log_per_aa_scores == float("-inf"),
torch.tensor(0.0),
log_per_aa_scores,
).sum(dim=1) / score_mask.sum(
dim=1
) # Calculates geometric score
return all_scores, per_aa_scores
all_peptide_scores = []
all_aa_scores = []
for psm_score in per_aa_scores:
psm_score = np.trim_zeros(psm_score)
aa_scores, peptide_score = _aa_pep_score(psm_score, True)
all_peptide_scores.append(peptide_score)
all_aa_scores.append(aa_scores)

return all_peptide_scores, all_aa_scores


class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
Expand Down
51 changes: 19 additions & 32 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,24 @@ def db_search(
self,
peak_path: Iterable[str],
fasta_path: str,
output: str,
results_path: str,
) -> None:
"""Perform database search with Casanovo.
Parameters
----------
peak_path : Iterable[str]
The paths to the .mgf data files for database search.
The path with the MS data files for database search.
fasta_path : str
The path to the FASTA file for database search.
output : str
Where should the output be saved?
The path with the FASTA file for database search.
results_path : str
Sequencing results file path
Returns
-------
self
"""
self.writer = ms_io.MztabWriter(Path(output).with_suffix(".mztab"))
self.writer = ms_io.MztabWriter(results_path)
self.writer.set_metadata(
self.config,
model=str(self.model_filename),
Expand Down Expand Up @@ -266,7 +266,7 @@ def predict(
Parameters
----------
peak_path : iterable of str
peak_path : Iterable[str]
The path with the MS data files for predicting peptide sequences.
results_path : str
Sequencing results file path
Expand Down Expand Up @@ -431,12 +431,12 @@ def initialize_model(
)

if self.model_filename is None:
# Train a model from scratch if no model file is provided.
if db_search:
logger.error("DB search mode requires a model file")
raise ValueError(
"A model file must be provided for DB search mode"
)
# Train a model from scratch if no model file is provided.
if train:
self.model = Spec2Pep(**model_params)
return
Expand All @@ -456,19 +456,13 @@ def initialize_model(
# First try loading model details from the weights file, otherwise use
# the provided configuration.
device = torch.empty(1).device # Use the default device.
Model = DbSpec2Pep if db_search else Spec2Pep
try:
if db_search:
self.model = DbSpec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**loaded_model_params,
)
else:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**loaded_model_params,
)
self.model = Model.load_from_checkpoint(
self.model_filename,
map_location=device,
**loaded_model_params,
)

architecture_params = set(model_params.keys()) - set(
loaded_model_params.keys()
Expand All @@ -484,18 +478,11 @@ def initialize_model(
except RuntimeError:
# This only doesn't work if the weights are from an older version
try:
if db_search:
self.model = DbSpec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
)
else:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
)
self.model = Model.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
)
except RuntimeError:
raise RuntimeError(
"Weights file incompatible with the current version of "
Expand Down
Loading

0 comments on commit d01dd7f

Please sign in to comment.