diff --git a/CHANGELOG.md b/CHANGELOG.md index 1824cc1f..9bd936a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,11 +8,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added -- During training, model checkpoints will now be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run. +- During training, model checkpoints will be saved at the end of each training epoch in addition to the checkpoints saved at the end of every validation run. +- Besides as a local file, model weights can be specified from a URL. Upon initial download, the weights file is cached for future re-use. ### Fixed -- Precursor charges are now exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification. +- Precursor charges are exported as integers instead of floats in the mzTab output file, in compliance with the mzTab specification. ## [4.2.1] - 2024-06-25 diff --git a/casanovo/casanovo.py b/casanovo/casanovo.py index d93b748c..cd2274a0 100644 --- a/casanovo/casanovo.py +++ b/casanovo/casanovo.py @@ -2,12 +2,14 @@ import datetime import functools +import hashlib import logging import os import re import shutil import sys import time +import urllib.parse import warnings from pathlib import Path from typing import Optional, Tuple @@ -60,10 +62,9 @@ def __init__(self, *args, **kwargs) -> None: click.Option( ("-m", "--model"), help=""" - The model weights (.ckpt file). If not provided, Casanovo - will try to download the latest release. + Either the model weights (.ckpt file) or a URL pointing to the model weights + file. If not provided, Casanovo will try to download the latest release automatically. """, - type=click.Path(exists=True, dir_okay=False), ), click.Option( ("-o", "--output"), @@ -365,22 +366,34 @@ def setup_model( seed_everything(seed=config["random_seed"], workers=True) # Download model weights if these were not specified (except when training). - if model is None and not is_train: - try: - model = _get_model_weights() - except github.RateLimitExceededException: - logger.error( - "GitHub API rate limit exceeded while trying to download the " - "model weights. Please download compatible model weights " - "manually from the official Casanovo code website " - "(https://github.com/Noble-Lab/casanovo) and specify these " - "explicitly using the `--model` parameter when running " - "Casanovo." + cache_dir = Path(appdirs.user_cache_dir("casanovo", False, opinion=False)) + if model is None: + if not is_train: + try: + model = _get_model_weights(cache_dir) + except github.RateLimitExceededException: + logger.error( + "GitHub API rate limit exceeded while trying to download the " + "model weights. Please download compatible model weights " + "manually from the official Casanovo code website " + "(https://github.com/Noble-Lab/casanovo) and specify these " + "explicitly using the `--model` parameter when running " + "Casanovo." + ) + raise PermissionError( + "GitHub API rate limit exceeded while trying to download the " + "model weights" + ) from None + else: + if _is_valid_url(model): + model = _get_weights_from_url(model, cache_dir) + elif not Path(model).is_file(): + error_msg = ( + f"{model} is not a valid URL or checkpoint file path, " + "--model argument must be a URL or checkpoint file path" ) - raise PermissionError( - "GitHub API rate limit exceeded while trying to download the " - "model weights" - ) from None + logger.error(error_msg) + raise ValueError(error_msg) # Log the active configuration. logger.info("Casanovo version %s", str(__version__)) @@ -393,7 +406,7 @@ def setup_model( return config, model -def _get_model_weights() -> str: +def _get_model_weights(cache_dir: Path) -> str: """ Use cached model weights or download them from GitHub. @@ -407,12 +420,16 @@ def _get_model_weights() -> str: Note that the GitHub API is limited to 60 requests from the same IP per hour. + Parameters + ---------- + cache_dir : Path + model weights cache directory path + Returns ------- str The name of the model weights file. """ - cache_dir = appdirs.user_cache_dir("casanovo", False, opinion=False) os.makedirs(cache_dir, exist_ok=True) version = utils.split_version(__version__) version_match: Tuple[Optional[str], Optional[str], int] = None, None, 0 @@ -436,7 +453,7 @@ def _get_model_weights() -> str: "Model weights file %s retrieved from local cache", version_match[0], ) - return version_match[0] + return Path(version_match[0]) # Otherwise try to find compatible model weights on GitHub. else: repo = github.Github().get_repo("Noble-Lab/casanovo") @@ -469,19 +486,9 @@ def _get_model_weights() -> str: # Download the model weights if a matching release was found. if version_match[2] > 0: filename, url, _ = version_match - logger.info( - "Downloading model weights file %s from %s", filename, url - ) - r = requests.get(url, stream=True, allow_redirects=True) - r.raise_for_status() - file_size = int(r.headers.get("Content-Length", 0)) - desc = "(Unknown total file size)" if file_size == 0 else "" - r.raw.read = functools.partial(r.raw.read, decode_content=True) - with tqdm.tqdm.wrapattr( - r.raw, "read", total=file_size, desc=desc - ) as r_raw, open(filename, "wb") as f: - shutil.copyfileobj(r_raw, f) - return filename + cache_file_path = cache_dir / filename + _download_weights(url, cache_file_path) + return cache_file_path else: logger.error( "No matching model weights for release v%s found, please " @@ -496,5 +503,130 @@ def _get_model_weights() -> str: ) +def _get_weights_from_url( + file_url: str, + cache_dir: Path, + force_download: Optional[bool] = False, +) -> Path: + """ + Resolve weight file from URL + + Attempt to download weight file from URL if weights are not already + cached - otherwise use cached weights. Downloaded weight files will be + cached. + + Parameters + ---------- + file_url : str + URL pointing to model weights file. + cache_dir : Path + Model weights cache directory path. + force_download : Optional[bool], default=False + If True, forces a new download of the weight file even if it exists in + the cache. + + Returns + ------- + Path + Path to the cached weights file. + """ + if not _is_valid_url(file_url): + raise ValueError("file_url must point to a valid URL") + + os.makedirs(cache_dir, exist_ok=True) + cache_file_name = Path(urllib.parse.urlparse(file_url).path).name + url_hash = hashlib.shake_256(file_url.encode("utf-8")).hexdigest(5) + cache_file_dir = cache_dir / url_hash + cache_file_path = cache_file_dir / cache_file_name + + if cache_file_path.is_file() and not force_download: + cache_time = cache_file_path.stat() + url_last_modified = 0 + + try: + file_response = requests.head(file_url) + if file_response.ok: + if "Last-Modified" in file_response.headers: + url_last_modified = datetime.datetime.strptime( + file_response.headers["Last-Modified"], + "%a, %d %b %Y %H:%M:%S %Z", + ).timestamp() + else: + logger.warning( + "Attempted HEAD request to %s yielded non-ok status code - using cached file", + file_url, + ) + except ( + requests.ConnectionError, + requests.Timeout, + requests.TooManyRedirects, + ): + logger.warning( + "Failed to reach %s to get remote last modified time - using cached file", + file_url, + ) + + if cache_time.st_mtime > url_last_modified: + logger.info( + "Model weights %s retrieved from local cache", file_url + ) + return cache_file_path + + _download_weights(file_url, cache_file_path) + return cache_file_path + + +def _download_weights(file_url: str, download_path: Path) -> None: + """ + Download weights file from URL + + Download the model weights file from the specified URL and save it to the + given path. Ensures the download directory exists, and uses a progress + bar to indicate download status. + + Parameters + ---------- + file_url : str + URL pointing to the model weights file. + download_path : Path + Path where the downloaded weights file will be saved. + """ + download_file_dir = download_path.parent + os.makedirs(download_file_dir, exist_ok=True) + response = requests.get(file_url, stream=True, allow_redirects=True) + response.raise_for_status() + file_size = int(response.headers.get("Content-Length", 0)) + desc = "(Unknown total file size)" if file_size == 0 else "" + response.raw.read = functools.partial( + response.raw.read, decode_content=True + ) + + with tqdm.tqdm.wrapattr( + response.raw, "read", total=file_size, desc=desc + ) as r_raw, open(download_path, "wb") as file: + shutil.copyfileobj(r_raw, file) + + +def _is_valid_url(file_url: str) -> bool: + """ + Determine whether file URL is a valid URL + + Parameters + ---------- + file_url : str + url to verify + + Return + ------ + is_url : bool + whether file_url is a valid url + """ + try: + result = urllib.parse.urlparse(file_url) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + if __name__ == "__main__": main() diff --git a/docs/file_formats.md b/docs/file_formats.md index 7cde8c7a..b01e4c02 100644 --- a/docs/file_formats.md +++ b/docs/file_formats.md @@ -2,6 +2,8 @@ ## Input file formats for Casanovo +### MS/MS spectra + When you're ready to use Casanovo for *de novo* peptide sequencing, you can input your MS/MS spectra in one of the following formats: - **[mzML](https://doi.org/10.1074/mcp.R110.000133)**: XML-based mass spectrometry community standard file format developed by the Proteomics Standards Initiative (PSI). @@ -11,6 +13,19 @@ When you're ready to use Casanovo for *de novo* peptide sequencing, you can inpu All three of the above file formats can be used as input to Casanovo for *de novo* peptide sequencing. As the official PSI standard format containing the complete information from a mass spectrometry run, mzML should typically be preferred. +### Model weights + +In addition to MS/MS spectra, Casanovo also optionally accepts a model weights (.ckpt extension) input file when running in training, sequencing, or evaluating mode. +These weights define the functionality of the Casanovo neural network. + +If no input weights file is provided, Casanovo will automatically use the most recent compatible weights from the [official Casanovo GitHub repository](https://github.com/Noble-Lab/casanovo), which will be downloaded and cached locally if they are not already. +Model weights are retrieved by matching Casanovo release version, which is of the form (major, minor, patch). +If no model weights for an identical release are available, alternative releases with matching (i) major and minor, or (ii) major versions will be used. + +Alternatively, you can input custom model weights in the form of a local file system path or a URL pointing to a compatible Casanovo model weights file. +If a URL is provided, the upstream weights file will be downloaded and cached locally for later use. +See the [command line interface documentation](cli.rst) for more details. + ## Output: Understanding the mzTab format After Casanovo processes your input file(s), it provides the sequencing results in an **[mzTab]((https://doi.org/10.1074/mcp.O113.036681))** file. diff --git a/docs/images/evaluate-help.svg b/docs/images/evaluate-help.svg index 2f770e2e..bd8b258f 100644 --- a/docs/images/evaluate-help.svg +++ b/docs/images/evaluate-help.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + - + - + - - $ casanovo evaluate --help - -Usage:casanovo evaluate [OPTIONSANNOTATED_PEAK_PATH...                       - - Evaluate de novo peptide sequencing performance.                                - ANNOTATED_PEAK_PATH must be one or more annoated MGF files, such as those       - provided by MassIVE-KB.                                                         - -╭─ Arguments ──────────────────────────────────────────────────────────────────╮ -*  ANNOTATED_PEAK_PATH    FILE[required] -╰──────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ────────────────────────────────────────────────────────────────────╮ ---model-mFILE                        The model weights (.ckpt file).  -                                              If not provided, Casanovo will   -                                              try to download the latest       -                                              release.                         ---output-oFILE                        The mzTab file to which results  -                                              will be written.                 ---config-cFILE                        The YAML configuration file      -                                              overriding the default options.  ---verbosity-v[debug|info|warning|error]  Set the verbosity of console     -                                              logging messages. Log files are  -                                              always set to 'debug'.           ---help-h  Show this message and exit.      -╰──────────────────────────────────────────────────────────────────────────────╯ - + + $ casanovo evaluate --help + +Usage:casanovo evaluate [OPTIONSANNOTATED_PEAK_PATH...                       + + Evaluate de novo peptide sequencing performance.                                + ANNOTATED_PEAK_PATH must be one or more annoated MGF files, such as those       + provided by MassIVE-KB.                                                         + +╭─ Arguments ──────────────────────────────────────────────────────────────────╮ +*  ANNOTATED_PEAK_PATH    FILE[required] +╰──────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────╮ +--model-mTEXT                        Either the model weights (.ckpt  +                                              file) or a URL pointing to the   +                                              model weights file. If not       +                                              provided, Casanovo will try to   +                                              download the latest release      +                                              automatically.                   +--output-oFILE                        The mzTab file to which results  +                                              will be written.                 +--config-cFILE                        The YAML configuration file      +                                              overriding the default options.  +--verbosity-v[debug|info|warning|error]  Set the verbosity of console     +                                              logging messages. Log files are  +                                              always set to 'debug'.           +--help-h  Show this message and exit.      +╰──────────────────────────────────────────────────────────────────────────────╯ + diff --git a/docs/images/help.svg b/docs/images/help.svg index 80d63c7e..baf2e237 100644 --- a/docs/images/help.svg +++ b/docs/images/help.svg @@ -1,4 +1,4 @@ - + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - + - + - - $ casanovo --help - -Usage:casanovo [OPTIONSCOMMAND [ARGS]...                                     - - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓  - ┃                                  Casanovo                                  ┃  - ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛  - Casanovo de novo sequences peptides from tandem mass spectra using a            - Transformer model. Casanovo currently supports mzML, mzXML, and MGF files for   - de novo sequencing and annotated MGF files, such as those from MassIVE-KB, for  - training new models.                                                            - - Links:                                                                          - - • Documentation: https://casanovo.readthedocs.io                               - • Official code repository: https://github.com/Noble-Lab/casanovo              - - If you use Casanovo in your work, please cite:                                  - - • Yilmaz, M., Fondrie, W. E., Bittremieux, W., Oh, S. & Noble, W. S. De novo   -mass spectrometry peptide sequencing with a transformer model. Proceedings   -of the 39th International Conference on Machine Learning - ICML '22 (2022)   -doi:10.1101/2022.02.07.479481.                                               - -╭─ Options ────────────────────────────────────────────────────────────────────╮ ---help-h    Show this message and exit.                                     -╰──────────────────────────────────────────────────────────────────────────────╯ -╭─ Commands ───────────────────────────────────────────────────────────────────╮ -configure Generate a Casanovo configuration file to customize.               -evaluate  Evaluate de novo peptide sequencing performance.                   -sequence  De novo sequence peptides from tandem mass spectra.                -train     Train a Casanovo model on your own data.                           -version   Get the Casanovo version information                               -╰──────────────────────────────────────────────────────────────────────────────╯ - + + $ casanovo --help diff --git a/docs/images/sequence-help.svg b/docs/images/sequence-help.svg index 6635cfaa..5e75dfe4 100644 --- a/docs/images/sequence-help.svg +++ b/docs/images/sequence-help.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + - + - + - - $ casanovo sequence --help - -Usage:casanovo sequence [OPTIONSPEAK_PATH...                                 - - De novo sequence peptides from tandem mass spectra.                             - PEAK_PATH must be one or more mzMl, mzXML, or MGF files from which to sequence  - peptides.                                                                       - -╭─ Arguments ──────────────────────────────────────────────────────────────────╮ -*  PEAK_PATH    FILE[required] -╰──────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ────────────────────────────────────────────────────────────────────╮ ---model-mFILE                        The model weights (.ckpt file).  -                                              If not provided, Casanovo will   -                                              try to download the latest       -                                              release.                         ---output-oFILE                        The mzTab file to which results  -                                              will be written.                 ---config-cFILE                        The YAML configuration file      -                                              overriding the default options.  ---verbosity-v[debug|info|warning|error]  Set the verbosity of console     -                                              logging messages. Log files are  -                                              always set to 'debug'.           ---help-h  Show this message and exit.      -╰──────────────────────────────────────────────────────────────────────────────╯ - + + $ casanovo sequence --help + +Usage:casanovo sequence [OPTIONSPEAK_PATH...                                 + + De novo sequence peptides from tandem mass spectra.                             + PEAK_PATH must be one or more mzMl, mzXML, or MGF files from which to sequence  + peptides.                                                                       + +╭─ Arguments ──────────────────────────────────────────────────────────────────╮ +*  PEAK_PATH    FILE[required] +╰──────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────╮ +--model-mTEXT                        Either the model weights (.ckpt  +                                              file) or a URL pointing to the   +                                              model weights file. If not       +                                              provided, Casanovo will try to   +                                              download the latest release      +                                              automatically.                   +--output-oFILE                        The mzTab file to which results  +                                              will be written.                 +--config-cFILE                        The YAML configuration file      +                                              overriding the default options.  +--verbosity-v[debug|info|warning|error]  Set the verbosity of console     +                                              logging messages. Log files are  +                                              always set to 'debug'.           +--help-h  Show this message and exit.      +╰──────────────────────────────────────────────────────────────────────────────╯ + diff --git a/docs/images/train-help.svg b/docs/images/train-help.svg index 58251215..e27717e1 100644 --- a/docs/images/train-help.svg +++ b/docs/images/train-help.svg @@ -1,4 +1,4 @@ - + - - + + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + + + + + + + + - + - + - - $ casanovo train --help - -Usage:casanovo train [OPTIONSTRAIN_PEAK_PATH...                              - - Train a Casanovo model on your own data.                                        - TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those provided  - by MassIVE-KB, from which to train a new Casnovo model.                         - -╭─ Arguments ──────────────────────────────────────────────────────────────────╮ -*  TRAIN_PEAK_PATH    FILE[required] -╰──────────────────────────────────────────────────────────────────────────────╯ -╭─ Options ────────────────────────────────────────────────────────────────────╮ -*--validation_peak_pa…-pFILE                    An annotated MGF file   -                                                       for validation, like    -                                                       from MassIVE-KB. Use    -                                                       this option multiple    -                                                       times to specify        -                                                       multiple files.         -[required]             ---model-mFILE                    The model weights       -                                                       (.ckpt file). If not    -                                                       provided, Casanovo      -                                                       will try to download    -                                                       the latest release.     ---output-oFILE                    The mzTab file to       -                                                       which results will be   -                                                       written.                ---config-cFILE                    The YAML configuration  -                                                       file overriding the     -                                                       default options.        ---verbosity-v[debug|info|warning|er  Set the verbosity of    -ror]  console logging         -                                                       messages. Log files     -                                                       are always set to       -                                                       'debug'.                ---help-h  Show this message and   -                                                       exit.                   -╰──────────────────────────────────────────────────────────────────────────────╯ - + + $ casanovo train --help + +Usage:casanovo train [OPTIONSTRAIN_PEAK_PATH...                              + + Train a Casanovo model on your own data.                                        + TRAIN_PEAK_PATH must be one or more annoated MGF files, such as those provided  + by MassIVE-KB, from which to train a new Casnovo model.                         + +╭─ Arguments ──────────────────────────────────────────────────────────────────╮ +*  TRAIN_PEAK_PATH    FILE[required] +╰──────────────────────────────────────────────────────────────────────────────╯ +╭─ Options ────────────────────────────────────────────────────────────────────╮ +*--validation_peak_pa…-pFILE                    An annotated MGF file   +                                                       for validation, like    +                                                       from MassIVE-KB. Use    +                                                       this option multiple    +                                                       times to specify        +                                                       multiple files.         +[required]             +--model-mTEXT                    Either the model        +                                                       weights (.ckpt file)    +                                                       or a URL pointing to    +                                                       the model weights       +                                                       file. If not provided,  +                                                       Casanovo will try to    +                                                       download the latest     +                                                       release automatically.  +--output-oFILE                    The mzTab file to       +                                                       which results will be   +                                                       written.                +--config-cFILE                    The YAML configuration  +                                                       file overriding the     +                                                       default options.        +--verbosity-v[debug|info|warning|er  Set the verbosity of    +ror]  console logging         +                                                       messages. Log files     +                                                       are always set to       +                                                       'debug'.                +--help-h  Show this message and   +                                                       exit.                   +╰──────────────────────────────────────────────────────────────────────────────╯ + diff --git a/tests/unit_tests/test_unit.py b/tests/unit_tests/test_unit.py index f615a099..120e341a 100644 --- a/tests/unit_tests/test_unit.py +++ b/tests/unit_tests/test_unit.py @@ -1,9 +1,17 @@ import collections +import datetime +import functools +import hashlib import heapq +import io import os +import pathlib import platform +import requests import shutil import tempfile +import unittest +import unittest.mock import einops import github @@ -69,7 +77,172 @@ def test_split_version(): assert version == ("3", "0", "1") -@pytest.mark.skip(reason="Hit rate limit during CI/CD") +class MockResponseGet: + file_content = b"fake model weights content" + + class MockRaw(io.BytesIO): + def read(self, *args, **kwargs): + return super().read(*args) + + def __init__(self): + self.request_counter = 0 + self.is_ok = True + + def raise_for_status(self): + if not self.is_ok: + raise requests.HTTPError + + def __call__(self, url, stream=True, allow_redirects=True): + self.request_counter += 1 + response = unittest.mock.MagicMock() + response.raise_for_status = self.raise_for_status + response.headers = {"Content-Length": str(len(self.file_content))} + response.raw = MockResponseGet.MockRaw(self.file_content) + return response + + +class MockAsset: + def __init__(self, file_name): + self.name = file_name + self.browser_download_url = f"http://example.com/{file_name}" + + +class MockRelease: + def __init__(self, tag_name, assets): + self.tag_name = tag_name + self.assets = [MockAsset(asset) for asset in assets] + + def get_assets(self): + return self.assets + + +class MockRepo: + def __init__( + self, + release_dict={ + "v3.0.0": [ + "casanovo_massivekb.ckpt", + "casanovo_non-enzy.checkpt", + "v3.0.0.zip", + "v3.0.0.tar.gz", + ], + "v3.1.0": ["v3.1.0.zip", "v3.1.0.tar.gz"], + "v3.2.0": ["v3.2.0.zip", "v3.2.0.tar.gz"], + "v3.3.0": ["v3.3.0.zip", "v3.3.0.tar.gz"], + "v4.0.0": [ + "casanovo_massivekb.ckpt", + "casanovo_nontryptic.ckpt", + "v4.0.0.zip", + "v4.0.0.tar.gz", + ], + }, + ): + self.releases = [ + MockRelease(tag_name, assets) + for tag_name, assets in release_dict.items() + ] + + def get_releases(self): + return self.releases + + +class MockGithub: + def __init__(self, releases): + self.releases = releases + + def get_repo(self, repo_name): + return MockRepo() + + +def test_setup_model(monkeypatch): + test_releases = ["3.0.0", "3.0.999", "3.999.999"] + mock_get = MockResponseGet() + mock_github = functools.partial(MockGithub, test_releases) + version = "3.0.0" + + # Test model is none when not training + with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: + mnk.setattr(casanovo, "__version__", version) + mnk.setattr("appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir) + mnk.setattr(github, "Github", mock_github) + mnk.setattr(requests, "get", mock_get) + filename = pathlib.Path(tmp_dir) / "casanovo_massivekb_v3_0_0.ckpt" + + assert not filename.is_file() + _, result_path = casanovo.setup_model(None, None, None, False) + assert result_path.resolve() == filename.resolve() + assert filename.is_file() + assert mock_get.request_counter == 1 + os.remove(result_path) + + assert not filename.is_file() + _, result = casanovo.setup_model(None, None, None, True) + assert result is None + assert not filename.is_file() + assert mock_get.request_counter == 1 + + with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: + mnk.setattr(casanovo, "__version__", version) + mnk.setattr("appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir) + mnk.setattr(github, "Github", mock_github) + mnk.setattr(requests, "get", mock_get) + + cache_file_name = "model_weights.ckpt" + file_url = f"http://www.example.com/{cache_file_name}" + url_hash = hashlib.shake_256(file_url.encode("utf-8")).hexdigest(5) + cache_dir = pathlib.Path(tmp_dir) + cache_file_dir = cache_dir / url_hash + cache_file_path = cache_file_dir / cache_file_name + + assert not cache_file_path.is_file() + _, result_path = casanovo.setup_model(file_url, None, None, False) + assert cache_file_path.is_file() + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 2 + os.remove(result_path) + + assert not cache_file_path.is_file() + _, result_path = casanovo.setup_model(file_url, None, None, False) + assert cache_file_path.is_file() + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 3 + + # Test model is file + with monkeypatch.context() as mnk, tempfile.NamedTemporaryFile( + suffix=".ckpt" + ) as temp_file, tempfile.TemporaryDirectory() as tmp_dir: + mnk.setattr(casanovo, "__version__", version) + mnk.setattr("appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir) + mnk.setattr(github, "Github", mock_github) + mnk.setattr(requests, "get", mock_get) + + temp_file_path = temp_file.name + _, result = casanovo.setup_model(temp_file_path, None, None, False) + assert mock_get.request_counter == 3 + assert result == temp_file_path + + _, result = casanovo.setup_model(temp_file_path, None, None, True) + assert mock_get.request_counter == 3 + assert result == temp_file_path + + # Test model is neither a URL or File + with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: + mnk.setattr(casanovo, "__version__", version) + mnk.setattr("appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir) + mnk.setattr(github, "Github", mock_github) + mnk.setattr(requests, "get", mock_get) + + with pytest.raises(ValueError): + casanovo.setup_model("FooBar", None, None, False) + + assert mock_get.request_counter == 3 + + with pytest.raises(ValueError): + casanovo.setup_model("FooBar", None, None, False) + + assert mock_get.request_counter == 3 + + def test_get_model_weights(monkeypatch): """ Test that model weights can be downloaded from GitHub or used from the @@ -77,26 +250,37 @@ def test_get_model_weights(monkeypatch): """ # Model weights for fully matching version, minor matching version, major # matching version. - for version in ["3.0.0", "3.0.999", "3.999.999"]: + test_releases = ["3.0.0", "3.0.999", "3.999.999"] + mock_get = MockResponseGet() + mock_github = functools.partial(MockGithub, test_releases) + + for version in test_releases: with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: mnk.setattr(casanovo, "__version__", version) mnk.setattr( "appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir ) - - filename = os.path.join(tmp_dir, "casanovo_massivekb_v3_0_0.ckpt") - assert not os.path.isfile(filename) - assert casanovo._get_model_weights() == filename - assert os.path.isfile(filename) - assert casanovo._get_model_weights() == filename + mnk.setattr(github, "Github", mock_github) + mnk.setattr(requests, "get", mock_get) + + tmp_path = pathlib.Path(tmp_dir) + filename = tmp_path / "casanovo_massivekb_v3_0_0.ckpt" + assert not filename.is_file() + result_path = casanovo._get_model_weights(tmp_path) + assert result_path == filename + assert filename.is_file() + result_path = casanovo._get_model_weights(tmp_path) + assert result_path == filename # Impossible to find model weights for (i) full version mismatch and (ii) # major version mismatch. for version in ["999.999.999", "999.0.0"]: - with monkeypatch.context() as mnk: + with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: mnk.setattr(casanovo, "__version__", version) + mnk.setattr(github, "Github", mock_github) + mnk.setattr(requests, "get", mock_get) with pytest.raises(ValueError): - casanovo._get_model_weights() + casanovo._get_model_weights(pathlib.Path(tmp_dir)) # Test GitHub API rate limit. def request(self, *args, **kwargs): @@ -107,8 +291,122 @@ def request(self, *args, **kwargs): with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: mnk.setattr("appdirs.user_cache_dir", lambda n, a, opinion: tmp_dir) mnk.setattr("github.Requester.Requester.requestJsonAndCheck", request) + mnk.setattr(requests, "get", mock_get) + mock_get.request_counter = 0 with pytest.raises(github.RateLimitExceededException): - casanovo._get_model_weights() + casanovo._get_model_weights(pathlib.Path(tmp_dir)) + + assert mock_get.request_counter == 0 + + +class MockResponseHead: + def __init__(self): + self.last_modified = None + self.is_ok = True + self.fail = False + + def __call__(self, url): + if self.fail: + raise requests.ConnectionError + + response = unittest.mock.MagicMock() + response.headers = dict() + response.ok = self.is_ok + if self.last_modified is not None: + response.headers["Last-Modified"] = self.last_modified + + return response + + +def test_get_weights_from_url(monkeypatch): + file_url = "http://example.com/model_weights.ckpt" + + with monkeypatch.context() as mnk, tempfile.TemporaryDirectory() as tmp_dir: + mock_get = MockResponseGet() + mock_head = MockResponseHead() + mnk.setattr(requests, "get", mock_get) + mnk.setattr(requests, "head", mock_head) + + cache_dir = pathlib.Path(tmp_dir) + url_hash = hashlib.shake_256(file_url.encode("utf-8")).hexdigest(5) + cache_file_name = "model_weights.ckpt" + cache_file_dir = cache_dir / url_hash + cache_file_path = cache_file_dir / cache_file_name + + # Test downloading and caching the file + assert not cache_file_path.is_file() + result_path = casanovo._get_weights_from_url(file_url, cache_dir) + assert cache_file_path.is_file() + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 1 + + # Test that cached file is used + result_path = casanovo._get_weights_from_url(file_url, cache_dir) + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 1 + + # Test force downloading the file + result_path = casanovo._get_weights_from_url( + file_url, cache_dir, force_download=True + ) + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 2 + + # Test that file is re-downloaded if last modified is newer than + # file last modified + # NOTE: Assuming test takes < 1 year to run + curr_utc = datetime.datetime.now().astimezone(datetime.timezone.utc) + mock_head.last_modified = ( + curr_utc + datetime.timedelta(days=365.0) + ).strftime("%a, %d %b %Y %H:%M:%S GMT") + result_path = casanovo._get_weights_from_url(file_url, cache_dir) + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 3 + + # Test file is not redownloaded if its newer than upstream file + mock_head.last_modified = ( + curr_utc - datetime.timedelta(days=365.0) + ).strftime("%a, %d %b %Y %H:%M:%S GMT") + result_path = casanovo._get_weights_from_url(file_url, cache_dir) + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 3 + + # Test that error is raised if file get response is not OK + mock_get.is_ok = False + with pytest.raises(requests.HTTPError): + casanovo._get_weights_from_url( + file_url, cache_dir, force_download=True + ) + mock_get.is_ok = True + assert mock_get.request_counter == 4 + + # Test that cached file is used if head requests yields non-ok status + # code, even if upstream file is newer + mock_head.is_ok = False + mock_head.last_modified = ( + curr_utc + datetime.timedelta(days=365.0) + ).strftime("%a, %d %b %Y %H:%M:%S GMT") + result_path = casanovo._get_weights_from_url(file_url, cache_dir) + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 4 + mock_head.is_ok = True + + # Test that cached file is used if head request fails + mock_head.fail = True + result_path = casanovo._get_weights_from_url(file_url, cache_dir) + assert result_path.resolve() == cache_file_path.resolve() + assert mock_get.request_counter == 4 + mock_head.fail = False + + # Test invalid URL + with pytest.raises(ValueError): + bad_url = "foobar" + casanovo._get_weights_from_url(bad_url, cache_dir) + + +def test_is_valid_url(): + assert casanovo._is_valid_url("https://www.washington.edu/") + assert not casanovo._is_valid_url("foobar") def test_tensorboard():