Skip to content

Commit

Permalink
Use config options and auto-downloaded weights (#246)
Browse files Browse the repository at this point in the history
* Use auto-downloaded weights

* Use config options in model init

* Fix linting

* Fix missing return value

* Override mismatching params with ckpt

* Handle corrupt ckpt

* Add test case for parameter mismatch

* Add Python version to screenshot action

* Generate new screengrabs with rich-codex

* Fix import order

* Minor reformatting

---------

Co-authored-by: William Fondrie <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Wout Bittremieux <[email protected]>
  • Loading branch information
4 people authored Dec 12, 2023
1 parent 3b688e8 commit 2aed9e5
Show file tree
Hide file tree
Showing 9 changed files with 403 additions and 352 deletions.
10 changes: 7 additions & 3 deletions .github/workflows/screenshots.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check out the repo
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}

- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Install your custom tools
run: pip install .
run: |
python -m pip install --upgrade pip
pip install .
- name: Generate terminal images with rich-codex
uses: ewels/rich-codex@v1
Expand Down
8 changes: 4 additions & 4 deletions casanovo/casanovo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def sequence(
to sequence peptides.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, False)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing peptides from:")
for peak_file in peak_path:
Expand Down Expand Up @@ -164,7 +164,7 @@ def evaluate(
such as those provided by MassIVE-KB.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, False)
config, model = setup_model(model, config, output, False)
with ModelRunner(config, model) as runner:
logger.info("Sequencing and evaluating peptides from:")
for peak_file in annotated_peak_path:
Expand Down Expand Up @@ -207,7 +207,7 @@ def train(
provided by MassIVE-KB, from which to train a new Casnovo model.
"""
output = setup_logging(output, verbosity)
config = setup_model(model, config, output, True)
config, model = setup_model(model, config, output, True)
with ModelRunner(config, model) as runner:
logger.info("Training a model from:")
for peak_file in train_peak_path:
Expand Down Expand Up @@ -378,7 +378,7 @@ def setup_model(
for key, value in config.items():
logger.debug("%s = %s", str(key), str(value))

return config
return config, model


def _get_model_weights() -> str:
Expand Down
56 changes: 47 additions & 9 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import tempfile
import uuid
import warnings
from pathlib import Path
from typing import Iterable, List, Optional, Union

Expand Down Expand Up @@ -217,6 +218,7 @@ def initialize_model(self, train: bool) -> None:
max_charge=self.config.max_charge,
precursor_mass_tol=self.config.precursor_mass_tol,
isotope_error_range=self.config.isotope_error_range,
min_peptide_len=self.config.min_peptide_len,
n_beams=self.config.n_beams,
top_match=self.config.top_match,
n_log=self.config.n_log,
Expand All @@ -230,6 +232,24 @@ def initialize_model(self, train: bool) -> None:
calculate_precision=self.config.calculate_precision,
)

# Reconfigurable non-architecture related parameters for a loaded model
loaded_model_params = dict(
max_length=self.config.max_length,
precursor_mass_tol=self.config.precursor_mass_tol,
isotope_error_range=self.config.isotope_error_range,
n_beams=self.config.n_beams,
min_peptide_len=self.config.min_peptide_len,
top_match=self.config.top_match,
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
warmup_iters=self.config.warmup_iters,
max_iters=self.config.max_iters,
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
out_writer=self.writer,
calculate_precision=self.config.calculate_precision,
)

from_scratch = (
self.config.train_from_scratch,
self.model_filename is None,
Expand All @@ -248,20 +268,38 @@ def initialize_model(self, train: bool) -> None:
)
raise FileNotFoundError("Could not find the model weights file")

# First try loading model details from the weithgs file,
# otherwise use the provided configuration.
# First try loading model details from the weights file, otherwise use
# the provided configuration.
device = torch.empty(1).device # Use the default device.
try:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
self.model_filename, map_location=device, **loaded_model_params
)
except RuntimeError:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,

architecture_params = set(model_params.keys()) - set(
loaded_model_params.keys()
)
for param in architecture_params:
if model_params[param] != self.model.hparams[param]:
warnings.warn(
f"Mismatching {param} parameter in "
f"model checkpoint ({self.model.hparams[param]}) "
f"vs config file ({model_params[param]}); "
"using the checkpoint."
)
except RuntimeError:
# This only doesn't work if the weights are from an older version
try:
self.model = Spec2Pep.load_from_checkpoint(
self.model_filename,
map_location=device,
**model_params,
)
except RuntimeError:
raise RuntimeError(
"Weights file incompatible with the current version of "
"Casanovo. "
)

def initialize_data_module(
self,
Expand Down
Loading

0 comments on commit 2aed9e5

Please sign in to comment.