diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 6e7d47a9..633b4ac0 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -1029,7 +1029,6 @@ def smart_batch_gen(self, batch): def on_predict_epoch_end(self, results) -> None: if self.out_writer is None: return - results = np.array(results, dtype=object).squeeze((0)) with open(self.out_writer.filename, "a") as out_f: csv_writer = csv.writer(out_f, delimiter="\t") # Write a header IF THE FILE IS BLANK @@ -1044,7 +1043,7 @@ def on_predict_epoch_end(self, results) -> None: ) ) # Write rows - for group in results: + for group in results[0]: for batch in group: for index, t_or_d, peptide, score, per_aa_scores in list( zip(*batch)