From 65252927bc7e49f2e3393d98c1f8183635eff802 Mon Sep 17 00:00:00 2001 From: justin-a-sanders <60298590+justin-a-sanders@users.noreply.github.com> Date: Wed, 13 Mar 2024 12:58:16 -0700 Subject: [PATCH] fix bug on demo data --- casanovo/denovo/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 6e7d47a9..633b4ac0 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -1029,7 +1029,6 @@ def smart_batch_gen(self, batch): def on_predict_epoch_end(self, results) -> None: if self.out_writer is None: return - results = np.array(results, dtype=object).squeeze((0)) with open(self.out_writer.filename, "a") as out_f: csv_writer = csv.writer(out_f, delimiter="\t") # Write a header IF THE FILE IS BLANK @@ -1044,7 +1043,7 @@ def on_predict_epoch_end(self, results) -> None: ) ) # Write rows - for group in results: + for group in results[0]: for batch in group: for index, t_or_d, peptide, score, per_aa_scores in list( zip(*batch)