diff --git a/casanovo/denovo/dataloaders.py b/casanovo/denovo/dataloaders.py index 7d825b82..21fe44fb 100644 --- a/casanovo/denovo/dataloaders.py +++ b/casanovo/denovo/dataloaders.py @@ -158,9 +158,14 @@ def _make_db_loader( torch.utils.data.DataLoader A PyTorch DataLoader. """ + # Calculate new batch size to saturate previous batch size with PSMs + pep_per_spec = [] + for i in range(min(10, len(dataset))): + pep_per_spec.append(len(dataset[i][3].split(","))) + new_batch_size = int(self.batch_size // np.mean(pep_per_spec)) return torch.utils.data.DataLoader( dataset, - batch_size=self.batch_size, + batch_size=new_batch_size, collate_fn=prepare_db_batch, pin_memory=True, num_workers=self.n_workers, diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 8c2d9ded..1c033484 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -962,11 +962,9 @@ def predict_step(self, batch, *args): precursors, encoded_ms, ) in self.smart_batch_gen(batch): - with torch.set_grad_enabled(True): # Fixes NaN!? - pred, truth = self.decoder( - peptides, precursors, *encoded_ms - ) #! get detailed breakdown of this speed - sm = torch.nn.Softmax(dim=2) # dim=2 is very important! + with torch.set_grad_enabled(True): + pred, truth = self.decoder(peptides, precursors, *encoded_ms) + sm = torch.nn.Softmax(dim=2) pred = sm(pred) score_result, per_aa_score = calc_match_score( pred, truth @@ -1026,8 +1024,9 @@ def on_predict_epoch_end(self, results) -> None: for index, t_or_d, peptide, score, per_aa_scores in list( zip(*batch) ): + # Remove scores of 0 (padding) per_aa_scores = per_aa_scores.numpy() - per_aa_scores = per_aa_scores[per_aa_scores != 0] + per_aa_scores = list(per_aa_scores[per_aa_scores != 0]) score = score.numpy() csv_writer.writerow( ( @@ -1061,17 +1060,19 @@ def calc_match_score( The score between the input spectra and associated peptide (for an entire batch) a list of lists of per amino acid scores (for an entire batch) """ - batch_all_aa_scores = batch_all_aa_scores[ - :, :-1, : # -2 - ] # Remove trailing tokens from predictions, change to -1 to keep stop token - truth_aa_indicies = truth_aa_indicies[ - :, : # -1 - ] # Remove trailing tokens from label, remove -1 to keep stop token - per_aa_scores = batch_all_aa_scores[ - torch.arange(batch_all_aa_scores.shape[0])[:, None], - torch.arange(0, batch_all_aa_scores.shape[1]), - truth_aa_indicies, - ] + # Remove trailing tokens from predictions, + batch_all_aa_scores = batch_all_aa_scores[:, :-1] + + # Vectorized scoring using efficient indexing. + rows = ( + torch.arange(batch_all_aa_scores.shape[0]) + .unsqueeze(-1) + .expand(-1, batch_all_aa_scores.shape[1]) + ) + cols = torch.arange(0, batch_all_aa_scores.shape[1]).expand_as(rows) + + per_aa_scores = batch_all_aa_scores[rows, cols, truth_aa_indicies] + score_mask = truth_aa_indicies != 0 masked_per_aa_scores = per_aa_scores * score_mask all_scores = masked_per_aa_scores.sum(dim=1) / score_mask.sum(dim=1) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index cdb2f273..96408d9e 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -394,7 +394,7 @@ def db_search( abs_experiment_dirpath = "/net/noble/vol2/home/vananth3/2023_vananth_denovo-dbsearch/results/2023-08-21_speedup" profiler = AdvancedProfiler( dirpath=abs_experiment_dirpath, - filename="test_remove_cms", + filename="placeholder", ) trainer = pl.Trainer( accelerator="auto",