From 3ecb3cb38a2ad5efa6fff5a61317b2a2d535a813 Mon Sep 17 00:00:00 2001 From: Varun Ananth Date: Sun, 24 Sep 2023 10:16:16 -0700 Subject: [PATCH] fixed bad prediction bug --- casanovo/denovo/model.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 7753dca8..00db9203 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -948,17 +948,23 @@ class DBSpec2Pep(Spec2Pep): Hijacks teacher-forcing implemented in Spec2Pep and uses it to predict scores between a spectra and associated peptide """ - num_pairs = 1024 + num_pairs = 2 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def predict_step(self, batch, *args): batch_res = [] - for indexes, t_or_d, new_batch in self.smart_batch_gen(batch): + for ( + indexes, + t_or_d, + peptides, + precursors, + encoded_ms, + ) in self.smart_batch_gen(batch): with torch.set_grad_enabled(True): # Fixes NaN!? - encoded_ms, precursors, peptides = new_batch pred, truth = self.decoder(peptides, precursors, *encoded_ms) + print(truth) sm = torch.nn.Softmax(dim=2) # dim=2 is very important! pred = sm(pred) score_result, per_aa_score = calc_match_score( @@ -994,7 +1000,7 @@ def smart_batch_gen(self, batch): ) ) ) - # continually grab n items from al_psm until list is exhausted + # continually grab n items from all_psm until list is exhausted while len(all_psm) > 0: batch = all_psm[: self.num_pairs] all_psm = all_psm[self.num_pairs :] @@ -1007,20 +1013,21 @@ def smart_batch_gen(self, batch): pep_str = list(batch[2]) indexes = [a[1] for a in batch[3]] t_or_d = batch[4] - yield (indexes, t_or_d, (encoded_ms, prec_data, pep_str)) + yield (indexes, t_or_d, pep_str, prec_data, encoded_ms) def on_predict_epoch_end(self, results) -> None: if self.out_writer is None: return - results = np.array(results, dtype=object).squeeze() + results = np.array(results, dtype=object).squeeze((0, 1)) with open(self.out_writer.filename, "a") as out_f: csv_writer = csv.writer(out_f) - for index, t_or_d, peptide, score, per_aa_scores in list( - zip(*results) - ): - csv_writer.writerow( - (index, peptide, bool(t_or_d), score, per_aa_scores) - ) + for batch in results: + for index, t_or_d, peptide, score, per_aa_scores in list( + zip(*batch) + ): + csv_writer.writerow( + (index, peptide, bool(t_or_d), score, per_aa_scores) + ) out_f.close() @@ -1063,10 +1070,9 @@ def calc_match_score( assert round(sum(preds).item()) == 1 aa_scores = [] for scores, true_index in zip(all_aa_pred, truth_indicies): - aa_scores.append(scores[true_index].item()) - normalized_score = sum(aa_scores) / len( - aa_scores - ) #! Convert to product and normalization by length, or alternatively use LogSoftmax + if true_index != 0: + aa_scores.append(scores[true_index].item()) + normalized_score = sum(aa_scores) / len(aa_scores) all_scores.append(normalized_score) per_aa_scores.append(aa_scores) return all_scores, per_aa_scores