diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 77df6df5..73c93448 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -684,7 +684,7 @@ def _get_top_peptide( yield [ ( pep_score, - aa_scores, + aa_scores[::-1] if self.decoder.reverse else aa_scores, "".join(self.decoder.detokenize(pred_tokens)), ) for pep_score, _, aa_scores, pred_tokens in heapq.nlargest( diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index f615a099..460c279c 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -240,6 +240,25 @@ def test_beam_search_decode(): [pep[-1] for pep in list(model._get_top_peptide(test_cache))[0]] ) == {"PEPK", "PEPP"} + # Test reverse aa scores when decoder is reversed + pred_cache = { + 0: [(1.0, 0.42, np.array([1.0, 0.0]), torch.Tensor([4, 14]))] + } + + model.decoder.reverse = True + top_peptides = list(model._get_top_peptide(pred_cache)) + assert len(top_peptides) == 1 + assert len(top_peptides[0]) == 1 + assert np.allclose(top_peptides[0][0][1], np.array([0.0, 1.0])) + assert top_peptides[0][0][2] == "EP" + + model.decoder.reverse = False + top_peptides = list(model._get_top_peptide(pred_cache)) + assert len(top_peptides) == 1 + assert len(top_peptides[0]) == 1 + assert np.allclose(top_peptides[0][0][1], np.array([1.0, 0.0])) + assert top_peptides[0][0][2] == "PE" + # Test _get_topk_beams(). # Set scores to proceed generating the unfinished beam. step = 4