From 20799039dc75eb761777cb1f258d01a93ee6b99e Mon Sep 17 00:00:00 2001 From: Will Fondrie Date: Mon, 9 Oct 2023 20:45:42 -0700 Subject: [PATCH] Replace data backend and increase flexibility. (#39) --- CHANGELOG.md | 17 + CODE_OF_CONDUCT.md | 1 - data/README.md | 8 + data/TMT10-Trial-8.mzML | 4 +- depthcharge/__init__.py | 39 +- depthcharge/data/__init__.py | 12 +- depthcharge/data/arrow.py | 302 ++++++ depthcharge/data/fields.py | 31 + depthcharge/data/parsers.py | 386 +++++-- depthcharge/data/spectrum_datasets.py | 960 +++++++----------- depthcharge/primitives.py | 2 + depthcharge/testing.py | 64 ++ depthcharge/tokenizers/peptides.py | 15 + depthcharge/transformers/spectra.py | 130 ++- depthcharge/utils.py | 57 +- depthcharge/version.py | 2 +- pyproject.toml | 6 +- tests/conftest.py | 1 - tests/unit_tests/test_data/test_arrow.py | 121 +++ tests/unit_tests/test_data/test_datasets.py | 357 +++---- tests/unit_tests/test_data/test_loaders.py | 66 +- tests/unit_tests/test_data/test_parsers.py | 275 +++-- tests/unit_tests/test_testing.py | 27 + .../test_peptide_transformers.py | 2 +- .../test_spectrum_transformers.py | 65 +- 25 files changed, 1836 insertions(+), 1114 deletions(-) create mode 100644 data/README.md create mode 100644 depthcharge/data/arrow.py create mode 100644 depthcharge/data/fields.py create mode 100644 depthcharge/testing.py create mode 100644 tests/unit_tests/test_data/test_arrow.py create mode 100644 tests/unit_tests/test_testing.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ed4fa6..622aef0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +We have completely reworked of the data module. +Depthcharge now uses Apache Arrow-based formats instead of HDF5; spectra are converted either Parquet or streamed with PyArrow, optionally into Lance datasets. + +### Breaking Changes +- Mass spectrometry data parsers now function as iterators, yielding batches of spectra as `pyarrow.RecordBatch` objects. +- Parsers can now be told to read arbitrary fields from their respective file formats with the `custom_fields` parameter. +- The parsing functionality of `SpctrumDataset` and its subclasses have been moved to the `spectra_to_*` functions in the data module. +- `SpectrumDataset` and its subclasses now return dictionaries of data rather than a tuple of data. This allows us to incorporate arbitrary additional data + +### Added +- Added the `StreamingSpectrumDataset` for fast inference. +- Added `spectra_to_df`, `spectra_to_df`, `spectra_to_stream` to the `depthcharge.data` module. + +### Changed +- Determining the mass spectrometry data file format is now less fragile. + It now looks for known line contents, rather than relying on the extension. + ## [v0.3.1] - 2023-08-18 ### Added - Support for fine-tuning the wavelengths used for encoding floating point numbers like m/z and intensity to the `FloatEncoder` and `PeakEncoder`. diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 874564e..0f355d6 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -120,4 +120,3 @@ version 2.0, available at [homepage]: https://www.contributor-covenant.org [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html - diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..aaaeaee --- /dev/null +++ b/data/README.md @@ -0,0 +1,8 @@ +# Data for testing and examples + +This directory contains real data to be used in tests and examples within the Depthcharge documentation. +Currently, they all originate from [PXD000001](http://central.proteomexchange.org/cgi/GetDataset?ID=PXD000001) + +## Notes + +- [TMT10-Trial-8.mzML]( TMT10-Trial-8.mzML) was modified manually such that one "charge state" CV accession was changed to an "assumed charge state" CV accession. diff --git a/data/TMT10-Trial-8.mzML b/data/TMT10-Trial-8.mzML index d595e07..82cca9a 100644 --- a/data/TMT10-Trial-8.mzML +++ b/data/TMT10-Trial-8.mzML @@ -83,7 +83,7 @@ - + @@ -157,7 +157,7 @@ - + diff --git a/depthcharge/__init__.py b/depthcharge/__init__.py index 295c778..b7a1c97 100644 --- a/depthcharge/__init__.py +++ b/depthcharge/__init__.py @@ -1,17 +1,28 @@ """Initialize the depthcharge package.""" -from . import ( - data, - encoders, - feedforward, - tokenizers, - transformers, -) -from .primitives import ( - MassSpectrum, - Molecule, - Peptide, - PeptideIons, -) -from .version import _get_version +# Ignore a bunch of pkg_resources warnings from dependencies: +import warnings + +with warnings.catch_warnings(): + for module in ["psims", "pkg_resources"]: + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module=module, + ) + + from . import ( + data, + encoders, + feedforward, + tokenizers, + transformers, + ) + from .primitives import ( + MassSpectrum, + Molecule, + Peptide, + PeptideIons, + ) + from .version import _get_version __version__ = _get_version() diff --git a/depthcharge/data/__init__.py b/depthcharge/data/__init__.py index 304ed7a..14a189e 100644 --- a/depthcharge/data/__init__.py +++ b/depthcharge/data/__init__.py @@ -1,4 +1,14 @@ """The Pytorch Datasets.""" from . import preprocessing +from .arrow import ( + spectra_to_df, + spectra_to_parquet, + spectra_to_stream, +) +from .fields import CustomField from .peptide_datasets import PeptideDataset -from .spectrum_datasets import AnnotatedSpectrumDataset, SpectrumDataset +from .spectrum_datasets import ( + AnnotatedSpectrumDataset, + SpectrumDataset, + StreamingSpectrumDataset, +) diff --git a/depthcharge/data/arrow.py b/depthcharge/data/arrow.py new file mode 100644 index 0000000..456b024 --- /dev/null +++ b/depthcharge/data/arrow.py @@ -0,0 +1,302 @@ +"""Store spectrum data as Arrow tables.""" +from collections.abc import Callable, Generator, Iterable +from os import PathLike +from pathlib import Path + +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq + +from .fields import CustomField +from .parsers import ParserFactory + + +def spectra_to_stream( + peak_file: PathLike, + *, + batch_size: int | None = 100_000, + metadata_df: pl.DataFrame | pl.LazyFrame | None = None, + ms_level: int | Iterable[int] | None = 2, + preprocessing_fn: Callable | Iterable[Callable] | None = None, + valid_charge: Iterable[int] | None = None, + custom_fields: CustomField | Iterable[CustomField] | None = None, + progress: bool = True, +) -> Generator[pa.RecordBatch]: + """Stream mass spectra in an Apache Arrow format, with preprocessing. + + Apache Arrow is a space efficient, columnar data format + that is popular in the data science and engineering community. + This function reads data from a mass spectrometry data format, + extracts the mass spectrum and identifying information. By + default, the schema is: + peak_file: str + scan_id: int + ms_level: int + precursor_mz: float + precursor_charge: int + mz_array: list[float] + intensity_array: list[float] + + An optional metadata DataFrame can be provided to add additional + metadata to each mass spectrum. This DataFrame must contain + a ``scan_id`` column containing the integer scan identifier for + each mass spectrum. For mzML files, this is generally the integer + following ``scan=``, whereas for MGF files this is the zero-indexed + offset of the mass spectrum in the file. + + Finally, custom fields can be extracted from the mass spectrometry + data file for advanced use. This must be a CustomField, where the + name is the new column and the accessor is a function + to extract a value from the corresponding Pyteomics spectrum + dictionary. The pyarrow data type must also be specified. + + Parameters + ---------- + peak_file : PathLike, optional + The mass spectrometry data file in mzML, mzXML, or MGF format. + batch_size : int or None + The number of mass spectra in each RecordBatch. ``None`` will load + all of the spectra in a single batch. + metadata_df : polars.DataFrame or polars.LazyFrame, optional + A `polars.DataFrame` containing additional + metadata from the spectra. This is merged on the `scan_id` column + which must be present, and optionally a `peak_file` column, + if present. + ms_level : int, list of int, or None, optional + The level(s) of tandem mass spectra to keep. `None` will retain + all spectra. + preprocessing_fn : Callable or Iterable[Callable], optional + The function(s) used to preprocess the mass spectra. `None`, + the default, filters for the top 200 peaks above m/z 140, + square root transforms the intensities and scales them to unit norm. + See the preprocessing module for details and additional options. + valid_charge : int or list of int, optional + Only consider spectra with the specified precursor charges. If `None`, + any precursor charge is accepted. + custom_fields : CustomField or iterable of CustomField + Additional fields to extract during peak file parsing. + progress : bool, optional + Enable or disable the progress bar. + + Returns + ------- + Generator of pyarrow.RecordBatch + Batches of parsed spectra. + """ + parser_args = { + "ms_level": ms_level, + "valid_charge": valid_charge, + "preprocessing_fn": preprocessing_fn, + "custom_fields": custom_fields, + "progress": progress, + } + + on_cols = ["scan_id"] + validation = "1:1" + if metadata_df is not None: + metadata_df = metadata_df.lazy() + if "peak_file" in metadata_df.columns: + # Validation is only supported when on is a single column. + # Adding a footgun here to remove later... + validation = "m:m" + on_cols.append("peak_file") + + parser = ParserFactory.get_parser(peak_file, **parser_args) + for batch in parser.iter_batches(batch_size=batch_size): + if metadata_df is not None: + batch = ( + pl.from_arrow(batch) + .lazy() + .join( + metadata_df, + on=on_cols, + how="left", + validate=validation, + ) + .collect() + .to_arrow() + .to_batches(max_chunksize=batch_size)[0] + ) + + yield batch + + +def spectra_to_parquet( + peak_file: PathLike, + *, + parquet_file: PathLike = None, + batch_size: int = 100_000, + metadata_df: pl.DataFrame | None = None, + ms_level: int | Iterable[int] | None = 2, + preprocessing_fn: Callable | Iterable[Callable] | None = None, + valid_charge: Iterable[int] | None = None, + custom_fields: CustomField | Iterable[CustomField] | None = None, + progress: bool = True, +) -> Path: + """Stream mass spectra to Apache Parquet, with preprocessing. + + Apache Parquet is a space efficient, columnar data storage format + that is popular in the data science and engineering community. + This function reads data from a mass spectrometry data format, + extracts the mass spectrum and identifying information. By + default, the schema is: + peak_file: str + scan_id: int + ms_level: int + precursor_mz: float64 + precursor_charge: int8 + mz_array: list[float64] + intensity_array: list[float64] + + An optional metadata DataFrame can be provided to add additional + metadata to each mass spectrum. This DataFrame must contain + a ``scan_id`` column containing the integer scan identifier for + each mass spectrum. For mzML files, this is generally the integer + following ``scan=``, whereas for MGF files this is the zero-indexed + offset of the mass spectrum in the file. + + Finally, custom fields can be extracted from the mass spectrometry + data file for advanced use. This must be a CustomField, where the + name is the new column and the accessor is a function + to extract a value from the corresponding Pyteomics spectrum + dictionary. The pyarrow data type must also be specified. + + Parameters + ---------- + peak_file : PathLike, optional + The mass spectrometry data file in mzML, mzXML, or MGF format. + parquet_file : PathLike, optional + The output file. By default this is the input file stem with a + `.parquet` extension. + batch_size : int + The number of mass spectra to process simultaneously. + metadata_df : polars.DataFrame or polars.LazyFrame, optional + A `polars.DataFrame` containing additional + metadata from the spectra. This is merged on the `scan_id` column + which must be present, and optionally a `peak_file` column, + if present. + ms_level : int, list of int, or None, optional + The level(s) of tandem mass spectra to keep. `None` will retain + all spectra. + preprocessing_fn : Callable or Iterable[Callable], optional + The function(s) used to preprocess the mass spectra. `None`, + the default, filters for the top 200 peaks above m/z 140, + square root transforms the intensities and scales them to unit norm. + See the preprocessing module for details and additional options. + valid_charge : int or list of int, optional + Only consider spectra with the specified precursor charges. If `None`, + any precursor charge is accepted. + custom_fields : CustomField or iterable of CustomField + Additional fields to extract during peak file parsing. + progress : bool, optional + Enable or disable the progress bar. + + Returns + ------- + Path + The Parquet file that was written. + """ + streamer = spectra_to_stream( + peak_file=peak_file, + batch_size=batch_size, + metadata_df=metadata_df, + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + progress=progress, + ) + + if parquet_file is None: + parquet_file = Path(Path(peak_file).stem).with_suffix(".parquet") + + writer = None + for batch in streamer: + if writer is None: + writer = pq.ParquetWriter(parquet_file, schema=batch.schema) + + writer.write_batch(batch) + + return parquet_file + + +def spectra_to_df( + peak_file: PathLike, + *, + metadata_df: pl.DataFrame | None = None, + ms_level: int | Iterable[int] | None = 2, + preprocessing_fn: Callable | Iterable[Callable] | None = None, + valid_charge: Iterable[int] | None = None, + custom_fields: CustomField | Iterable[CustomField] | None = None, + progress: bool = True, +) -> pl.DataFrame: + """Read mass spectra into a Polars DataFrame. + + Apache Parquet is a space efficient, columnar data storage format + that is popular in the data science and engineering community. + This function reads data from a mass spectrometry data format, + extracts the mass spectrum and identifying information. By + default, the schema is: + peak_file: str + scan_id: int + ms_level: int + precursor_mz: float64 + precursor_charge: int8 + mz_array: list[float64] + intensity_array: list[float64] + + An optional metadata DataFrame can be provided to add additional + metadata to each mass spectrum. This DataFrame must contain + a ``scan_id`` column containing the integer scan identifier for + each mass spectrum. For mzML files, this is generally the integer + following ``scan=``, whereas for MGF files this is the zero-indexed + offset of the mass spectrum in the file. + + Finally, custom fields can be extracted from the mass spectrometry + data file for advanced use. This must be a CustomField, where the + name is the new column and the accessor is a function + to extract a value from the corresponding Pyteomics spectrum + dictionary. The pyarrow data type must also be specified. + + Parameters + ---------- + peak_file : PathLike, optional + The mass spectrometry data file in mzML, mzXML, or MGF format. + metadata_df : polars.DataFrame or polars.LazyFrame, optional + A `polars.DataFrame` containing additional + metadata from the spectra. This is merged on the `scan_id` column + which must be present, and optionally a `peak_file` column, + if present. + ms_level : int, list of int, or None, optional + The level(s) of tandem mass spectra to keep. `None` will retain + all spectra. + preprocessing_fn : Callable or Iterable[Callable], optional + The function(s) used to preprocess the mass spectra. `None`, + the default, filters for the top 200 peaks above m/z 140, + square root transforms the intensities and scales them to unit norm. + See the preprocessing module for details and additional options. + valid_charge : int or list of int, optional + Only consider spectra with the specified precursor charges. If `None`, + any precursor charge is accepted. + custom_fields : CustomField or iterable of CustomField + Additional fields to extract during peak file parsing. + progress : bool, optional + Enable or disable the progress bar. + + Returns + ------- + Path + The Parquet file that was written. + """ + streamer = spectra_to_stream( + peak_file=peak_file, + batch_size=None, + metadata_df=metadata_df, + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + progress=progress, + ) + + return pl.from_arrow(streamer) diff --git a/depthcharge/data/fields.py b/depthcharge/data/fields.py new file mode 100644 index 0000000..7d77665 --- /dev/null +++ b/depthcharge/data/fields.py @@ -0,0 +1,31 @@ +"""Custom fields for the Arrow Schema.""" +from collections.abc import Callable +from dataclasses import dataclass + +import pyarrow as pa + + +@dataclass +class CustomField: + """An additional field to extract during peak file parsing. + + The accessor function is used to extract additional information + from each spectrum during parsing. + + Parameters + ---------- + name: str + The resulting column name in the Arrow schema. + accessor: Callable + A function to access the value of interest. The input will + depend on the parser that is used. Currently, we use + Pyteomics to parse mzML, MGF, and mzXML files, so the + parameters for this function would be a dictionary for + each spectrum. + dtype: pyarrow.DataType + The expected Arrow data type for the column in the schema. + """ + + name: str + accessor: Callable + dtype: pa.DataType diff --git a/depthcharge/data/parsers.py b/depthcharge/data/parsers.py index 73415d2..139a24d 100644 --- a/depthcharge/data/parsers.py +++ b/depthcharge/data/parsers.py @@ -6,8 +6,9 @@ from collections.abc import Callable, Iterable from os import PathLike from pathlib import Path +from typing import Any -import numpy as np +import pyarrow as pa from pyteomics.mgf import MGF from pyteomics.mzml import MzML from pyteomics.mzxml import MzXML @@ -15,6 +16,7 @@ from .. import utils from ..primitives import MassSpectrum +from . import preprocessing LOGGER = logging.getLogger(__name__) @@ -24,7 +26,7 @@ class BaseParser(ABC): Parameters ---------- - ms_data_file : PathLike + peak_file : PathLike The peak file to parse. ms_level : int The MS level of the spectra to parse. @@ -33,38 +35,86 @@ class BaseParser(ABC): valid_charge : Iterable[int], optional Only consider spectra with the specified precursor charges. If `None`, any precursor charge is accepted. + custom_fields : dict of str to list of str, optional + Additional field to extract during peak file parsing. The key must + be the resulting column name and value must be an interable of + containing the necessary keys to retreive the value from the + spectrum from the corresponding Pyteomics parser. + progress : bool, optional + Enable or disable the progress bar. id_type : str, optional The Hupo-PSI prefix for the spectrum identifier. """ def __init__( self, - ms_data_file: PathLike, - ms_level: int, + peak_file: PathLike, + ms_level: int | Iterable[int] | None = 2, preprocessing_fn: Callable | Iterable[Callable] | None = None, valid_charge: Iterable[int] | None = None, + custom_fields: dict[str, str | Iterable[str]] | None = None, + progress: bool = True, id_type: str = "scan", ) -> None: """Initialize the BaseParser.""" - self.path = Path(ms_data_file) - self.ms_level = ms_level + self.peak_file = Path(peak_file) + self.progress = progress + self.ms_level = ( + ms_level if ms_level is None else set(utils.listify(ms_level)) + ) + if preprocessing_fn is None: - self.preprocessing_fn = [] + self.preprocessing_fn = [ + preprocessing.set_mz_range(min_mz=140), + preprocessing.filter_intensity(max_num_peaks=200), + preprocessing.scale_intensity(scaling="root"), + preprocessing.scale_to_unit_norm, + ] else: self.preprocessing_fn = utils.listify(preprocessing_fn) self.valid_charge = None if valid_charge is None else set(valid_charge) + self.custom_fields = custom_fields self.id_type = id_type - self.offset = None - self.precursor_mz = [] - self.precursor_charge = [] - self.scan_id = [] - self.mz_arrays = [] - self.intensity_arrays = [] - self.annotations = None + + # Check format: + self.sniff() + + # Used during parsing: + self._batch = None + + # Define the schema + self.schema = pa.schema( + [ + pa.field("peak_file", pa.string()), + pa.field("scan_id", pa.int64()), + pa.field("ms_level", pa.uint8()), + pa.field("precursor_mz", pa.float64()), + pa.field("precursor_charge", pa.int16()), + pa.field("mz_array", pa.list_(pa.float64())), + pa.field("intensity_array", pa.list_(pa.float64())), + ] + ) + + if self.custom_fields is not None: + self.custom_fields = utils.listify(self.custom_fields) + for field in self.custom_fields: + self.schema = self.schema.append( + pa.field(field.name, field.dtype) + ) @abstractmethod - def open(self) -> Iterable: + def sniff(self) -> None: + """Quickly test a file for the correct type. + + Raises + ------ + IOError + Raised if the file is not the expected format. + """ + + @abstractmethod + def open(self) -> Iterable[dict]: """Open the file as an iterable.""" @abstractmethod @@ -82,66 +132,115 @@ def parse_spectrum(self, spectrum: dict) -> MassSpectrum | None: The parsed mass spectrum or None if it is skipped. """ - def read(self) -> BaseParser: - """Read the ms data file. + def parse_custom_fields(self, spectrum: dict) -> dict[str, Any]: + """Parse user-provided fields. + + Parameters + ---------- + spectrum : dict + The dictionary defining the spectrum in a given format. Returns ------- - Self + dict + The parsed value of each, whatever it may be. """ + out = {} + if self.custom_fields is None: + return out + + for field in self.custom_fields: + out[field.name] = field.accessor(spectrum) + + return out + + def iter_batches(self, batch_size: int | None) -> pa.RecordBatch: + """Iterate over batches of mass spectra in the Arrow format. + + Parameters + ---------- + batch_size : int or None + The number of spectra in a batch. ``None`` loads all of + the spectra in a single batch. + + Yields + ------ + RecordBatch + A batch of spectra and their metadata. + """ + batch_size = float("inf") if batch_size is None else batch_size + pbar_args = { + "desc": self.peak_file.name, + "unit": " spectra", + "disable": not self.progress, + } + n_skipped = 0 + last_exc = None with self.open() as spectra: - for spectrum in tqdm(spectra, desc=str(self.path), unit="spectra"): + self._batch = None + for spectrum in tqdm(spectra, **pbar_args): try: - spectrum = self.parse_spectrum(spectrum) - if spectrum is None: + parsed = self.parse_spectrum(spectrum) + if parsed is None: continue if self.preprocessing_fn is not None: for processor in self.preprocessing_fn: - spectrum = processor(spectrum) - - self.mz_arrays.append(spectrum.mz) - self.intensity_arrays.append(spectrum.intensity) - self.precursor_mz.append(spectrum.precursor_mz) - self.precursor_charge.append(spectrum.precursor_charge) - self.scan_id.append(_parse_scan_id(spectrum.scan_id)) - if self.annotations is not None: - self.annotations.append(spectrum.label) - except (IndexError, KeyError, ValueError): + parsed = processor(parsed) + + entry = { + "peak_file": self.peak_file.name, + "scan_id": _parse_scan_id(parsed.scan_id), + "ms_level": parsed.ms_level, + "precursor_mz": parsed.precursor_mz, + "precursor_charge": parsed.precursor_charge, + "mz_array": parsed.mz, + "intensity_array": parsed.intensity, + } + + except (IndexError, KeyError, ValueError) as exc: + last_exc = exc n_skipped += 1 + continue + + # Parse custom fields: + entry.update(self.parse_custom_fields(spectrum)) + self._update_batch(entry) + + # Update the batch: + if len(self._batch["scan_id"]) == batch_size: + yield self._yield_batch() + + # Get the remainder: + if self._batch is not None: + yield self._yield_batch() if n_skipped: LOGGER.warning( - "Skipped %d spectra with invalid precursor info", n_skipped + "Skipped %d spectra with invalid information", n_skipped ) + LOGGER.debug("Last error: %s", str(last_exc)) - self.precursor_mz = np.array(self.precursor_mz, dtype=np.float64) - self.precursor_charge = np.array( - self.precursor_charge, - dtype=np.uint8, - ) - - self.scan_id = np.array(self.scan_id) - - # Build the index - sizes = np.array([0] + [s.shape[0] for s in self.mz_arrays]) - self.offset = sizes[:-1].cumsum() - self.mz_arrays = np.concatenate(self.mz_arrays).astype(np.float64) - self.intensity_arrays = np.concatenate(self.intensity_arrays).astype( - np.float32 - ) - return self + def _update_batch(self, entry: dict) -> None: + """Update the batch. - @property - def n_spectra(self) -> int: - """The number of spectra.""" - return self.offset.shape[0] + Parameters + ---------- + entry : dict + The elemtn to add. + """ + if self._batch is None: + self._batch = {k: [v] for k, v in entry.items()} + else: + for key, val in entry.items(): + self._batch[key].append(val) - @property - def n_peaks(self) -> int: - """The number of peaks in the file.""" - return self.mz_arrays.shape[0] + def _yield_batch(self) -> pa.RecordBatch: + """Yield the batch.""" + out = pa.RecordBatch.from_pydict(self._batch, schema=self.schema) + self._batch = None + return out class MzmlParser(BaseParser): @@ -149,7 +248,7 @@ class MzmlParser(BaseParser): Parameters ---------- - ms_data_file : PathLike + peak_file : PathLike The mzML file to parse. ms_level : int The MS level of the spectra to parse. @@ -158,11 +257,31 @@ class MzmlParser(BaseParser): valid_charge : Iterable[int], optional Only consider spectra with the specified precursor charges. If `None`, any precursor charge is accepted. + custom_fields : dict of str to list of str, optional + Additional field to extract during peak file parsing. The key must + be the resulting column name and value must be an interable of + containing the necessary keys to retreive the value from the + spectrum from the corresponding Pyteomics parser. + progress : bool, optional + Enable or disable the progress bar. """ + def sniff(self) -> None: + """Quickly test a file for the correct type. + + Raises + ------ + IOError + Raised if the file is not the expected format. + """ + with self.peak_file.open() as mzdat: + next(mzdat) + if "http://psi.hupo.org/ms/mzml" not in next(mzdat): + raise OSError("Not an mzML file.") + def open(self) -> Iterable[dict]: """Open the mzML file for reading.""" - return MzML(str(self.path)) + return MzML(str(self.peak_file)) def parse_spectrum(self, spectrum: dict) -> MassSpectrum | None: """Parse a single spectrum. @@ -177,12 +296,28 @@ def parse_spectrum(self, spectrum: dict) -> MassSpectrum | None: MassSpectrum or None The parsed mass spectrum or None if not at the correct MS level. """ - if spectrum["ms level"] != self.ms_level: + ms_level = spectrum["ms level"] + if self.ms_level is not None and ms_level not in self.ms_level: return None - if self.ms_level > 1: - precursor = spectrum["precursorList"]["precursor"][0] - precursor_ion = precursor["selectedIonList"]["selectedIon"][0] + if ms_level > 1: + precursor = spectrum["precursorList"]["precursor"] + if len(precursor) > 1: + LOGGER.warning( + "More than one precursor found for spectrum %s. " + "Only the first will be retained.", + spectrum["id"], + ) + + precursor_ion = precursor[0]["selectedIonList"]["selectedIon"] + if len(precursor_ion) > 1: + LOGGER.warning( + "More than one selected ions found for spectrum %s. " + "Only the first will be retained.", + spectrum["id"], + ) + + precursor_ion = precursor_ion[0] precursor_mz = float(precursor_ion["selected ion m/z"]) if "charge state" in precursor_ion: precursor_charge = int(precursor_ion["charge state"]) @@ -195,15 +330,16 @@ def parse_spectrum(self, spectrum: dict) -> MassSpectrum | None: if self.valid_charge is None or precursor_charge in self.valid_charge: return MassSpectrum( - filename=str(self.path), + filename=str(self.peak_file), scan_id=spectrum["id"], mz=spectrum["m/z array"], intensity=spectrum["intensity array"], + ms_level=ms_level, precursor_mz=precursor_mz, precursor_charge=precursor_charge, ) - raise ValueError("Invalid precursor charge") + raise ValueError("Invalid precursor charge.") class MzxmlParser(BaseParser): @@ -211,7 +347,7 @@ class MzxmlParser(BaseParser): Parameters ---------- - ms_data_file : PathLike + peak_file : PathLike The mzXML file to parse. ms_level : int The MS level of the spectra to parse. @@ -220,11 +356,32 @@ class MzxmlParser(BaseParser): valid_charge : Iterable[int], optional Only consider spectra with the specified precursor charges. If `None`, any precursor charge is accepted. + custom_fields : dict of str to list of str, optional + Additional field to extract during peak file parsing. The key must + be the resulting column name and value must be an interable of + containing the necessary keys to retreive the value from the + spectrum from the corresponding Pyteomics parser. + progress : bool, optional + Enable or disable the progress bar. """ + def sniff(self) -> None: + """Quickly test a file for the correct type. + + Raises + ------ + IOError + Raised if the file is not the expected format. + """ + scent = "http://sashimi.sourceforge.net/schema_revision/mzXML" + with self.peak_file.open() as mzdat: + next(mzdat) + if scent not in next(mzdat): + raise OSError("Not an mzXML file.") + def open(self) -> Iterable[dict]: """Open the mzXML file for reading.""" - return MzXML(str(self.path)) + return MzXML(str(self.peak_file)) def parse_spectrum(self, spectrum: dict) -> MassSpectrum | None: """Parse a single spectrum. @@ -239,10 +396,11 @@ def parse_spectrum(self, spectrum: dict) -> MassSpectrum | None: MassSpectrum The parsed mass spectrum. """ - if spectrum["msLevel"] != self.ms_level: + ms_level = spectrum["msLevel"] + if self.ms_level is not None and ms_level not in self.ms_level: return None - if self.ms_level > 1: + if ms_level > 1: precursor = spectrum["precursorMz"][0] precursor_mz = float(precursor["precursorMz"]) precursor_charge = int(precursor.get("precursorCharge", 0)) @@ -251,10 +409,11 @@ def parse_spectrum(self, spectrum: dict) -> MassSpectrum | None: if self.valid_charge is None or precursor_charge in self.valid_charge: return MassSpectrum( - filename=str(self.path), + filename=str(self.peak_file), scan_id=spectrum["id"], mz=spectrum["m/z array"], intensity=spectrum["intensity array"], + ms_level=ms_level, precursor_mz=precursor_mz, precursor_charge=precursor_charge, ) @@ -267,7 +426,7 @@ class MgfParser(BaseParser): Parameters ---------- - ms_data_file : PathLike + peak_file : PathLike The MGF file to parse. ms_level : int The MS level of the spectra to parse. @@ -276,34 +435,55 @@ class MgfParser(BaseParser): valid_charge : Iterable[int], optional Only consider spectra with the specified precursor charges. If `None`, any precursor charge is accepted. - annotations : bool - Include peptide annotations. + custom_fields : dict of str to list of str, optional + Additional field to extract during peak file parsing. The key must + be the resulting column name and value must be an interable of + containing the necessary keys to retreive the value from the + spectrum from the corresponding Pyteomics parser. + progress : bool, optional + Enable or disable the progress bar. """ def __init__( self, - ms_data_file: PathLike, + peak_file: PathLike, ms_level: int = 2, preprocessing_fn: Callable | Iterable[Callable] | None = None, valid_charge: Iterable[int] | None = None, - annotations: bool = False, + custom_fields: dict[str, Iterable[str]] | None = None, + progress: bool = True, ) -> None: """Initialize the MgfParser.""" super().__init__( - ms_data_file, + peak_file, ms_level=ms_level, preprocessing_fn=preprocessing_fn, valid_charge=valid_charge, + custom_fields=custom_fields, + progress=progress, id_type="index", ) - if annotations: - self.annotations = [] - self._counter = -1 + if ms_level is not None: + self._assumed_ms_level = sorted(self.ms_level)[0] + else: + self._assumed_ms_level = None + + def sniff(self) -> None: + """Quickly test a file for the correct type. + + Raises + ------ + IOError + Raised if the file is not the expected format. + """ + with self.peak_file.open() as mzdat: + if not next(mzdat).startswith("BEGIN IONS"): + raise OSError("Not an MGF file.") def open(self) -> Iterable[dict]: """Open the MGF file for reading.""" - return MGF(str(self.path)) + return MGF(str(self.peak_file)) def parse_spectrum(self, spectrum: dict) -> MassSpectrum: """Parse a single spectrum. @@ -314,30 +494,24 @@ def parse_spectrum(self, spectrum: dict) -> MassSpectrum: The dictionary defining the spectrum in MGF format. """ self._counter += 1 - - if self.ms_level > 1: + if self.ms_level is not None and 1 not in self.ms_level: precursor_mz = float(spectrum["params"]["pepmass"][0]) precursor_charge = int(spectrum["params"].get("charge", [0])[0]) else: precursor_mz, precursor_charge = None, 0 - if self.annotations is not None: - label = spectrum["params"].get("seq") - else: - label = None - if self.valid_charge is None or precursor_charge in self.valid_charge: return MassSpectrum( - filename=str(self.path), + filename=str(self.peak_file), scan_id=self._counter, mz=spectrum["m/z array"], intensity=spectrum["intensity array"], + ms_level=self._assumed_ms_level, precursor_mz=precursor_mz, precursor_charge=precursor_charge, - label=label, ) - raise ValueError("Invalid precursor charge") + raise ValueError("Invalid precursor charge.") def _parse_scan_id(scan_str: str | int) -> int: @@ -364,6 +538,38 @@ def _parse_scan_id(scan_str: str | int) -> int: try: return int(scan_str[scan_str.find("scan=") + len("scan=") :]) except ValueError: - pass + try: + return int(scan_str[scan_str.find("index=") + len("index=") :]) + except ValueError: + pass raise ValueError("Failed to parse scan number") + + +class ParserFactory: + """Figure out what parser to use.""" + + parsers = [ + MzmlParser, + MzxmlParser, + MgfParser, + ] + + @classmethod + def get_parser(cls, peak_file: PathLike, **kwargs: dict) -> BaseParser: + """Get the correct parser for a peak file. + + Parameters + ---------- + peak_file: PathLike + The peak file to parse. + kwargs : dict + Keyword arguments to pass to the parser. + """ + for parser in cls.parsers: + try: + return parser(peak_file, **kwargs) + except OSError: + pass + + raise OSError("Unknown file format.") diff --git a/depthcharge/data/spectrum_datasets.py b/depthcharge/data/spectrum_datasets.py index bad97b3..1609948 100644 --- a/depthcharge/data/spectrum_datasets.py +++ b/depthcharge/data/spectrum_datasets.py @@ -1,513 +1,324 @@ -"""Parse mass spectra into an HDF5 file format.""" +"""Serve mass spectra to neural networks.""" from __future__ import annotations -import hashlib import logging import uuid -from collections.abc import Callable, Iterable +import warnings +from collections.abc import Generator, Iterable from os import PathLike from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any +from typing import Any, Literal -import dill -import h5py -import numpy as np +import lance +import polars as pl +import pyarrow as pa +import pyarrow.parquet as pq import torch -from torch.utils.data import DataLoader, Dataset +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset, IterableDataset from .. import utils -from ..primitives import MassSpectrum from ..tokenizers import PeptideTokenizer -from . import preprocessing -from .parsers import MgfParser, MzmlParser, MzxmlParser +from . import arrow LOGGER = logging.getLogger(__name__) -class SpectrumDataset(Dataset): - """Store and access a collection of mass spectra. - - Parses one or more MS data files, adding the mass spectra to an HDF5 - file that indexes the mass spectra from one or multiple files. - This allows depthcharge to access random mass spectra from many different - files quickly and without loading them all into memory. - - Parameters - ---------- - ms_data_files : PathLike or list of PathLike, optional - The mzML, mzXML, or MGF files to include in this collection. Files - can be added later using the ``.add_file()`` method. - ms_level : int, optional - The level of tandem mass spectra to use. - preprocessing_fn : Callable or Iterable[Callable], optional - The function(s) used to preprocess the mass spectra. ``None``, - the default, filters for the top 200 peaks above m/z 140, - square root transforms the intensities and scales them to unit norm. - See the preprocessing module for details and additional options. - valid_charge : Iterable[int], optional - Only consider spectra with the specified precursor charges. If `None`, - any precursor charge is accepted. - index_path : PathLike, optional. - The name and path of the HDF5 file index. If the path does - not contain the `.h5` or `.hdf5` extension, `.hdf5` will be added. - If ``None``, a file will be created in a temporary directory. - overwrite : bool, optional - Overwrite previously indexed files? If ``False`` and new files are - provided, they will be appended to the collection. +class CollateFnMixin: + """Define a common collate function.""" - Attributes - ---------- - ms_files : list of str - path : Path - ms_level : int - valid_charge : Optional[Iterable[int]] - annotated : bool - overwrite : bool - n_spectra : int - n_peaks : int - """ - - _annotated = False + def __init__(self) -> None: + """Specify the float precision.""" + self.precision = torch.float64 - def __init__( + def collate_fn( self, - ms_data_files: PathLike | Iterable[PathLike] = None, - ms_level: int = 2, - preprocessing_fn: Callable | Iterable[Callable] | None = None, - valid_charge: Iterable[int] | None = None, - index_path: PathLike | None = None, - overwrite: bool = False, - ) -> None: - """Initialize a SpectrumIndex.""" - self._tmpdir = None - if index_path is None: - # Create a random temporary file: - self._tmpdir = TemporaryDirectory() - index_path = Path(self._tmpdir.name) / f"{uuid.uuid4()}.hdf5" - - index_path = Path(index_path) - if index_path.suffix not in [".h5", ".hdf5"]: - index_path = Path(index_path) - index_path = Path(str(index_path) + ".hdf5") - - # Set attributes and check parameters: - self._path = index_path - self._ms_level = utils.check_positive_int(ms_level, "ms_level") - self._valid_charge = valid_charge - self._overwrite = bool(overwrite) - self._handle = None - self._file_offsets = np.array([0]) - self._file_map = {} - self._locs = {} - self._offsets = None - - if preprocessing_fn is not None: - self._preprocessing_fn = utils.listify(preprocessing_fn) - else: - self._preprocessing_fn = [ - preprocessing.set_mz_range(min_mz=140), - preprocessing.filter_intensity(max_num_peaks=200), - preprocessing.scale_intensity(scaling="root"), - preprocessing.scale_to_unit_norm, - ] - - # Create the file if it doesn't exist. - if not self.path.exists() or self.overwrite: - with h5py.File(self.path, "w") as index: - index.attrs["ms_level"] = self.ms_level - index.attrs["n_spectra"] = 0 - index.attrs["n_peaks"] = 0 - index.attrs["annotated"] = self.annotated - index.attrs["preprocessing_fn"] = _hash_obj( - tuple(self.preprocessing_fn) - ) - else: - self._validate_index() + batch: Iterable[dict[Any]], + ) -> dict[str, torch.Tensor | list[Any]]: + """The collate function for a SpectrumDataset. - # Now parse spectra. - if ms_data_files is not None: - ms_data_files = utils.listify(ms_data_files) - for ms_file in ms_data_files: - self.add_file(ms_file) - - def _reindex(self) -> None: - """Update the file mappings and offsets.""" - offsets = [0] - for idx in range(len(self._handle)): - grp = self._handle[str(idx)] - offsets.append(grp.attrs["n_spectra"]) - self._file_map[grp.attrs["path"]] = idx - - self._file_offsets = np.cumsum([0] + offsets) - - # Build a map of 1D indices to 2D locations: - grp_idx = 0 - for lin_idx in range(offsets[-1]): - grp_idx += lin_idx >= offsets[grp_idx + 1] - row_idx = lin_idx - offsets[grp_idx] - self._locs[lin_idx] = (grp_idx, row_idx) - - self._offsets = None - _ = self.offsets # Reinitialize the offsets. - - def _validate_index(self) -> None: - """Validate that the index is appropriate for this dataset.""" - preproc_hash = _hash_obj(tuple(self.preprocessing_fn)) - with self: - try: - assert self._handle.attrs["ms_level"] == self.ms_level - assert self._handle.attrs["preprocessing_fn"] == preproc_hash - if self._annotated: - assert self._handle.attrs["annotated"] - except (KeyError, AssertionError): - raise ValueError( - f"'{self.path}' already exists, but was created with " - "incompatible parameters. Use 'overwrite=True' to " - "overwrite it." - ) - - self._reindex() - - def _get_parser( - self, - ms_data_file: PathLike, - ) -> MzmlParser | MzxmlParser | MgfParser: - """Get the parser for the MS data file. + Transform compatible data types into PyTorch tensors and + pad the m/z and intensities arrays of each mass spectrum with + zeros to be stacked into tensor. Parameters ---------- - ms_data_file : PathLike - The mass spectrometry data file to be parsed. + batch : iterable of dict + A batch of data. Returns ------- - MzmlParser, MzxmlParser, or MgfParser - The appropriate parser for the file. + dict of str, tensor or list + A dictionary mapping the columns of the lance dataset + to a PyTorch tensor or list of values. """ - kw_args = { - "ms_level": self.ms_level, - "valid_charge": self.valid_charge, - "preprocessing_fn": self.preprocessing_fn, - } + mz_array = nn.utils.rnn.pad_sequence( + [s.pop("mz_array") for s in batch], + batch_first=True, + ) - if ms_data_file.suffix.lower() == ".mzml": - return MzmlParser(ms_data_file, **kw_args) + intensity_array = nn.utils.rnn.pad_sequence( + [s.pop("intensity_array") for s in batch], + batch_first=True, + ) - if ms_data_file.suffix.lower() == ".mzxml": - return MzxmlParser(ms_data_file, **kw_args) + batch = torch.utils.data.default_collate(batch) + batch["mz_array"] = mz_array + batch["intensity_array"] = intensity_array - if ms_data_file.suffix.lower() == ".mgf": - return MgfParser(ms_data_file, **kw_args) + for key, val in batch.items(): + if isinstance(val, torch.Tensor) and torch.is_floating_point(val): + batch[key] = val.type(self.precision) - raise ValueError("Only mzML, mzXML, and MGF files are supported.") + return batch - def _assemble_metadata( + def loader( self, - parser: MzmlParser | MzxmlParser | MgfParser, - ) -> np.ndarray: - """Assemble the metadata. + precision: Literal[ + torch.float16, torch.float32, torch.float64 + ] = torch.float64, + **kwargs: dict, + ) -> DataLoader: + """Create a suitable PyTorch DataLoader.""" + self.precision = precision + if kwargs.get("collate_fn", False): + warnings.warn("The default collate_fn was overridden.") + else: + kwargs["collate_fn"] = self.collate_fn - Parameters - ---------- - parser : MzmlParser, MzxmlParser, MgfParser - The parser to use. + return DataLoader(self, **kwargs) - Returns - ------- - numpy.ndarray of shape (n_spectra,) - The file metadata. - """ - meta_types = [ - ("precursor_mz", np.float32), - ("precursor_charge", np.uint8), - ("offset", np.uint64), - ("scan_id", np.uint32), - ] - - metadata = np.empty(parser.n_spectra, dtype=meta_types) - metadata["precursor_mz"] = parser.precursor_mz - metadata["precursor_charge"] = parser.precursor_charge - metadata["offset"] = parser.offset - metadata["scan_id"] = parser.scan_id - return metadata - - def add_file(self, ms_data_file: PathLike) -> None: - """Add a mass spectrometry data file to the index. - Parameters - ---------- - ms_data_file : PathLike - The mass spectrometry data file to add. It must be in an mzML or - MGF file format and use an ``.mzML``, ``.mzXML``, or ``.mgf`` file - extension. - """ - ms_data_file = Path(ms_data_file) - if str(ms_data_file) in self._file_map: - return - - # Invalidate current offsets: - self._offsets = None - - # Read the file: - parser = self._get_parser(ms_data_file) - parser.read() - - # Create the tables: - metadata = self._assemble_metadata(parser) - - spectrum_types = [ - ("mz_array", np.float64), - ("intensity_array", np.float32), - ] - - spectra = np.zeros(parser.n_peaks, dtype=spectrum_types) - spectra["mz_array"] = parser.mz_arrays - spectra["intensity_array"] = parser.intensity_arrays - - # Write the tables: - with h5py.File(self.path, "a") as index: - group_index = len(index) - group = index.create_group(str(group_index)) - group.attrs["path"] = str(ms_data_file) - group.attrs["n_spectra"] = parser.n_spectra - group.attrs["n_peaks"] = parser.n_peaks - group.attrs["id_type"] = parser.id_type - - # Update overall stats: - index.attrs["n_spectra"] += parser.n_spectra - index.attrs["n_peaks"] += parser.n_peaks - - # Add the datasets: - group.create_dataset( - "metadata", - data=metadata, - ) +class SpectrumDataset(Dataset, CollateFnMixin): + """Store and access a collection of mass spectra. - group.create_dataset( - "spectra", - data=spectra, - ) + Parse and/or add mass spectra to an index in the + [lance data format](https://lancedb.github.io/lance/index.html). + This format enables fast random access to spectra for training. + This file is then served as a PyTorch Dataset, allowing spectra + to be accessed efficiently for training and inference. - try: - group.create_dataset( - "annotations", - data=parser.annotations, - dtype=h5py.string_dtype(), - ) - except (KeyError, AttributeError, TypeError): - pass - - self._file_map[str(ms_data_file)] = group_index - end_offset = self._file_offsets[-1] + parser.n_spectra - self._file_offsets = np.append(self._file_offsets, [end_offset]) - - # Update the locations: - grp_idx = len(self._file_offsets) - 2 - for row_idx in range(parser.n_spectra): - lin_idx = row_idx + self._file_offsets[-2] - self._locs[lin_idx] = (grp_idx, row_idx) - - def get_spectrum(self, idx: int) -> MassSpectrum: - """Access a mass spectrum. + If you wish to use an existing lance dataset, use the `from_lance()` + method. - Parameters - ---------- - idx : int - The index of the index of the mass spectrum to look-up. - Returns - ------- - tuple of numpy.ndarray - The m/z values, intensity values, precurosr m/z, precurosr charge, - and the spectrum annotation. - """ - group_index, row_index = self._locs[idx] - if self._handle is None: - raise RuntimeError("Use the context manager for access.") - - grp = self._handle[str(group_index)] - metadata = grp["metadata"] - spectra = grp["spectra"] - offsets = self.offsets[str(group_index)][row_index : row_index + 2] - - start_offset = offsets[0] - if offsets.shape[0] == 2: - stop_offset = offsets[1] - else: - stop_offset = spectra.shape[0] - - spectrum = spectra[start_offset:stop_offset] - precursor = metadata[row_index] - return MassSpectrum( - filename=grp.attrs["path"], - scan_id=f"{grp.attrs['id_type']}={metadata[row_index]['scan_id']}", - mz=np.array(spectrum["mz_array"]), - intensity=np.array(spectrum["intensity_array"]), - precursor_mz=precursor["precursor_mz"], - precursor_charge=precursor["precursor_charge"], - ) + Parameters + ---------- + spectra : polars.DataFrame, PathLike, or list of PathLike + Spectra to add to this collection. These may be a DataFrame parsed + with `depthcharge.spectra_to_df()`, parquet files created with + `depthcharge.spectra_to_parquet()`, or a peak file in the mzML, + mzXML, or MGF format. Additional spectra can be added later using + the `.add_spectra()` method. + path : PathLike, optional. + The name and path of the lance dataset. If the path does + not contain the `.lance` then it will be added. + If `None`, a file will be created in a temporary directory. + **kwargs : dict + Keyword arguments passed `depthcharge.spectra_to_stream()` for + peak files that are provided. This argument has no affect for + DataFrame or parquet file inputs. - def get_spectrum_id(self, idx: int) -> tuple[str, str]: - """Get the identifier for a mass spectrum. + Attributes + ---------- + peak_files : list of str + path : Path + """ - Parameters - ---------- - idx : int - The index of the mass spectrum in the SpectrumIndex. + def __init__( + self, + spectra: pl.DataFrame | PathLike | Iterable[PathLike], + path: PathLike | None = None, + **kwargs: dict, + ) -> None: + """Initialize a SpectrumDataset.""" + self._tmpdir = None + if path is None: + # Create a random temporary file: + self._tmpdir = TemporaryDirectory() + path = Path(self._tmpdir.name) / f"{uuid.uuid4()}.lance" - Returns - ------- - ms_data_file : str - The mass spectrometry data file from which the mass spectrum was - originally parsed. - identifier : str - The mass spectrum identifier, per PSI recommendations. - """ - group_index, row_index = self._locs[idx] - if self._handle is None: - raise RuntimeError("Use the context manager for access.") + self._path = Path(path) + if self._path.suffix != "lance": + self._path = path.with_suffix(".lance") - grp = self._handle[str(group_index)] - ms_data_file = grp.attrs["path"] - identifier = grp["metadata"][row_index]["scan_id"] - prefix = grp.attrs["id_type"] - return ms_data_file, f"{prefix}={identifier}" + # Now parse spectra. + if spectra is not None: + spectra = utils.listify(spectra) + batch = next(_get_records(spectra, **kwargs)) + lance.write_dataset( + _get_records(spectra, **kwargs), + self._path, + mode="overwrite", + schema=batch.schema, + ) - def loader(self, *args: tuple, **kwargs: dict) -> DataLoader: - """A PyTorch DataLoader for the mass spectra. + elif not self._path.exists(): + raise ValueError("No spectra were provided") + + self._dataset = lance.dataset(self._path) + + def add_spectra( + self, + spectra: pl.DataFrame | PathLike | Iterable[PathLike], + **kwargs: dict, + ) -> SpectrumDataset: + """Add mass spectrometry data to the lance dataset. + + Note that depthcharge does not verify whether the provided spectra + already exist in the lance dataset. Parameters ---------- - *args : tuple - Arguments passed initialize a torch.utils.data.DataLoader, - excluding ``dataset`` and ``collate_fn``. + spectra : polars.DataFrame, PathLike, or list of PathLike + Spectra to add to this collection. These may be a DataFrame parsed + with `depthcharge.spectra_to_df()`, parquet files created with + `depthcharge.spectra_to_parquet()`, or a peak file in the mzML, + mzXML, or MGF format. **kwargs : dict - Keyword arguments passed initialize a torch.utils.data.DataLoader, - excluding ``dataset`` and ``collate_fn``. - - Returns - ------- - torch.utils.data.DataLoader - A DataLoader for the mass spectra. + Keyword arguments passed `depthcharge.spectra_to_stream()` for + peak files that are provided. This argument has no affect for + DataFrame or parquet file inputs. """ - return DataLoader(self, *args, collate_fn=self.collate_fn, **kwargs) - - def __len__(self) -> int: - """The number of spectra in the index.""" - return self.n_spectra + spectra = utils.listify(spectra) + batch = next(_get_records(spectra, **kwargs)) + self._dataset = lance.write_dataset( + _get_records(spectra, **kwargs), + self._path, + mode="append", + schema=batch.schema, + ) - def __del__(self) -> None: - """Cleanup the temporary directory.""" - if self._tmpdir is not None: - self._tmpdir.cleanup() + return self - def __getitem__(self, idx: int) -> MassSpectrum: + def __getitem__(self, idx: int) -> dict[str, Any]: """Access a mass spectrum. Parameters ---------- idx : int - The overall index of the mass spectrum to retrieve. + The index of the index of the mass spectrum to look up. Returns ------- - tuple of numpy.ndarray - The m/z values, intensity values, precursor m/z, precurosr charge, - and the annotation (if available). + dict + A dictionary representing a row of the dataset. Each + key is a column and the value is the value for that + row. List columns are automatically converted to + PyTorch tensors if the nested data type is compatible. """ - if self._handle is None: - with self: - return self.get_spectrum(idx) - - return self.get_spectrum(idx) - - def __enter__(self) -> SpectrumDataset: - """Open the index file for reading.""" - if self._handle is None: - self._handle = h5py.File( - self.path, - "r", - rdcc_nbytes=int(3e8), - rdcc_nslots=1024000, - ) - return self + return { + k: _tensorize(v[0]) + for k, v in self._dataset.take([idx]).to_pydict().items() + } - def __exit__(self, *args: str) -> None: - """Close the HDF5 file.""" - self._handle.close() - self._handle = None + def __len__(self) -> int: + """The number of spectra in the lance dataset.""" + return self._dataset.count_rows() + + def __del__(self) -> None: + """Cleanup the temporary directory.""" + if self._tmpdir is not None: + self._tmpdir.cleanup() @property - def ms_files(self) -> list[str]: - """The files currently in the index.""" - return list(self._file_map.keys()) + def peak_files(self) -> list[str]: + """The files currently in the lance dataset.""" + return ( + self._dataset.to_table(columns=["peak_file"]) + .column(0) + .unique() + .to_pylist() + ) @property def path(self) -> Path: - """The path to the underyling HDF5 index file.""" + """The path to the underyling lance dataset.""" return self._path - @property - def ms_level(self) -> int: - """The MS level of tandem mass spectra in the collection.""" - return self._ms_level - - @property - def preprocessing_fn(self) -> list[Callable]: - """The functions for preprocessing MS data.""" - return self._preprocessing_fn - - @property - def valid_charge(self) -> list[int]: - """Valid precursor charges for spectra to be included.""" - return self._valid_charge - - @property - def annotated(self) -> bool: - """Whether or not the index contains spectrum annotations.""" - return self._annotated + @classmethod + def from_lance(cls, path: PathLike, **kwargs: dict) -> SpectrumDataset: + """Load a previously created lance dataset. - @property - def overwrite(self) -> bool: - """Overwrite a previous index?.""" - return self._overwrite + Parameters + ---------- + path : PathLike + The path of the lance dataset. + **kwargs : dict + Keyword arguments passed `depthcharge.spectra_to_stream()` for + peak files that are added. This argument has no affect for + DataFrame or parquet file inputs. + """ + return cls(spectra=None, path=path, **kwargs) - @property - def n_spectra(self) -> int: - """The total number of mass spectra in the index.""" - if self._handle is None: - with self: - return self._handle.attrs["n_spectra"] - return self._handle.attrs["n_spectra"] +class AnnotatedSpectrumDataset(SpectrumDataset): + """Store and access a collection of annotated mass spectra. - @property - def n_peaks(self) -> int: - """The total number of mass peaks in the index.""" - if self._handle is None: - with self: - return self._handle.attrs["n_peaks"] + Parse and/or add mass spectra to an index in the + [lance data format](https://lancedb.github.io/lance/index.html). + This format enables fast random access to spectra for training. + This file is then served as a PyTorch Dataset, allowing spectra + to be accessed efficiently for training and inference. - return self._handle.attrs["n_peaks"] + If you wish to use an existing lance dataset, use the `from_lance()` + method. - @property - def offsets(self) -> dict[str, np.array]: - """The offsets denoting where each spectrum starts.""" - if self._offsets is not None: - return self._offsets + Parameters + ---------- + spectra : polars.DataFrame, PathLike, or list of PathLike + Spectra to add to this collection. These may be a DataFrame parsed + with `depthcharge.spectra_to_df()`, parquet files created with + `depthcharge.spectra_to_parquet()`, or a peak file in the mzML, + mzXML, or MGF format. Additional spectra can be added later using + the `.add_spectra()` method. + annotations : str + The column name containing the annotations. + tokenizer : PeptideTokenizer + The tokenizer used to transform the annotations into PyTorch + tensors. + path : PathLike, optional. + The name and path of the lance dataset. If the path does + not contain the `.lance` then it will be added. + If ``None``, a file will be created in a temporary directory. + **kwargs : dict + Keyword arguments passed `depthcharge.spectra_to_stream()` for + peak files that are provided. This argument has no affect for + DataFrame or parquet file inputs. - self._offsets = { - k: v["metadata"]["offset"] for k, v in self._handle.items() - } + Attributes + ---------- + peak_files : list of str + path : Path + tokenizer : PeptideTokenizer + The tokenizer for the annotations. + annotations : str + The annotation column in the dataset. + """ - return self._offsets + def __init__( + self, + spectra: pl.DataFrame | PathLike | Iterable[PathLike], + annotations: str, + tokenizer: PeptideTokenizer, + path: PathLike = None, + **kwargs: dict, + ) -> None: + """Initialize an AnnotatedSpectrumDataset.""" + self.tokenizer = tokenizer + self.annotations = annotations + super().__init__( + spectra=spectra, + path=path, + **kwargs, + ) - @staticmethod def collate_fn( - batch: Iterable[MassSpectrum], - ) -> tuple[torch.Tensor, torch.Tensor]: - """The collate function for a SpectrumDataset. + self, + batch: Iterable[dict[Any]], + ) -> dict[Any]: + """The collate function for a AnnotatedSpectrumDataset. The mass spectra must be padded so that they fit nicely as a tensor. However, the padded elements are ignored during the subsequent steps. @@ -515,197 +326,166 @@ def collate_fn( Parameters ---------- batch : tuple of tuple of torch.Tensor - A batch of data from an SpectrumDataset. + A batch of data from an AnnotatedSpectrumDataset. Returns ------- - spectra : torch.Tensor of shape (batch_size, n_peaks, 2) - The mass spectra to sequence, where ``X[:, :, 0]`` are the m/z - values and ``X[:, :, 1]`` are their associated intensities. - precursors : torch.Tensor of shape (batch_size, 2) - The precursor mass and charge state. + dict of str, tensor or list + A dictionary mapping the columns of the lance dataset + to a PyTorch tensor or list of values. """ - return _collate_fn(batch)[0:2] + batch = super().collate_fn(batch) + batch[self.annotations] = self.tokenizer.tokenize( + batch[self.annotations] + ) + return batch + @classmethod + def from_lance( + cls, + path: PathLike, + annotations: str, + tokenizer: PeptideTokenizer, + **kwargs: dict, + ) -> AnnotatedSpectrumDataset: + """Load a previously created lance dataset. -class AnnotatedSpectrumDataset(SpectrumDataset): - """Store and access a collection of annotated mass spectra. + Parameters + ---------- + path : PathLike + The path of the lance dataset. + annotations : str + The column name containing the annotations. + tokenizer : PeptideTokenizer + The tokenizer used to transform the annotations into PyTorch + tensors. + **kwargs : dict + Keyword arguments passed `depthcharge.spectra_to_stream()` for + peak files that are added. This argument has no affect for + DataFrame or parquet file inputs. + """ + return cls( + spectra=None, + annotations=annotations, + tokenizer=tokenizer, + path=path, + **kwargs, + ) + + +class StreamingSpectrumDataset(IterableDataset, CollateFnMixin): + """Stream mass spectra from a file or DataFrame. + + While the on-disk dataset provided by `depthcharge.data.SpectrumDataset` + provides an excellent option for model training, this class provides + a PyTorch Dataset that is more suitable for inference. - This class parses one or more MGF file and converts it to our HDF5 index - format. This allows us to access spectra from many different files quickly - and without loading them all into memory. + When using a `StreamingSpectrumDataset`, the order of mass spectra + cannot be shuffled. + + The `batch_size` parameter for this class indepedent of the `batch_size` + of the PyTorch DataLoader. The former specified how many spectra are + simultaneously read into memory, whereas the latter specifies the batch + size served to a model. Additionally, this dataset should not be + used with a DataLoader set to `max_workers` > 1. Parameters ---------- - index_path : str - The name and path of the HDF5 file index. If the path does - not contain the `.h5` or `.hdf5` extension, `.hdf5` will be added. - ms_data_files : str or list of str, optional - The MGF to include in this collection. - ms_level : int, optional - The level of tandem mass spectra to use. - preprocessing_fn : Callable or Iterable[Callable], optional - The function(s) used to preprocess the mass spectra. ``None``, - the default, filters for the top 200 peaks above m/z 140, - square root transforms the intensities and scales them to unit norm. - See the preprocessing module for details and additional options. - valid_charge : Iterable[int], optional - Only consider spectra with the specified precursor charges. If `None`, - any precursor charge is accepted. - overwrite : bool - Overwrite previously indexed files? If ``False`` and new files are - provided, they will be appended to the collection. + spectra : polars.DataFrame, PathLike, or list of PathLike + Spectra to add to this collection. These may be a DataFrame parsed + with `depthcharge.spectra_to_df()`, parquet files created with + `depthcharge.spectra_to_parquet()`, or a peak file in the mzML, + mzXML, or MGF format. + batch_size : int + The batch size to use for loading mass spectra. Note that this is + independent from the batch size for the PyTorch DataLoader. + **kwargs : dict + Keyword arguments passed `depthcharge.spectra_to_stream()` for + peak files that are provided. This argument has no affect for + DataFrame or parquet file inputs. Attributes ---------- - ms_files : list of str - path : Path - ms_level : int - valid_charge : Optional[Iterable[int]] - overwrite : bool - n_spectra : int - n_peaks : int + batch_size : int + The batch size to use for loading mass spectra. """ - _annotated = True - def __init__( self, - tokenizer: PeptideTokenizer, - ms_data_files: PathLike | Iterable[PathLike] = None, - ms_level: int = 2, - preprocessing_fn: Callable | Iterable[Callable] | None = None, - valid_charge: Iterable[int] | None = None, - index_path: PathLike | None = None, - overwrite: bool = False, + spectra: pl.DataFrame | PathLike | Iterable[PathLike], + batch_size: int, + **kwargs: dict, ) -> None: - """Initialize an AnnotatedSpectrumIndex.""" - self.tokenizer = tokenizer - super().__init__( - ms_data_files=ms_data_files, - ms_level=ms_level, - preprocessing_fn=preprocessing_fn, - valid_charge=valid_charge, - index_path=index_path, - overwrite=overwrite, + """Initialize a StreamingSpectrumDataset.""" + self.batch_size = batch_size + self._spectra = utils.listify(spectra) + self._kwargs = kwargs + + def __iter__(self) -> dict[str, Any]: + """Yield a batch mass spectra.""" + records = _get_records( + self._spectra, + batch_size=self.batch_size, + **self._kwargs, ) + for batch in records: + for row in batch.to_pylist(): + yield {k: _tensorize(v) for k, v in row.items()} - def _get_parser(self, ms_data_file: str) -> MgfParser: - """Get the parser for the MS data file.""" - if ms_data_file.suffix.lower() == ".mgf": - return MgfParser( - ms_data_file, - ms_level=self.ms_level, - annotations=True, - preprocessing_fn=self.preprocessing_fn, - ) + def loader(self, **kwargs: dict) -> DataLoader: + """Create a suitable PyTorch DataLoader.""" + if kwargs.get("num_workers", 0) > 1: + warnings.warn("'num_workers' > 1 may have unexpected behavior.") - raise ValueError("Only MGF files are currently supported.") + return super().loader(**kwargs) - def get_spectrum(self, idx: int) -> MassSpectrum: - """Access a mass spectrum. - Parameters - ---------- - idx : int - The index of the mass spectrum in the AnnotatedSpectrumIndex. +def _get_records( + data: list[pl.DataFrame | PathLike], **kwargs: dict +) -> Generator[pa.RecordBatch]: + """Yields RecordBatches for data. - Returns - ------- - MassSpectrum - The mass spectrum, labeled with its annotation. - """ - spectrum = super().get_spectrum(idx) - group_index, row_index = self._locs[idx] - grp = self._handle[str(group_index)] - annotations = grp["annotations"] - spectrum.label = annotations[row_index].decode() - return spectrum - - @property - def annotations(self) -> np.ndarray[str]: - """Retrieve all of the annotations in the index.""" - annotations = [] - for grp in self._handle.values(): + Parameters + ---------- + data : list of polars.DataFrame or PathLike + The data to add. + **kwargs : dict + Keyword arguments for the parser. + """ + for spectra in data: + try: + spectra = spectra.lazy().collect().to_arrow().to_batches() + except AttributeError: try: - annotations.append(grp["annotations"]) - except KeyError: - pass - - return np.concatenate(annotations) - - def collate_fn( - self, - batch: Iterable[MassSpectrum], - ) -> tuple[torch.Tensor, torch.Tensor, np.ndarray[str]]: - """The collate function for a AnnotatedSpectrumDataset. - - The mass spectra must be padded so that they fit nicely as a tensor. - However, the padded elements are ignored during the subsequent steps. + spectra = pq.ParquetFile(spectra).iter_batches() + except (pa.ArrowInvalid, TypeError): + spectra = arrow.spectra_to_stream(spectra, **kwargs) - Parameters - ---------- - batch : tuple of tuple of torch.Tensor - A batch of data from an AnnotatedSpectrumDataset. + for batch in spectra: + yield batch - Returns - ------- - spectra : torch.Tensor of shape (batch_size, n_peaks, 2) - The mass spectra to sequence, where ``X[:, :, 0]`` are the m/z - values and ``X[:, :, 1]`` are their associated intensities. - precursors : torch.Tensor of shape (batch_size, 2) - The precursor mass and charge state. - annotations : np.ndarray[str] - The spectrum annotations. - """ - spectra, precursors, annotations = _collate_fn(batch) - tokens = self.tokenizer.tokenize(annotations) - return spectra, precursors, tokens - - -def _collate_fn( - batch: Iterable[MassSpectrum], -) -> tuple[torch.Tensor, torch.Tensor, list[str | None]]: - """The collate function for a SpectrumDataset. - The mass spectra must be padded so that they fit nicely as a tensor. - However, the padded elements are ignored during the subsequent steps. +def _tensorize(obj: Any) -> Any: # noqa: ANN401 + """Turn lists into tensors. Parameters ---------- - batch : tuple of tuple of torch.Tensor - A batch of data from an SpectrumDataset. + obj : any object + If a list, attempt to make a tensor. If not or if it fails, + return the obj unchanged. Returns ------- - spectra : torch.Tensor of shape (batch_size, n_peaks, 2) - The mass spectra to sequence, where ``X[:, :, 0]`` are the m/z - values and ``X[:, :, 1]`` are their associated intensities. - precursors : torch.Tensor of shape (batch_size, 2) - The precursor mass and charge state. - annotations : np.ndarray of shape (batch_size,) - The annotations, if any exist. + Any + Whatever type the object is, unless its been transformed to + a PyTorch tensor. """ - spectra = [] - masses = [] - charges = [] - annotations = [] - for spec in batch: - spectra.append(spec.to_tensor()) - masses.append(spec.precursor_mass) - charges.append(spec.precursor_charge) - annotations.append(spec.label) - - precursors = torch.vstack([torch.tensor(masses), torch.tensor(charges)]) - spectra = torch.nn.utils.rnn.pad_sequence( - spectra, - batch_first=True, - ) - return spectra, precursors.T.float(), annotations - - -def _hash_obj(obj: Any) -> str: # noqa: ANN401 - """SHA1 hash for a picklable object.""" - out = hashlib.sha1() - out.update(dill.dumps(obj)) - return out.hexdigest() + if not isinstance(obj, list): + return obj + + try: + return torch.tensor(obj) + except ValueError: + pass + + return obj diff --git a/depthcharge/primitives.py b/depthcharge/primitives.py index 2c3ff26..80430b6 100644 --- a/depthcharge/primitives.py +++ b/depthcharge/primitives.py @@ -358,6 +358,7 @@ def __init__( scan_id: str, mz: ArrayLike, intensity: ArrayLike, + ms_level: int = None, retention_time: float | None = None, ion_mobility: float | None = None, precursor_mz: float | None = None, @@ -367,6 +368,7 @@ def __init__( """Initialize a MassSpectrum.""" self.filename = filename self.scan_id = scan_id + self.ms_level = ms_level self.label = label # Not currently supported by spectrum_utils: diff --git a/depthcharge/testing.py b/depthcharge/testing.py new file mode 100644 index 0000000..353b7a0 --- /dev/null +++ b/depthcharge/testing.py @@ -0,0 +1,64 @@ +"""Helper functions for testing.""" +from typing import Any + +import torch + + +def assert_dicts_equal( + dict1: dict[Any], dict2: dict[Any], **kwargs: dict +) -> None: + """Assert two dictionary are equal, while considering tensors. + + Parameters + ---------- + dict1 : dict + The first dictionary to compare. + dict2 : dict + The second dictionary to compare. + **kwargs : dict + Keyword arguments passed to `torch.testing.assert_close` + + Raises + ------ + AssertionError + Indicates that the two dictionaries are not equal. + """ + bad_keys = [] + assert set(dict1.keys()) == set(dict2.keys()) + + for key, val1 in dict1.items(): + try: + val2 = dict2[key] + except KeyError: + bad_keys.append(key) + continue + + try: + assert type(val1) is type(val2) + except AssertionError: + bad_keys.append(key) + continue + + try: + assert val1 == val2 + continue + except AssertionError: + bad_keys.append(key) + continue + except RuntimeError: + pass + + try: + # Works on numpy arrays too. + torch.testing.assert_close(val1, val2, **kwargs) + continue + except AssertionError: + bad_keys.append(key) + continue + + if not bad_keys: + return + + raise AssertionError( + f"Dictionaries did not match at the following keys: {bad_keys}" + ) diff --git a/depthcharge/tokenizers/peptides.py b/depthcharge/tokenizers/peptides.py index fe251c8..bddad42 100644 --- a/depthcharge/tokenizers/peptides.py +++ b/depthcharge/tokenizers/peptides.py @@ -94,6 +94,21 @@ def __init__( super().__init__(list(self.residues.keys())) + def __getstate__(self) -> dict: + """How to pickle the object.""" + self.residues = dict(self.residues) + return self.__dict__ + + def __setstate__(self, state: dict) -> None: + """How to unpickle the object.""" + self.__dict__ = state + residues = self.residues + self.residues = nb.typed.Dict.empty( + nb.types.unicode_type, + nb.types.float64, + ) + self.residues.update(residues) + def split(self, sequence: str) -> list[str]: """Split a ProForma peptide sequence. diff --git a/depthcharge/transformers/spectra.py b/depthcharge/transformers/spectra.py index d2ea493..3f5b826 100644 --- a/depthcharge/transformers/spectra.py +++ b/depthcharge/transformers/spectra.py @@ -9,6 +9,12 @@ class SpectrumTransformerEncoder(torch.nn.Module): """A Transformer encoder for input mass spectra. + Use this PyTorch module to embed mass spectra. By default, nothing + other than the m/z and intensity arrays for each mass spectrum are + considered. However, arbitrary information can be integrated into the + spectrum representation by subclassing this class and overwriting the + `precursor_hook()` method. + Parameters ---------- d_model : int, optional @@ -24,7 +30,22 @@ class SpectrumTransformerEncoder(torch.nn.Module): dropout : float, optional The dropout probability for all layers. peak_encoder : PeakEncoder or bool, optional - Sinusoidal encodings m/z and intensityvalues of each peak. + The function to encode the (m/z, intensity) tuples of each mass + spectrum. `True` uses the default sinusoidal encoding and `False` + instead performs a 1 to `d_model` learned linear projection. + + Attributes + ---------- + d_model : int + nhead : int + dim_feedforward : int + n_layers : int + dropout : float + peak_encoder : torch.nn.Module or Callable + The function to encode the (m/z, intensity) tuples of each mass + spectrum. + transformer_encoder : torch.nn.TransformerEncoder + The Transformer encoder layers. """ def __init__( @@ -38,8 +59,11 @@ def __init__( ) -> None: """Initialize a SpectrumEncoder.""" super().__init__() - - self.latent_spectrum = torch.nn.Parameter(torch.randn(1, 1, d_model)) + self._d_model = d_model + self._nhead = nhead + self._dim_feedforward = dim_feedforward + self._n_layers = n_layers + self._dropout = dropout if callable(peak_encoder): self.peak_encoder = peak_encoder @@ -62,20 +86,48 @@ def __init__( num_layers=n_layers, ) + @property + def d_model(self) -> int: + """The latent dimensionality of the model.""" + return self._d_model + + @property + def nhead(self) -> int: + """The number of attention heads.""" + return self._nhead + + @property + def dim_feedforward(self) -> int: + """The dimensionality of the Transformer feedforward layers.""" + return self._dim_feedforward + + @property + def n_layers(self) -> int: + """The number of Transformer layers.""" + return self._n_layers + + @property + def dropout(self) -> float: + """The dropout for the transformer layers.""" + return self._dropout + def forward( self, - spectra: torch.Tensor, + mz_array: torch.Tensor, + intensity_array: torch.Tensor, + **kwargs: dict, ) -> tuple[torch.Tensor, torch.Tensor]: - """Embed a mass spectrum. + """Embed a batch of mass spectra. Parameters ---------- - spectra : torch.Tensor of shape (n_spectra, n_peaks, 2) - The spectra to embed. Axis 0 represents a mass spectrum, axis 1 - contains the peaks in the mass spectrum, and axis 2 is essentially - a 2-tuple specifying the m/z-intensity pair for each peak. These - should be zero-padded, such that all of the spectra in the batch - are the same length. + mz_array : torch.Tensor of shape (n_spectra, n_peaks) + The zero-padded m/z dimension for a batch of mass spectra. + intensity_array : torch.Tensor of shape (n_spectra, n_peaks) + The zero-padded intensity dimension for a batch of mass spctra. + **kwargs : dict + Additional fields provided by the data loader. These may be + used by overwriting the `precursor_hook()` method in a subclass. Returns ------- @@ -85,20 +137,60 @@ def forward( mem_mask : torch.Tensor The memory mask specifying which elements were padding in X. """ + spectra = torch.stack([mz_array, intensity_array], dim=2) + n_batch = spectra.shape[0] zeros = ~spectra.sum(dim=2).bool() - mask = [ - torch.tensor([[False]] * spectra.shape[0]).type_as(zeros), - zeros, - ] - mask = torch.cat(mask, dim=1) + mask = torch.cat( + [torch.tensor([[False]] * n_batch).type_as(zeros), zeros], dim=1 + ) peaks = self.peak_encoder(spectra) - # Add the spectrum representation to each input: - latent_spectra = self.latent_spectrum.expand(peaks.shape[0], -1, -1) + # Add the precursor information: + latent_spectra = self.precursor_hook( + mz_array=mz_array, + intensity_array=intensity_array, + **kwargs, + ) - peaks = torch.cat([latent_spectra, peaks], dim=1) + peaks = torch.cat([latent_spectra[:, None, :], peaks], dim=1) return self.transformer_encoder(peaks, src_key_padding_mask=mask), mask + def precursor_hook( + self, + mz_array: torch.Tensor, + intensity_array: torch.Tensor, + **kwargs: dict, + ) -> torch.Tensor: + """Define how additional information in the batch may be used. + + Overwrite this method to define custom functionality dependent on + information in the batch. Examples would be to incorporate any + combination of the mass, charge, retention time, or + ion mobility of a precursor ion. + + The representation returned by this method is preprended to the + peak representations that are fed into the Transformer encoder and + ultimately contribute to the spectrum representation that is the + first element of the sequence in the model output. + + By default, this method returns a tensor of zeros. + + Parameters + ---------- + mz_array : torch.Tensor of shape (n_spectra, n_peaks) + The zero-padded m/z dimension for a batch of mass spectra. + intensity_array : torch.Tensor of shape (n_spectra, n_peaks) + The zero-padded intensity dimension for a batch of mass spctra. + **kwargs : dict + The additional data passed with the batch. + + Returns + ------- + torch.Tensor of shape (batch_size, d_model) + The precursor representations. + """ + return torch.zeros((mz_array.shape[0], self.d_model)).type_as(mz_array) + @property def device(self) -> torch.device: """The current device for the model.""" diff --git a/depthcharge/utils.py b/depthcharge/utils.py index a0f4511..3b3255b 100644 --- a/depthcharge/utils.py +++ b/depthcharge/utils.py @@ -1,63 +1,18 @@ """Common utility functions.""" from typing import Any +import polars as pl + def listify(obj: Any) -> list[Any]: # noqa: ANN401 """Turn an object into a list, but don't split strings.""" try: - assert not isinstance(obj, str) + invalid = [str, pl.DataFrame, pl.LazyFrame] + if any(isinstance(obj, c) for c in invalid): + raise TypeError + iter(obj) except (AssertionError, TypeError): obj = [obj] return list(obj) - - -def check_int(integer: int, name: str) -> int: - """Verify that an object is an integer, or coercible to one. - - Parameters - ---------- - integer : int - The integer to check. - name : str - The name to print in the error message if it fails. - - Returns - ------- - int - The coerced integer. - """ - if isinstance(integer, int): - return integer - - # Else if it is a float: - coerced = int(integer) - if coerced != integer: - raise ValueError(f"'{name}' must be an integer.") - - return coerced - - -def check_positive_int(integer: int, name: str) -> int: - """Verify that an object is an integer and positive. - - Parameters - ---------- - integer : int - The integer to check. - name : str - The name to print in the error message if it fails. - - Returns - ------- - int - The coerced integer. - """ - try: - integer = check_int(integer, name) - assert integer > 0 - except (ValueError, AssertionError): - raise ValueError(f"'{name}' must be a positive integer.") - - return integer diff --git a/depthcharge/version.py b/depthcharge/version.py index f178557..f8771ad 100644 --- a/depthcharge/version.py +++ b/depthcharge/version.py @@ -1,5 +1,5 @@ """Get the version information.""" -from importlib.metadata import version, PackageNotFoundError +from importlib.metadata import PackageNotFoundError, version def _get_version() -> str: diff --git a/pyproject.toml b/pyproject.toml index 14d10d8..fdf615b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,12 +15,14 @@ classifiers = [ ] requires-python = ">=3.10" dependencies = [ - "torch>=1.11.0", + "torch>=2.0.0", + "polars>=0.19.0", + "pyarrow>=12.0.1", + "pylance>=0.7.5", "pyteomics>=4.4.2", "numpy>=1.18.1", "numba>=0.48.0", "lxml>=4.9.1", - "h5py>=3.7.0", "einops>=0.4.1", "tqdm>=4.65.0", "lark>=1.1.4", diff --git a/tests/conftest.py b/tests/conftest.py index b3c6755..aa19076 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,6 @@ def mgf_small(tmp_path): return _create_mgf(peptides, mgf_file) -# Utility functions ----------------------------------------------------------- def _create_mgf_entry(peptide, charge=2): """Create a MassIVE-KB style MGF entry for a single PSM. diff --git a/tests/unit_tests/test_data/test_arrow.py b/tests/unit_tests/test_data/test_arrow.py new file mode 100644 index 0000000..540d5f8 --- /dev/null +++ b/tests/unit_tests/test_data/test_arrow.py @@ -0,0 +1,121 @@ +"""Test the arrow functionality.""" +import polars as pl +import pyarrow as pa +import pytest + +from depthcharge.data.arrow import ( + spectra_to_df, + spectra_to_parquet, + spectra_to_stream, +) +from depthcharge.data.fields import CustomField +from depthcharge.data.preprocessing import scale_to_unit_norm + +METADATA_DF1 = pl.DataFrame({"scan_id": [501, 507, 10], "blah": [True] * 3}) +METADATA_DF2 = pl.DataFrame( + { + "scan_id": [501, 507, 10], + "blah": [True] * 3, + "peak_file": ["TMT10-Trail-8.mzML"] * 3, + }, +) +PARAM_NAMES = [ + "ms_level", + "preprocessing_fn", + "valid_charge", + "custom_fields", + "metadata_df", + "progress", + "shape", +] + +custom_field = CustomField("index", lambda x: x["index"], pa.int64()) +PARAM_VALS = [ + (2, None, None, None, None, True, (4, 7)), + (1, None, None, None, None, True, (4, 7)), + (3, None, None, None, None, True, (3, 7)), + (2, None, [3], None, None, True, (3, 7)), + (None, None, None, None, None, True, (11, 7)), + (2, scale_to_unit_norm, None, custom_field, None, True, (4, 8)), + (2, None, None, None, METADATA_DF1, True, (4, 8)), + (2, None, None, None, METADATA_DF2, False, (4, 8)), +] + + +@pytest.mark.parametrize(PARAM_NAMES, PARAM_VALS) +def test_to_df( + real_mzml, + ms_level, + preprocessing_fn, + valid_charge, + custom_fields, + metadata_df, + progress, + shape, +): + """Test parsing to DataFrame.""" + parsed = spectra_to_df( + real_mzml, + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + metadata_df=metadata_df, + progress=progress, + ) + + assert parsed.shape == shape + + +@pytest.mark.parametrize(PARAM_NAMES, PARAM_VALS) +def test_to_parquet( + tmp_path, + real_mzml, + ms_level, + preprocessing_fn, + valid_charge, + custom_fields, + metadata_df, + progress, + shape, +): + """Test parsing to DataFrame.""" + out = spectra_to_parquet( + real_mzml, + parquet_file=tmp_path / "test", + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + metadata_df=metadata_df, + progress=progress, + ) + + parsed = pl.read_parquet(out) + assert parsed.shape == shape + + +@pytest.mark.parametrize(PARAM_NAMES, PARAM_VALS) +def test_to_stream( + real_mzml, + ms_level, + preprocessing_fn, + valid_charge, + custom_fields, + metadata_df, + progress, + shape, +): + """Test parsing to DataFrame.""" + out = spectra_to_stream( + real_mzml, + batch_size=1, + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + metadata_df=metadata_df, + progress=progress, + ) + parsed = pl.from_arrow(list(out)) + assert parsed.shape == shape diff --git a/tests/unit_tests/test_data/test_datasets.py b/tests/unit_tests/test_data/test_datasets.py index 33984ac..a2a7be8 100644 --- a/tests/unit_tests/test_data/test_datasets.py +++ b/tests/unit_tests/test_data/test_datasets.py @@ -1,16 +1,20 @@ """Test the datasets.""" -import functools +import pickle import shutil +import pyarrow as pa import pytest import torch from depthcharge.data import ( AnnotatedSpectrumDataset, + CustomField, PeptideDataset, SpectrumDataset, + StreamingSpectrumDataset, + arrow, ) -from depthcharge.primitives import MassSpectrum +from depthcharge.testing import assert_dicts_equal from depthcharge.tokenizers import PeptideTokenizer @@ -20,20 +24,13 @@ def tokenizer(): return PeptideTokenizer() -def test_spectrum_id(mgf_small, tmp_path): - """Test the mgf index.""" - mgf_small2 = tmp_path / "mgf_small2.mgf" - shutil.copy(mgf_small, mgf_small2) - - dataset = SpectrumDataset([mgf_small, mgf_small2]) - with dataset: - assert dataset.get_spectrum_id(0) == (str(mgf_small), "index=0") - assert dataset.get_spectrum_id(3) == (str(mgf_small2), "index=1") +def test_addition(mgf_small, tmp_path): + """Testing adding a file.""" + dataset = SpectrumDataset(mgf_small, path=tmp_path / "test") + assert len(dataset) == 2 - dataset = AnnotatedSpectrumDataset(tokenizer, [mgf_small, mgf_small2]) - with dataset: - assert dataset.get_spectrum_id(0) == (str(mgf_small), "index=0") - assert dataset.get_spectrum_id(3) == (str(mgf_small2), "index=1") + dataset = dataset.add_spectra(mgf_small) + assert len(dataset) == 4 def test_indexing(mgf_small, tmp_path): @@ -41,237 +38,105 @@ def test_indexing(mgf_small, tmp_path): mgf_small2 = tmp_path / "mgf_small2.mgf" shutil.copy(mgf_small, mgf_small2) - dataset = SpectrumDataset([mgf_small, mgf_small2]) - spec = dataset[0] - assert isinstance(spec, MassSpectrum) - assert spec.label is None - - dataset = AnnotatedSpectrumDataset(tokenizer, [mgf_small, mgf_small2]) - spec = dataset[0] - assert isinstance(spec, MassSpectrum) - assert spec.label == "LESLIEK" - assert spec.to_tensor().shape[1] == 2 - assert dataset[3].label == "EDITHR" - - -def test_valid_charge(mgf_medium, tmp_path): - """Test that the valid_charge argument works.""" - mkindex = functools.partial( - SpectrumDataset, - index_path=tmp_path / "index.hdf5", - ms_data_files=mgf_medium, - overwrite=True, + dataset = SpectrumDataset( + [mgf_small, mgf_small2], + path=tmp_path / "test", ) - index = mkindex() - assert index.n_spectra == 100 - - index = mkindex(valid_charge=[0, 2, 3, 4]) - assert index.n_spectra == 100 - - index = mkindex(valid_charge=[2, 3, 4]) - assert index.n_spectra == 99 - - index = mkindex(valid_charge=[2, 3]) - assert index.n_spectra == 98 - - -def test_mgf(mgf_small, tmp_path): - """Test the mgf index.""" - mgf_small2 = tmp_path / "mgf_small2.mgf" - shutil.copy(mgf_small, mgf_small2) - - index = SpectrumDataset([mgf_small, mgf_small2]) - assert index.ms_files == [str(mgf_small), str(mgf_small2)] - assert index.ms_level == 2 - assert not index.annotated - assert not index.overwrite - assert index.n_spectra == 4 - assert index.n_peaks == 66 - - with index: - assert index.get_spectrum_id(0) == (str(mgf_small), "index=0") - assert index.get_spectrum_id(3) == (str(mgf_small2), "index=1") - - index = AnnotatedSpectrumDataset(tokenizer, [mgf_small, mgf_small2]) - assert index.ms_files == [str(mgf_small), str(mgf_small2)] - assert index.ms_level == 2 - assert index.annotated - assert not index.overwrite - assert index.n_spectra == 4 - assert index.n_peaks == 66 - - with index: - assert index.get_spectrum_id(0) == (str(mgf_small), "index=0") - assert index.get_spectrum_id(3) == (str(mgf_small2), "index=1") - - -def test_mzml_index(real_mzml, tmp_path): - """Test an mzML index.""" - real_mzml2 = tmp_path / "real_mzml2.mzML" - shutil.copy(real_mzml, real_mzml2) - - index = SpectrumDataset([real_mzml, real_mzml2], 2, []) - assert index.ms_files == [str(real_mzml), str(real_mzml2)] - assert index.ms_level == 2 - assert not index.annotated - assert not index.overwrite - assert index.n_spectra == 8 - assert index.n_peaks == 726 - - with index: - assert index.get_spectrum_id(0) == (str(real_mzml), "scan=501") - assert index.get_spectrum_id(7) == (str(real_mzml2), "scan=510") - - # MS3 - index = SpectrumDataset([real_mzml, real_mzml2], 3, []) - assert index.ms_files == [str(real_mzml), str(real_mzml2)] - assert index.ms_level == 3 - assert index.n_spectra == 6 - assert index.n_peaks == 194 - - with index: - assert index.get_spectrum_id(0) == (str(real_mzml), "scan=502") - assert index.get_spectrum_id(5) == (str(real_mzml2), "scan=508") - - # MS1 - index = SpectrumDataset([real_mzml, real_mzml2], 1, []) - assert index.ms_files == [str(real_mzml), str(real_mzml2)] - assert index.ms_level == 1 - assert index.n_spectra == 8 - assert index.n_peaks == 4316 - - with index: - assert index.get_spectrum_id(0) == (str(real_mzml), "scan=500") - assert index.get_spectrum_id(5) == (str(real_mzml2), "scan=503") - - -def test_mzxml_index(real_mzxml, tmp_path): - """Test an mzXML index.""" - real_mzxml2 = tmp_path / "real_mzxml2.mzXML" - shutil.copy(real_mzxml, real_mzxml2) - - index = SpectrumDataset([real_mzxml, real_mzxml2], 2, []) - assert index.ms_files == [str(real_mzxml), str(real_mzxml2)] - assert index.ms_level == 2 - assert not index.annotated - assert not index.overwrite - assert index.n_spectra == 8 - assert index.n_peaks == 726 - - with index: - assert index.get_spectrum_id(0) == (str(real_mzxml), "scan=501") - assert index.get_spectrum_id(7) == (str(real_mzxml2), "scan=510") - - # MS3 - index = SpectrumDataset([real_mzxml, real_mzxml2], 3, []) - assert index.ms_files == [str(real_mzxml), str(real_mzxml2)] - assert index.ms_level == 3 - assert index.n_spectra == 6 - assert index.n_peaks == 194 - - with index: - assert index.get_spectrum_id(0) == (str(real_mzxml), "scan=502") - assert index.get_spectrum_id(5) == (str(real_mzxml2), "scan=508") - - # MS1 - index = SpectrumDataset([real_mzxml, real_mzxml2], 1, []) - assert index.ms_files == [str(real_mzxml), str(real_mzxml2)] - assert index.ms_level == 1 - assert index.n_spectra == 8 - assert index.n_peaks == 4316 - - with index: - assert index.get_spectrum_id(0) == (str(real_mzxml), "scan=500") - assert index.get_spectrum_id(5) == (str(real_mzxml2), "scan=503") - - -def test_spectrum_index_reuse(mgf_small, tmp_path): - """Reuse a previously created (annotated) spectrum index.""" - plain_index = tmp_path / "plain.hdf5" - ann_index = tmp_path / "ann.hdf5" - - index = SpectrumDataset(mgf_small, index_path=plain_index) - index2 = SpectrumDataset(index_path=plain_index) - assert index.ms_level == index2.ms_level - assert index.annotated == index2.annotated - assert not index2.annotated - assert index.n_peaks == index2.n_peaks - assert index.n_spectra == index2.n_spectra - - index3 = AnnotatedSpectrumDataset( - tokenizer, mgf_small, index_path=ann_index - ) - index4 = AnnotatedSpectrumDataset(tokenizer, index_path=ann_index) - assert index3.ms_level == index4.ms_level - assert index3.annotated == index4.annotated - assert index4.annotated - assert index3.n_peaks == index4.n_peaks - assert index3.n_spectra == index4.n_spectra - - # An annotated spectrum dataset may be loaded as a spectrum dataset, - # but not vice versa: - SpectrumDataset(index_path=ann_index) - with pytest.raises(ValueError): - AnnotatedSpectrumDataset(tokenizer, index_path=plain_index) - - # Verify we invalidate correctly: - with pytest.raises(ValueError): - SpectrumDataset(index_path=plain_index, ms_level=1) - - with pytest.raises(ValueError): - SpectrumDataset(index_path=plain_index, preprocessing_fn=[]) - - with pytest.raises(ValueError): - AnnotatedSpectrumDataset(tokenizer, index_path=ann_index, ms_level=3) - - with pytest.raises(ValueError): - AnnotatedSpectrumDataset( - tokenizer, index_path=plain_index, preprocessing_fn=[] - ) - + assert dataset.path == tmp_path / "test.lance" -def test_spectrum_indexing_bug(tmp_path, mgf_small): - """Test that we've fixed reindexing upon reload.""" - dset1 = AnnotatedSpectrumDataset( - PeptideTokenizer(), mgf_small, 2, index_path=tmp_path / "test.hdf5" + spec = dataset[0] + assert len(spec) == 7 + assert spec["peak_file"] == "small.mgf" + assert spec["scan_id"] == 0 + assert spec["ms_level"] == 2 + assert (spec["precursor_mz"] - 416.2448) < 0.001 + + dataset = AnnotatedSpectrumDataset( + [mgf_small, mgf_small2], + tokenizer, + "seq", + tmp_path / "test.lance", + preprocessing_fn=[], + custom_fields=CustomField( + "seq", lambda x: x["params"]["seq"], pa.string() + ), ) - - dset2 = AnnotatedSpectrumDataset( - PeptideTokenizer(), mgf_small, 2, index_path=tmp_path / "test.hdf5" + spec = dataset[0] + assert len(spec) == 8 + assert spec["seq"] == "LESLIEK" + assert spec["mz_array"].shape == (14,) + + spec2 = dataset[3] + assert spec2["seq"] == "EDITHR" + assert spec2["mz_array"].shape == (24,) + + +def test_load(tmp_path, mgf_small): + """Test saving and loading a dataset.""" + db_path = tmp_path / "test.lance" + + AnnotatedSpectrumDataset( + mgf_small, + tokenizer, + "seq", + db_path, + preprocessing_fn=[], + custom_fields=CustomField( + "seq", lambda x: x["params"]["seq"], pa.string() + ), ) - assert dset1._locs == dset2._locs + dataset = AnnotatedSpectrumDataset.from_lance(db_path, "seq", tokenizer) + spec = dataset[0] + assert len(spec) == 8 + assert spec["seq"] == "LESLIEK" + assert spec["mz_array"].shape == (14,) -def test_preprocessing_fn(mgf_small): - """Test preprocessing functions.""" - dset = SpectrumDataset(mgf_small) - loader = dset.loader(batch_size=1, num_workers=0) + spec2 = dataset[1] + assert spec2["seq"] == "EDITHR" + assert spec2["mz_array"].shape == (24,) - spec, *_ = next(iter(loader)) - assert (spec[:, :, 1] < 1).all() + dataset = SpectrumDataset.from_lance(db_path) + spec = dataset[0] + assert len(spec) == 8 + assert spec["peak_file"] == "small.mgf" + assert spec["scan_id"] == 0 + assert spec["ms_level"] == 2 + assert (spec["precursor_mz"] - 416.2448) < 0.001 + + +def test_formats(tmp_path, real_mgf, real_mzml, real_mzxml): + """Test all of the supported formats.""" + df = arrow.spectra_to_df(real_mgf) + parquet = arrow.spectra_to_parquet( + real_mgf, + parquet_file=tmp_path / "test.parquet", + ) - dset = SpectrumDataset(mgf_small, preprocessing_fn=[]) - loader = dset.loader(batch_size=1, num_workers=0) + data = [df, real_mgf, real_mzml, real_mzxml, parquet] + for input_type in data: + SpectrumDataset( + spectra=input_type, + path=tmp_path / "test", + ) - spec, *_ = next(iter(loader)) - assert (spec[:, :, 1] == 1).all() - def my_func(spec): - """A simple test function.""" - spec.intensity[:] = 2.0 - return spec +def test_streaming_spectra(mgf_small): + """Test the streaming dataset.""" + streamer = StreamingSpectrumDataset(mgf_small, batch_size=1) + spec = next(iter(streamer)) + expected = SpectrumDataset(mgf_small)[0] + assert_dicts_equal(spec, expected) - dset = SpectrumDataset(mgf_small, preprocessing_fn=my_func) - loader = dset.loader(batch_size=1, num_workers=0) - spec, *_ = next(iter(loader)) - assert (spec[:, :, 1] == 2).all() + streamer = StreamingSpectrumDataset(mgf_small, batch_size=2) + spec = next(iter(streamer)) + assert_dicts_equal(spec, expected) -def test_peptide_dataset(): +def test_peptide_dataset(tokenizer): """Test the peptide dataset.""" - tokenizer = PeptideTokenizer() seqs = ["LESLIEK", "EDITHR"] charges = torch.tensor([2, 3]) dset = PeptideDataset(tokenizer, seqs, charges) @@ -296,3 +161,35 @@ def test_peptide_dataset(): torch.testing.assert_close(dset.tokens, tokenizer.tokenize(seqs)) torch.testing.assert_close(dset.charges, charges) + + +def test_pickle(tokenizer, tmp_path, mgf_small): + """Test that datasets can be pickled.""" + dataset = SpectrumDataset(mgf_small, path=tmp_path / "test") + pkl_file = tmp_path / "test.pkl" + with pkl_file.open("wb+") as pkl: + pickle.dump(dataset, pkl) + + with pkl_file.open("rb") as pkl: + loaded = pickle.load(pkl) + + assert len(dataset) == len(loaded) + + dataset = AnnotatedSpectrumDataset( + [mgf_small], + tokenizer, + "seq", + tmp_path / "test.lance", + custom_fields=CustomField( + "seq", lambda x: x["params"]["seq"], pa.string() + ), + ) + pkl_file = tmp_path / "test.pkl" + + with pkl_file.open("wb+") as pkl: + pickle.dump(dataset, pkl) + + with pkl_file.open("rb") as pkl: + loaded = pickle.load(pkl) + + assert len(dataset) == len(loaded) diff --git a/tests/unit_tests/test_data/test_loaders.py b/tests/unit_tests/test_data/test_loaders.py index b06cded..18d9779 100644 --- a/tests/unit_tests/test_data/test_loaders.py +++ b/tests/unit_tests/test_data/test_loaders.py @@ -1,44 +1,70 @@ """Test PyTorch DataLoaders.""" +import pyarrow as pa +import pytest import torch from depthcharge.data import ( AnnotatedSpectrumDataset, + CustomField, PeptideDataset, SpectrumDataset, + StreamingSpectrumDataset, ) +from depthcharge.testing import assert_dicts_equal from depthcharge.tokenizers import PeptideTokenizer -def test_spectrum_loader(mgf_small): - """Test initialization of with a SpectrumIndex.""" - dset = SpectrumDataset(mgf_small, 2) +def test_spectrum_loader(mgf_small, tmp_path): + """Test a normal spectrum dataset.""" + dset = SpectrumDataset(mgf_small, tmp_path / "test") loader = dset.loader(batch_size=1, num_workers=0) - batch = next(iter(loader)) - assert len(batch) == 2 - assert batch[0].shape == (1, 13, 2) - assert batch[1].shape == (1, 2) + assert len(batch) == 7 + assert batch["mz_array"].shape == (1, 13) + assert isinstance(batch["mz_array"], torch.Tensor) - dset = SpectrumDataset(mgf_small, 2, []) - loader = dset.loader(batch_size=1, num_workers=0) + with pytest.warns(UserWarning): + dset.loader(collate_fn=torch.utils.data.default_collate) - batch = next(iter(loader)) - assert len(batch) == 2 - assert batch[0].shape == (1, 14, 2) - assert batch[1].shape == (1, 2) + +def test_streaming_spectrum_loader(mgf_small, tmp_path): + """Test streaming spectra.""" + streamer = StreamingSpectrumDataset(mgf_small, 2) + loader = streamer.loader(batch_size=2, num_workers=0) + stream_batch = next(iter(loader)) + with pytest.warns(UserWarning): + streamer.loader(collate_fn=torch.utils.data.default_collate) + + with pytest.warns(UserWarning): + streamer.loader(num_workers=2) + + dset = SpectrumDataset(mgf_small, tmp_path / "test") + loader = dset.loader(batch_size=2, num_workers=0) + map_batch = next(iter(loader)) + assert_dicts_equal(stream_batch, map_batch) def test_ann_spectrum_loader(mgf_small): """Test initialization of with a SpectrumIndex.""" tokenizer = PeptideTokenizer() - dset = AnnotatedSpectrumDataset(tokenizer, mgf_small) + dset = AnnotatedSpectrumDataset( + mgf_small, + "seq", + tokenizer, + custom_fields=CustomField( + "seq", lambda x: x["params"]["seq"], pa.string() + ), + ) loader = dset.loader(batch_size=1, num_workers=0) batch = next(iter(loader)) - assert len(batch) == 3 - assert batch[0].shape == (1, 13, 2) - assert batch[1].shape == (1, 2) - assert batch[2].shape == (1, 7) + assert len(batch) == 8 + assert batch["mz_array"].shape == (1, 13) + assert isinstance(batch["mz_array"], torch.Tensor) + torch.testing.assert_close(batch["seq"], tokenizer.tokenize(["LESLIEK"])) + + with pytest.warns(UserWarning): + dset.loader(collate_fn=torch.utils.data.default_collate) def test_peptide_loader(): @@ -55,7 +81,7 @@ def test_peptide_loader(): batch[0], tokenizer.tokenize(seqs)[:2, :], ) - torch.testing.assert_close(batch[1], torch.tensor(charges[:2], dtype=int)) + torch.testing.assert_close(batch[1], charges[:2]) args = (torch.tensor([1, 2, 3]), torch.tensor([[1, 1], [2, 2], [3, 3]])) dset = PeptideDataset(tokenizer, seqs, charges, *args) @@ -63,6 +89,6 @@ def test_peptide_loader(): batch = next(iter(loader)) assert len(batch) == 4 - torch.testing.assert_close(batch[1], torch.tensor(charges[:2], dtype=int)) + torch.testing.assert_close(batch[1], charges[:2]) torch.testing.assert_close(batch[2], torch.tensor([1, 2])) torch.testing.assert_close(batch[3], args[1][:2, :]) diff --git a/tests/unit_tests/test_data/test_parsers.py b/tests/unit_tests/test_data/test_parsers.py index 19f715f..cfec339 100644 --- a/tests/unit_tests/test_data/test_parsers.py +++ b/tests/unit_tests/test_data/test_parsers.py @@ -1,110 +1,213 @@ """Test that parsers work.""" +import polars as pl +import pyarrow as pa +import pytest +from polars.testing import assert_frame_equal, assert_series_equal -import numpy as np - -from depthcharge.data.parsers import MgfParser, MzmlParser, MzxmlParser +from depthcharge.data import CustomField +from depthcharge.data.parsers import ( + MgfParser, + MzmlParser, + MzxmlParser, + ParserFactory, +) +from depthcharge.data.preprocessing import scale_to_unit_norm SMALL_MGF_MZS = [ - 114.09134044, - 147.11280416, - 243.13393353, - 276.15539725, - 330.16596194, - 389.23946123, - 443.25002591, - 502.32352521, - 556.33408989, - 589.35555361, - 685.37668298, - 718.3981467, - 813.47164599, - 831.48221068, - 65.52857301, - 88.06311432, - 123.04204452, - 130.04986955, - 156.59257025, - 175.11895217, - 179.58407651, - 207.11640948, - 230.10791575, - 245.07681258, - 263.65844147, - 298.63737167, - 312.17786403, - 321.17191298, - 358.16087656, - 376.68792719, - 385.69320953, - 413.2255425, - 459.20855502, - 526.30960648, - 596.26746688, - 641.3365495, - 752.36857791, - 770.37914259, + [ + 114.09134044, + 147.11280416, + 243.13393353, + 276.15539725, + 330.16596194, + 389.23946123, + 443.25002591, + 502.32352521, + 556.33408989, + 589.35555361, + 685.37668298, + 718.3981467, + 813.47164599, + 831.48221068, + ], + [ + 65.52857301, + 88.06311432, + 123.04204452, + 130.04986955, + 156.59257025, + 175.11895217, + 179.58407651, + 207.11640948, + 230.10791575, + 245.07681258, + 263.65844147, + 298.63737167, + 312.17786403, + 321.17191298, + 358.16087656, + 376.68792719, + 385.69320953, + 413.2255425, + 459.20855502, + 526.30960648, + 596.26746688, + 641.3365495, + 752.36857791, + 770.37914259, + ], ] +MGF_FIELD = CustomField("t", lambda x: x["params"]["title"], pa.string()) +MZML_FIELD = CustomField("index", lambda x: x["index"], pa.int64()) +MZXML_FIELD = CustomField("CE", lambda x: x["collisionEnergy"], pa.float64()) + def test_mgf_and_base(mgf_small): """MGF file with a missing charge.""" - parser = MgfParser(mgf_small).read() - np.testing.assert_allclose( - parser.precursor_mz, - np.array([416.24474357, 257.464565]), + parsed = pl.from_arrow( + MgfParser(mgf_small, preprocessing_fn=[]).iter_batches(None) ) - np.testing.assert_equal( - parser.precursor_charge, - np.array([2, 3]), + expected = pl.DataFrame( + { + "peak_file": [mgf_small.name] * 2, + "scan_id": [0, 1], + "ms_level": [2, 2], + "precursor_mz": [416.24474357, 257.464565], + "precursor_charge": [2, 3], + "mz_array": SMALL_MGF_MZS, + "intensity_array": [ + [1.0] * len(SMALL_MGF_MZS[0]), + [1.0] * len(SMALL_MGF_MZS[1]), + ], + } + ).with_columns( + [ + pl.col("intensity_array").cast(pl.List(pl.Float64)), + pl.col("ms_level").cast(pl.UInt8), + pl.col("precursor_charge").cast(pl.Int16), + ] ) - np.testing.assert_equal( - parser.offset, - np.array([0, 14]), + + assert parsed.shape == (2, 7) + assert_frame_equal(parsed, expected) + + parsed = pl.from_arrow( + MgfParser(mgf_small, valid_charge=[2]).iter_batches(2), ) - np.testing.assert_allclose( - parser.mz_arrays, - np.array(SMALL_MGF_MZS), + assert parsed.shape == (1, 7) + assert isinstance(ParserFactory.get_parser(mgf_small), MgfParser) + + +@pytest.mark.parametrize( + ["ms_level", "preprocessing_fn", "valid_charge", "custom_fields", "shape"], + [ + (2, None, None, None, (4, 7)), + (1, None, None, None, (4, 7)), + (3, None, None, None, (3, 7)), + (2, None, [3], None, (3, 7)), + (None, None, None, None, (11, 7)), + (2, scale_to_unit_norm, None, MZML_FIELD, (4, 8)), + ], +) +def test_mzml( + real_mzml, ms_level, preprocessing_fn, valid_charge, custom_fields, shape +): + """A simple mzML test.""" + parsed = pl.from_arrow( + MzmlParser( + real_mzml, + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + ).iter_batches(None) ) + assert parsed.shape == shape - np.testing.assert_allclose( - parser.intensity_arrays, np.ones(len(SMALL_MGF_MZS), dtype=int) + +@pytest.mark.parametrize( + ["ms_level", "preprocessing_fn", "valid_charge", "custom_fields", "shape"], + [ + (2, None, None, None, (4, 7)), + (1, None, None, None, (4, 7)), + (3, None, None, None, (3, 7)), + (2, None, [3], None, (3, 7)), + (None, None, None, None, (11, 7)), + (2, scale_to_unit_norm, None, MZXML_FIELD, (4, 8)), + ], +) +def test_mzxml( + real_mzxml, ms_level, preprocessing_fn, valid_charge, custom_fields, shape +): + """A simple mzML test.""" + parsed = pl.from_arrow( + MzxmlParser( + real_mzxml, + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + ).iter_batches(None) ) + assert parsed.shape == shape - parser = MgfParser(mgf_small, valid_charge=[2]).read() - assert parser.precursor_charge.shape == (1,) - parser = MgfParser(mgf_small, ms_level=1).read() - np.testing.assert_equal( - parser.precursor_charge, - np.array([0, 0]), +@pytest.mark.parametrize( + ["ms_level", "preprocessing_fn", "valid_charge", "custom_fields", "shape"], + [ + (2, None, None, None, (7, 7)), + (1, None, None, None, (7, 7)), + (3, None, None, None, (7, 7)), + (2, None, [3], None, (3, 7)), + (None, None, None, None, (7, 7)), + (2, scale_to_unit_norm, None, MGF_FIELD, (7, 8)), + ], +) +def test_mgf( + real_mgf, ms_level, preprocessing_fn, valid_charge, custom_fields, shape +): + """A simple mzML test.""" + parsed = pl.from_arrow( + MgfParser( + real_mgf, + ms_level=ms_level, + preprocessing_fn=preprocessing_fn, + valid_charge=valid_charge, + custom_fields=custom_fields, + ).iter_batches(None) ) - np.testing.assert_equal( - parser.precursor_mz, - np.array([np.nan, np.nan]), + assert parsed.shape == shape + + +def test_custom_fields(mgf_small): + """Test that custom fields are working.""" + parsed = pl.from_arrow( + MgfParser( + mgf_small, + custom_fields=CustomField( + "seq", lambda x: x["params"]["seq"], pa.string() + ), + ).iter_batches(None) ) + expected = pl.Series("seq", ["LESLIEK", "EDITHR"]) + assert_series_equal(parsed["seq"], expected) -def test_mzml(real_mzml): - """A simple mzML test.""" - parser = MzmlParser(real_mzml, 2).read() - prev_len = len(parser.mz_arrays) - assert len(parser.intensity_arrays) == prev_len - assert len(parser.precursor_charge) == 4 + with pytest.raises(KeyError): + pl.from_arrow( + MgfParser( + mgf_small, + custom_fields=CustomField( + "seq", lambda x: x["params"]["bar"], pa.string() + ), + ).iter_batches(None) + ) - parser = MzmlParser(real_mzml, 2, valid_charge=[3]).read() - assert len(parser.precursor_charge) == 3 - assert len(parser.intensity_arrays) < prev_len - assert len(parser.mz_arrays) < prev_len +def test_invalid_file(tmp_path): + """Test an invalid file raises an error.""" + tmp_path.touch("blah.txt") -def test_mzxml(real_mzxml): - """A simple mzML test.""" - parser = MzxmlParser(real_mzxml, 2).read() - prev_len = len(parser.mz_arrays) - assert len(parser.intensity_arrays) == prev_len - assert len(parser.precursor_charge) == 4 - - parser = MzxmlParser(real_mzxml, 2, valid_charge=[3]).read() - assert len(parser.precursor_charge) == 3 - assert len(parser.intensity_arrays) < prev_len - assert len(parser.mz_arrays) < prev_len + with pytest.raises(OSError): + ParserFactory().get_parser(tmp_path / "blah.txt") diff --git a/tests/unit_tests/test_testing.py b/tests/unit_tests/test_testing.py new file mode 100644 index 0000000..5d475d1 --- /dev/null +++ b/tests/unit_tests/test_testing.py @@ -0,0 +1,27 @@ +"""Ironically test that the testing functions are working.""" +import numpy as np +import pytest +import torch + +from depthcharge.testing import assert_dicts_equal + + +@pytest.mark.parametrize( + ["dict1", "dict2", "error"], + [ + ({"a": "a", "b": "b"}, {"a": "a", "b": "b"}, False), + ({"a": "a"}, {"a": "c"}, True), + ({"a": "a"}, {"b": "a"}, True), + ({"a": torch.tensor([1])}, {"a": torch.tensor([1])}, False), + ({"a": torch.tensor([1])}, {"a": torch.tensor([2])}, True), + ({"a": torch.tensor([1])}, {"a": 1}, True), + ({"a": np.array([1])}, {"a": np.array([1])}, False), + ], +) +def test_assert_dicts_equal(dict1, dict2, error): + """Test the dict equal function.""" + if error: + with pytest.raises(AssertionError): + assert_dicts_equal(dict1, dict2) + else: + assert_dicts_equal(dict1, dict2) diff --git a/tests/unit_tests/test_transformers/test_peptide_transformers.py b/tests/unit_tests/test_transformers/test_peptide_transformers.py index b345b6a..d9debae 100644 --- a/tests/unit_tests/test_transformers/test_peptide_transformers.py +++ b/tests/unit_tests/test_transformers/test_peptide_transformers.py @@ -42,7 +42,7 @@ def test_peptide_decoder(): precursors = torch.tensor([[100.0, 2], [200.0, 3]]) encoder = SpectrumTransformerEncoder(8, 1, 12) - memory, mem_mask = encoder(spectra) + memory, mem_mask = encoder(spectra[:, :, 0], spectra[:, :, 1]) decoder = PeptideTransformerDecoder(n_tokens, 8, 1, 12, max_charge=3) scores = decoder(peptides, precursors, memory, mem_mask) diff --git a/tests/unit_tests/test_transformers/test_spectrum_transformers.py b/tests/unit_tests/test_transformers/test_spectrum_transformers.py index 843b1d6..d3fdc01 100644 --- a/tests/unit_tests/test_transformers/test_spectrum_transformers.py +++ b/tests/unit_tests/test_transformers/test_spectrum_transformers.py @@ -1,12 +1,14 @@ """Test the spectrum transformers.""" +import pytest import torch from depthcharge.encoders import PeakEncoder from depthcharge.transformers import SpectrumTransformerEncoder -def test_spectrum_encoder(): - """Test that a spectrum encoder will run.""" +@pytest.fixture +def batch(): + """A mass spectrum.""" spectra = torch.tensor( [ [[100.1, 0.1], [200.2, 0.2], [300.3, 0.3]], @@ -14,17 +16,70 @@ def test_spectrum_encoder(): ] ) + batch_dict = { + "mz_array": spectra[:, :, 0], + "intensity_array": spectra[:, :, 0], + "charge": torch.tensor([1.0, 2.0]), + } + + return batch_dict + + +def test_spectrum_encoder(batch): + """Test that a spectrum encoder will run.""" model = SpectrumTransformerEncoder(8, 1, 12) - emb, mask = model(spectra) + emb, mask = model(**batch) assert emb.shape == (2, 4, 8) assert mask.sum() == 1 model = SpectrumTransformerEncoder(8, 1, 12, peak_encoder=PeakEncoder(8)) - emb, mask = model(spectra) + emb, mask = model(**batch) assert emb.shape == (2, 4, 8) assert mask.sum() == 1 model = SpectrumTransformerEncoder(8, 1, 12, peak_encoder=False) - emb, mask = model(spectra) + emb, mask = model(**batch) assert emb.shape == (2, 4, 8) assert mask.sum() == 1 + + +def test_precursor_hook(batch): + """Test that the hook works.""" + + class MyEncoder(SpectrumTransformerEncoder): + """A silly class.""" + + def precursor_hook(self, mz_array, intensity_array, **kwargs): + """A silly hook.""" + return kwargs["charge"].expand(self.d_model, -1).T + + model1 = MyEncoder(8, 1, 12) + emb1, mask1 = model1(**batch) + assert emb1.shape == (2, 4, 8) + assert mask1.sum() == 1 + + model2 = SpectrumTransformerEncoder(8, 1, 12) + emb2, mask2 = model2(**batch) + assert emb2.shape == (2, 4, 8) + assert mask2.sum() == 1 + + for elem in zip(emb1.flatten(), emb2.flatten()): + if elem: + assert elem[0] != elem[1] + + +def test_properties(): + """Test that the properties work.""" + model = SpectrumTransformerEncoder( + d_model=256, + nhead=16, + dim_feedforward=48, + n_layers=3, + dropout=0.1, + ) + + assert model.d_model == 256 + assert model.nhead == 16 + assert model.dim_feedforward == 48 + assert model.n_layers == 3 + assert model.dropout == 0.1