Skip to content

Commit

Permalink
create branch to profile base db-search
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Ananth committed Sep 22, 2023
1 parent 3118d9b commit f0c3014
Showing 1 changed file with 77 additions and 1 deletion.
78 changes: 77 additions & 1 deletion casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,11 +948,14 @@ class DBSpec2Pep(Spec2Pep):
Hijacks teacher-forcing implemented in Spec2Pep and uses it to predict scores between a spectra and associated peptide
"""

num_pairs = 64

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def predict_step(self, batch, *args):
batch_res = []
self.batch_gen(batch)
for new_batch, index in new_batch_generator(batch):
with torch.set_grad_enabled(True): # Fixes NaN!?
pred, truth = self._forward_step(*new_batch)
Expand All @@ -965,6 +968,46 @@ def predict_step(self, batch, *args):
batch_res.append((index, peptides, score_result, per_aa_score))
return batch_res

def batch_gen(self, batch):
all_psm = []
enc = self.encoder(batch[0])
precursors = batch[1]
indexes = batch[3]
enc = list(zip(*enc))
for idx, ms_spectra in enumerate(batch[0]):
spec_peptides = batch[2][idx].split(",")
spec_precursors = [precursors[idx]] * len(spec_peptides)
spec_enc = [enc[idx]] * len(spec_peptides)
spec_idx = [indexes[idx]] * len(spec_peptides)
all_psm.extend(
list(zip(spec_enc, spec_precursors, spec_peptides, spec_idx))
)
# continually grab n items from al_psm until list is exhausted
while len(all_psm) > 0:
batch = all_psm[: self.num_pairs]
all_psm = all_psm[self.num_pairs :]
batch = list(zip(*batch))
print("------------------")
print(type(batch))
print("------------------")
for b in batch:
print(type(b))
print("------------------")
for b in batch:
for t in b:
print(type(t))
print("------------------")
"""
Organization of batch:
batch is a list
each element in batch is a tuple of 32 items
for batch[0] it is a tuple of 2 items, each one a tensor (encoded ms)
for batch[1] it is a bunch of tensors (precursor data)
for batch[2] it is a bunch of srings (peptides)
for batch[3] it is a bunch of [filename, index] lists
"""
quit()

def on_predict_epoch_end(
self, results: List[List[Tuple[str, List[float], List[List[float]]]]]
) -> None:
Expand All @@ -990,6 +1033,38 @@ def on_predict_epoch_end(
)
out_f.close()

def _forward_step(
self,
spectra: torch.Tensor,
precursors: torch.Tensor,
sequences: List[str],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
The forward learning step.
Parameters
----------
spectra : torch.Tensor of shape (n_spectra, n_peaks, 2)
The spectra for which to predict peptide sequences.
Axis 0 represents an MS/MS spectrum, axis 1 contains the peaks in
the MS/MS spectrum, and axis 2 is essentially a 2-tuple specifying
the m/z-intensity pair for each peak. These should be zero-padded,
such that all the spectra in the batch are the same length.
precursors : torch.Tensor of size (n_spectra, 3)
The measured precursor mass (axis 0), precursor charge (axis 1), and
precursor m/z (axis 2) of each MS/MS spectrum.
sequences : List[str] of length n_spectra
The partial peptide sequences to predict.
Returns
-------
scores : torch.Tensor of shape (n_spectra, length, n_amino_acids)
The individual amino acid scores for each prediction.
tokens : torch.Tensor of shape (n_spectra, length)
The predicted tokens for each spectrum.
"""
return self.decoder(sequences, precursors, *self.encoder(spectra))


def new_batch_generator(batch):
"""
Expand All @@ -1006,6 +1081,7 @@ def new_batch_generator(batch):
new_batch : (torch.Tensor, torch.Tensor, array)
A new batch that shares one spectra but has different ptptides to score against
"""
batch_dict = {}
for idx, ms_spectra in enumerate(batch[0]):
# Batch by ms spectra and comma-separated peptides
# Split peptides into list
Expand All @@ -1028,7 +1104,7 @@ def calc_match_score(
) -> List[float]:
"""
Take in teacher-forced scoring of amino acids of the peptides (in a batch) and use the truth labels
to calculate a score between the input spectra and associated peptide. The score will be [DESCRIBE SCORE NORMALIZATION]
to calculate a score between the input spectra and associated peptide.
Parameters
----------
Expand Down

0 comments on commit f0c3014

Please sign in to comment.