From ec20013dc51496b972f3c0d0edbac0209cc89d30 Mon Sep 17 00:00:00 2001 From: Lilferrit Date: Mon, 2 Dec 2024 13:15:15 -0800 Subject: [PATCH] integration test fix --- casanovo/denovo/model_runner.py | 5 ++++- tests/conftest.py | 26 ++++++++++++++++++++++---- tests/test_integration.py | 4 +++- tests/unit_tests/test_runner.py | 4 ++-- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/casanovo/denovo/model_runner.py b/casanovo/denovo/model_runner.py index 10e15cdf..c8fc7125 100644 --- a/casanovo/denovo/model_runner.py +++ b/casanovo/denovo/model_runner.py @@ -494,7 +494,9 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: self.model = Model.load_from_checkpoint( self.model_filename, map_location=device, **loaded_model_params ) - + # Use tokenizer initialized from config file instead of loaded + # from checkpoint file + self.model.tokenizer = tokenizer architecture_params = set(model_params.keys()) - set( loaded_model_params.keys() ) @@ -515,6 +517,7 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None: map_location=device, **model_params, ) + self.model.tokenizer = tokenizer except RuntimeError: raise RuntimeError( "Weights file incompatible with the current version of " diff --git a/tests/conftest.py b/tests/conftest.py index e23e9d39..4cc02aed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -253,9 +253,8 @@ def _create_mzml(peptides, mzml_file, random_state=42): return mzml_file -@pytest.fixture -def tiny_config(tmp_path): - """A config file for a tiny model.""" +def get_config_file(file_path, file_name, additional_cfg=None): + """Get Casanovo config yaml file""" cfg = { "n_head": 2, "dim_feedforward": 10, @@ -345,8 +344,27 @@ def tiny_config(tmp_path): ), } - cfg_file = tmp_path / "config.yml" + if additional_cfg is not None: + cfg.update(additional_cfg) + + cfg_file = file_path / file_name with cfg_file.open("w+") as out_file: yaml.dump(cfg, out_file) return cfg_file + + +@pytest.fixture +def tiny_config(tmp_path): + """A config file for a tiny model.""" + return get_config_file(tmp_path, "config.yml") + + +@pytest.fixture +def tiny_config_db(tmp_path): + """A config file for a db search.""" + return get_config_file( + tmp_path, + "config_db.yml", + additional_cfg={"replace_isoleucine_with_leucine": False}, + ) diff --git a/tests/test_integration.py b/tests/test_integration.py index 9eb7e092..b5adfa96 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -6,6 +6,7 @@ from click.testing import CliRunner from casanovo import casanovo +from casanovo.config import Config TEST_DIR = Path(__file__).resolve().parent @@ -14,6 +15,7 @@ def test_train_and_run( mgf_small, mzml_small, tiny_config, + tiny_config_db, tmp_path, monkeypatch, mgf_medium, @@ -158,7 +160,7 @@ def test_train_and_run( "--model", str(model_file), "--config", - tiny_config, + tiny_config_db, "--output_dir", str(tmp_path), "--output_root", diff --git a/tests/unit_tests/test_runner.py b/tests/unit_tests/test_runner.py index 958f1984..10a8d4ef 100644 --- a/tests/unit_tests/test_runner.py +++ b/tests/unit_tests/test_runner.py @@ -207,7 +207,7 @@ def test_save_final_model(tmp_path, mgf_small, tiny_config): # Test checkpoint saving when val_check_interval is greater than training steps config = Config(tiny_config) config.val_check_interval = 50 - model_file = tmp_path / "epoch=14-step=15.ckpt" + model_file = tmp_path / "epoch=19-step=20.ckpt" with ModelRunner(config, output_dir=tmp_path) as runner: runner.train([mgf_small], [mgf_small]) @@ -224,7 +224,7 @@ def test_save_final_model(tmp_path, mgf_small, tiny_config): # Test checkpoint saving when val_check_interval is not a factor of training steps config.val_check_interval = 15 validation_file = tmp_path / "foobar.best.ckpt" - model_file = tmp_path / "foobar.epoch=14-step=15.ckpt" + model_file = tmp_path / "foobar.epoch=19-step=20.ckpt" with ModelRunner( config, output_dir=tmp_path, output_rootname="foobar" ) as runner: