Skip to content

Commit

Permalink
speed up calc_match_score
Browse files Browse the repository at this point in the history
  • Loading branch information
VarunAnanth2003 committed Oct 15, 2023
1 parent 3ecb3cb commit b873425
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 37 deletions.
66 changes: 33 additions & 33 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ class DBSpec2Pep(Spec2Pep):
Hijacks teacher-forcing implemented in Spec2Pep and uses it to predict scores between a spectra and associated peptide
"""

num_pairs = 2
num_pairs = 1024

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -963,8 +963,9 @@ def predict_step(self, batch, *args):
encoded_ms,
) in self.smart_batch_gen(batch):
with torch.set_grad_enabled(True): # Fixes NaN!?
pred, truth = self.decoder(peptides, precursors, *encoded_ms)
print(truth)
pred, truth = self.decoder(
peptides, precursors, *encoded_ms
) #! get detailed breakdown of this speed
sm = torch.nn.Softmax(dim=2) # dim=2 is very important!
pred = sm(pred)
score_result, per_aa_score = calc_match_score(
Expand All @@ -975,7 +976,6 @@ def predict_step(self, batch, *args):
)
return batch_res

@torch.compile
def smart_batch_gen(self, batch):
all_psm = []
enc = self.encoder(batch[0])
Expand Down Expand Up @@ -1018,16 +1018,26 @@ def smart_batch_gen(self, batch):
def on_predict_epoch_end(self, results) -> None:
if self.out_writer is None:
return
results = np.array(results, dtype=object).squeeze((0, 1))
results = np.array(results, dtype=object).squeeze((0))
with open(self.out_writer.filename, "a") as out_f:
csv_writer = csv.writer(out_f)
for batch in results:
for index, t_or_d, peptide, score, per_aa_scores in list(
zip(*batch)
):
csv_writer.writerow(
(index, peptide, bool(t_or_d), score, per_aa_scores)
)
for group in results:
for batch in group:
for index, t_or_d, peptide, score, per_aa_scores in list(
zip(*batch)
):
per_aa_scores = per_aa_scores.numpy()
per_aa_scores = per_aa_scores[per_aa_scores != 0]
score = score.numpy()
csv_writer.writerow(
(
index,
peptide,
bool(t_or_d),
score,
per_aa_scores,
)
)
out_f.close()


Expand All @@ -1047,35 +1057,25 @@ def calc_match_score(
Returns
-------
score : list[float]
score : list[float], list[list[float]]
The score between the input spectra and associated peptide (for an entire batch)
a list of lists of per amino acid scores (for an entire batch)
"""
batch_all_aa_scores = batch_all_aa_scores[
:, :-1, : # -2
] # Remove trailing tokens from predictions, change to -1 to keep stop token
truth_aa_indicies = truth_aa_indicies[
:, : # -1
] # Remove trailing tokens from label, remove -1 to keep stop token
all_scores = []
per_aa_scores = []
for all_aa_pred, truth_indicies in zip(
batch_all_aa_scores, truth_aa_indicies
):
assert len(all_aa_pred) == len(
truth_indicies
) # Ensure that length of score list and indexes to pull from are the same length
for (
preds
) in all_aa_pred: # Ensure softmax distribution along correct axis
assert round(sum(preds).item()) == 1
aa_scores = []
for scores, true_index in zip(all_aa_pred, truth_indicies):
if true_index != 0:
aa_scores.append(scores[true_index].item())
normalized_score = sum(aa_scores) / len(aa_scores)
all_scores.append(normalized_score)
per_aa_scores.append(aa_scores)
return all_scores, per_aa_scores
per_aa_scores = batch_all_aa_scores[
torch.arange(batch_all_aa_scores.shape[0])[:, None],
torch.arange(0, batch_all_aa_scores.shape[1]),
truth_aa_indicies,
]
score_mask = truth_aa_indicies != 0
masked_per_aa_scores = per_aa_scores * score_mask
all_scores = masked_per_aa_scores.sum(dim=1) / score_mask.sum(dim=1)
return all_scores, masked_per_aa_scores


class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
Expand Down
7 changes: 3 additions & 4 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ..denovo.model import Spec2Pep
from ..denovo.model import DBSpec2Pep

from pytorch_lightning.profiler import SimpleProfiler
from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler


logger = logging.getLogger("casanovo")
Expand Down Expand Up @@ -392,10 +392,9 @@ def db_search(

# Create the Trainer object.
abs_experiment_dirpath = "/net/noble/vol2/home/vananth3/2023_vananth_denovo-dbsearch/results/2023-08-21_speedup"
profiler = SimpleProfiler(
profiler = AdvancedProfiler(
dirpath=abs_experiment_dirpath,
filename="casanovo_plasmodium_batch",
extended=True,
filename="test_remove_cms",
)
trainer = pl.Trainer(
accelerator="auto",
Expand Down

0 comments on commit b873425

Please sign in to comment.