Skip to content

Commit

Permalink
integration test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Dec 2, 2024
1 parent 3028cd2 commit ec20013
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
5 changes: 4 additions & 1 deletion casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,9 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None:
self.model = Model.load_from_checkpoint(
self.model_filename, map_location=device, **loaded_model_params
)

# Use tokenizer initialized from config file instead of loaded
# from checkpoint file
self.model.tokenizer = tokenizer
architecture_params = set(model_params.keys()) - set(
loaded_model_params.keys()
)
Expand All @@ -515,6 +517,7 @@ def initialize_model(self, train: bool, db_search: bool = False) -> None:
map_location=device,
**model_params,
)
self.model.tokenizer = tokenizer
except RuntimeError:
raise RuntimeError(
"Weights file incompatible with the current version of "
Expand Down
26 changes: 22 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,8 @@ def _create_mzml(peptides, mzml_file, random_state=42):
return mzml_file


@pytest.fixture
def tiny_config(tmp_path):
"""A config file for a tiny model."""
def get_config_file(file_path, file_name, additional_cfg=None):
"""Get Casanovo config yaml file"""
cfg = {
"n_head": 2,
"dim_feedforward": 10,
Expand Down Expand Up @@ -345,8 +344,27 @@ def tiny_config(tmp_path):
),
}

cfg_file = tmp_path / "config.yml"
if additional_cfg is not None:
cfg.update(additional_cfg)

cfg_file = file_path / file_name
with cfg_file.open("w+") as out_file:
yaml.dump(cfg, out_file)

return cfg_file


@pytest.fixture
def tiny_config(tmp_path):
"""A config file for a tiny model."""
return get_config_file(tmp_path, "config.yml")


@pytest.fixture
def tiny_config_db(tmp_path):
"""A config file for a db search."""
return get_config_file(
tmp_path,
"config_db.yml",
additional_cfg={"replace_isoleucine_with_leucine": False},
)
4 changes: 3 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from click.testing import CliRunner

from casanovo import casanovo
from casanovo.config import Config

TEST_DIR = Path(__file__).resolve().parent

Expand All @@ -14,6 +15,7 @@ def test_train_and_run(
mgf_small,
mzml_small,
tiny_config,
tiny_config_db,
tmp_path,
monkeypatch,
mgf_medium,
Expand Down Expand Up @@ -158,7 +160,7 @@ def test_train_and_run(
"--model",
str(model_file),
"--config",
tiny_config,
tiny_config_db,
"--output_dir",
str(tmp_path),
"--output_root",
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def test_save_final_model(tmp_path, mgf_small, tiny_config):
# Test checkpoint saving when val_check_interval is greater than training steps
config = Config(tiny_config)
config.val_check_interval = 50
model_file = tmp_path / "epoch=14-step=15.ckpt"
model_file = tmp_path / "epoch=19-step=20.ckpt"
with ModelRunner(config, output_dir=tmp_path) as runner:
runner.train([mgf_small], [mgf_small])

Expand All @@ -224,7 +224,7 @@ def test_save_final_model(tmp_path, mgf_small, tiny_config):
# Test checkpoint saving when val_check_interval is not a factor of training steps
config.val_check_interval = 15
validation_file = tmp_path / "foobar.best.ckpt"
model_file = tmp_path / "foobar.epoch=14-step=15.ckpt"
model_file = tmp_path / "foobar.epoch=19-step=20.ckpt"
with ModelRunner(
config, output_dir=tmp_path, output_rootname="foobar"
) as runner:
Expand Down

0 comments on commit ec20013

Please sign in to comment.