Skip to content

Commit

Permalink
pullproper number of spectra
Browse files Browse the repository at this point in the history
to saturate batch size
  • Loading branch information
VarunAnanth2003 committed Oct 17, 2023
1 parent b873425 commit e895a5d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
7 changes: 6 additions & 1 deletion casanovo/denovo/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,14 @@ def _make_db_loader(
torch.utils.data.DataLoader
A PyTorch DataLoader.
"""
# Calculate new batch size to saturate previous batch size with PSMs
pep_per_spec = []
for i in range(min(10, len(dataset))):
pep_per_spec.append(len(dataset[i][3].split(",")))
new_batch_size = int(self.batch_size // np.mean(pep_per_spec))
return torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
batch_size=new_batch_size,
collate_fn=prepare_db_batch,
pin_memory=True,
num_workers=self.n_workers,
Expand Down
35 changes: 18 additions & 17 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,11 +962,9 @@ def predict_step(self, batch, *args):
precursors,
encoded_ms,
) in self.smart_batch_gen(batch):
with torch.set_grad_enabled(True): # Fixes NaN!?
pred, truth = self.decoder(
peptides, precursors, *encoded_ms
) #! get detailed breakdown of this speed
sm = torch.nn.Softmax(dim=2) # dim=2 is very important!
with torch.set_grad_enabled(True):
pred, truth = self.decoder(peptides, precursors, *encoded_ms)
sm = torch.nn.Softmax(dim=2)
pred = sm(pred)
score_result, per_aa_score = calc_match_score(
pred, truth
Expand Down Expand Up @@ -1026,8 +1024,9 @@ def on_predict_epoch_end(self, results) -> None:
for index, t_or_d, peptide, score, per_aa_scores in list(
zip(*batch)
):
# Remove scores of 0 (padding)
per_aa_scores = per_aa_scores.numpy()
per_aa_scores = per_aa_scores[per_aa_scores != 0]
per_aa_scores = list(per_aa_scores[per_aa_scores != 0])
score = score.numpy()
csv_writer.writerow(
(
Expand Down Expand Up @@ -1061,17 +1060,19 @@ def calc_match_score(
The score between the input spectra and associated peptide (for an entire batch)
a list of lists of per amino acid scores (for an entire batch)
"""
batch_all_aa_scores = batch_all_aa_scores[
:, :-1, : # -2
] # Remove trailing tokens from predictions, change to -1 to keep stop token
truth_aa_indicies = truth_aa_indicies[
:, : # -1
] # Remove trailing tokens from label, remove -1 to keep stop token
per_aa_scores = batch_all_aa_scores[
torch.arange(batch_all_aa_scores.shape[0])[:, None],
torch.arange(0, batch_all_aa_scores.shape[1]),
truth_aa_indicies,
]
# Remove trailing tokens from predictions,
batch_all_aa_scores = batch_all_aa_scores[:, :-1]

# Vectorized scoring using efficient indexing.
rows = (
torch.arange(batch_all_aa_scores.shape[0])
.unsqueeze(-1)
.expand(-1, batch_all_aa_scores.shape[1])
)
cols = torch.arange(0, batch_all_aa_scores.shape[1]).expand_as(rows)

per_aa_scores = batch_all_aa_scores[rows, cols, truth_aa_indicies]

score_mask = truth_aa_indicies != 0
masked_per_aa_scores = per_aa_scores * score_mask
all_scores = masked_per_aa_scores.sum(dim=1) / score_mask.sum(dim=1)
Expand Down
2 changes: 1 addition & 1 deletion casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def db_search(
abs_experiment_dirpath = "/net/noble/vol2/home/vananth3/2023_vananth_denovo-dbsearch/results/2023-08-21_speedup"
profiler = AdvancedProfiler(
dirpath=abs_experiment_dirpath,
filename="test_remove_cms",
filename="placeholder",
)
trainer = pl.Trainer(
accelerator="auto",
Expand Down

0 comments on commit e895a5d

Please sign in to comment.