diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index a4f4052a..29f1f6b9 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -172,11 +172,9 @@ def main( logger.info("Train the Casanovo model.") model_runner.train(peak_path, peak_path_val, model, config) elif mode == "db": - logger.info("Database seach with casanovo") + logger.info("Database seach with casanovo.") writer = ms_io.MztabWriter(f"{output}.mztab") - #!writer.set_metadata(config, model=model, config_filename=config_fn) model_runner.db_search(peak_path, model, config, writer) - #!writer.save() def _get_model_weights() -> str: diff --git a/casanovo/denovo/model.py b/casanovo/denovo/model.py index 155f0661..9e74e414 100644 --- a/casanovo/denovo/model.py +++ b/casanovo/denovo/model.py @@ -949,6 +949,7 @@ class DBSpec2Pep(Spec2Pep): """ num_pairs = 1024 + decoy_prefix = "decoy_" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -982,8 +983,18 @@ def smart_batch_gen(self, batch): enc = list(zip(*enc)) for idx, _ in enumerate(batch[0]): spec_peptides = batch[2][idx].split(",") - t_or_d = np.zeros(len(spec_peptides)) - t_or_d[range(0, len(t_or_d), 2)] = 1 + # Check for decoy prefixes and create a bit-vector indicating targets (1) or decoys (0) + t_or_d = [ + 0 if p.startswith(self.decoy_prefix) else 1 + for p in spec_peptides + ] + # Remove decoy prefix + spec_peptides = [ + s[len(self.decoy_prefix) :] + if s.startswith(self.decoy_prefix) + else s + for s in spec_peptides + ] spec_precursors = [precursors[idx]] * len(spec_peptides) spec_enc = [enc[idx]] * len(spec_peptides) spec_idx = [indexes[idx]] * len(spec_peptides)