Skip to content

Commit

Permalink
fixed bad prediction bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Ananth committed Sep 24, 2023
1 parent bcceff5 commit 3ecb3cb
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions casanovo/denovo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,17 +948,23 @@ class DBSpec2Pep(Spec2Pep):
Hijacks teacher-forcing implemented in Spec2Pep and uses it to predict scores between a spectra and associated peptide
"""

num_pairs = 1024
num_pairs = 2

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def predict_step(self, batch, *args):
batch_res = []
for indexes, t_or_d, new_batch in self.smart_batch_gen(batch):
for (
indexes,
t_or_d,
peptides,
precursors,
encoded_ms,
) in self.smart_batch_gen(batch):
with torch.set_grad_enabled(True): # Fixes NaN!?
encoded_ms, precursors, peptides = new_batch
pred, truth = self.decoder(peptides, precursors, *encoded_ms)
print(truth)
sm = torch.nn.Softmax(dim=2) # dim=2 is very important!
pred = sm(pred)
score_result, per_aa_score = calc_match_score(
Expand Down Expand Up @@ -994,7 +1000,7 @@ def smart_batch_gen(self, batch):
)
)
)
# continually grab n items from al_psm until list is exhausted
# continually grab n items from all_psm until list is exhausted
while len(all_psm) > 0:
batch = all_psm[: self.num_pairs]
all_psm = all_psm[self.num_pairs :]
Expand All @@ -1007,20 +1013,21 @@ def smart_batch_gen(self, batch):
pep_str = list(batch[2])
indexes = [a[1] for a in batch[3]]
t_or_d = batch[4]
yield (indexes, t_or_d, (encoded_ms, prec_data, pep_str))
yield (indexes, t_or_d, pep_str, prec_data, encoded_ms)

def on_predict_epoch_end(self, results) -> None:
if self.out_writer is None:
return
results = np.array(results, dtype=object).squeeze()
results = np.array(results, dtype=object).squeeze((0, 1))
with open(self.out_writer.filename, "a") as out_f:
csv_writer = csv.writer(out_f)
for index, t_or_d, peptide, score, per_aa_scores in list(
zip(*results)
):
csv_writer.writerow(
(index, peptide, bool(t_or_d), score, per_aa_scores)
)
for batch in results:
for index, t_or_d, peptide, score, per_aa_scores in list(
zip(*batch)
):
csv_writer.writerow(
(index, peptide, bool(t_or_d), score, per_aa_scores)
)
out_f.close()


Expand Down Expand Up @@ -1063,10 +1070,9 @@ def calc_match_score(
assert round(sum(preds).item()) == 1
aa_scores = []
for scores, true_index in zip(all_aa_pred, truth_indicies):
aa_scores.append(scores[true_index].item())
normalized_score = sum(aa_scores) / len(
aa_scores
) #! Convert to product and normalization by length, or alternatively use LogSoftmax
if true_index != 0:
aa_scores.append(scores[true_index].item())
normalized_score = sum(aa_scores) / len(aa_scores)
all_scores.append(normalized_score)
per_aa_scores.append(aa_scores)
return all_scores, per_aa_scores
Expand Down

0 comments on commit 3ecb3cb

Please sign in to comment.