From cfdeeb43bac117cfa7fdbf2364ff799c1e930c6e Mon Sep 17 00:00:00 2001 From: William Fondrie Date: Fri, 6 Oct 2023 19:59:14 -0700 Subject: [PATCH] Add precision and fix pickling --- depthcharge/data/spectrum_datasets.py | 20 +++++++++++-- depthcharge/tokenizers/peptides.py | 15 ++++++++++ tests/unit_tests/test_data/test_datasets.py | 31 +++++++++++++++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/depthcharge/data/spectrum_datasets.py b/depthcharge/data/spectrum_datasets.py index e1b17c3..1609948 100644 --- a/depthcharge/data/spectrum_datasets.py +++ b/depthcharge/data/spectrum_datasets.py @@ -8,7 +8,7 @@ from os import PathLike from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any +from typing import Any, Literal import lance import polars as pl @@ -28,6 +28,10 @@ class CollateFnMixin: """Define a common collate function.""" + def __init__(self) -> None: + """Specify the float precision.""" + self.precision = torch.float64 + def collate_fn( self, batch: Iterable[dict[Any]], @@ -62,10 +66,22 @@ def collate_fn( batch = torch.utils.data.default_collate(batch) batch["mz_array"] = mz_array batch["intensity_array"] = intensity_array + + for key, val in batch.items(): + if isinstance(val, torch.Tensor) and torch.is_floating_point(val): + batch[key] = val.type(self.precision) + return batch - def loader(self, **kwargs: dict) -> DataLoader: + def loader( + self, + precision: Literal[ + torch.float16, torch.float32, torch.float64 + ] = torch.float64, + **kwargs: dict, + ) -> DataLoader: """Create a suitable PyTorch DataLoader.""" + self.precision = precision if kwargs.get("collate_fn", False): warnings.warn("The default collate_fn was overridden.") else: diff --git a/depthcharge/tokenizers/peptides.py b/depthcharge/tokenizers/peptides.py index fe251c8..bddad42 100644 --- a/depthcharge/tokenizers/peptides.py +++ b/depthcharge/tokenizers/peptides.py @@ -94,6 +94,21 @@ def __init__( super().__init__(list(self.residues.keys())) + def __getstate__(self) -> dict: + """How to pickle the object.""" + self.residues = dict(self.residues) + return self.__dict__ + + def __setstate__(self, state: dict) -> None: + """How to unpickle the object.""" + self.__dict__ = state + residues = self.residues + self.residues = nb.typed.Dict.empty( + nb.types.unicode_type, + nb.types.float64, + ) + self.residues.update(residues) + def split(self, sequence: str) -> list[str]: """Split a ProForma peptide sequence. diff --git a/tests/unit_tests/test_data/test_datasets.py b/tests/unit_tests/test_data/test_datasets.py index 8a4bee3..61a609c 100644 --- a/tests/unit_tests/test_data/test_datasets.py +++ b/tests/unit_tests/test_data/test_datasets.py @@ -1,4 +1,5 @@ """Test the datasets.""" +import pickle import shutil import pytest @@ -154,3 +155,33 @@ def test_peptide_dataset(tokenizer): torch.testing.assert_close(dset.tokens, tokenizer.tokenize(seqs)) torch.testing.assert_close(dset.charges, charges) + + +def test_pickle(tokenizer, tmp_path, mgf_small): + """Test that datasets can be pickled.""" + dataset = SpectrumDataset(mgf_small, path=tmp_path / "test") + pkl_file = tmp_path / "test.pkl" + with pkl_file.open("wb+") as pkl: + pickle.dump(dataset, pkl) + + with pkl_file.open("rb") as pkl: + loaded = pickle.load(pkl) + + assert len(dataset) == len(loaded) + + dataset = AnnotatedSpectrumDataset( + [mgf_small], + tokenizer, + "seq", + tmp_path / "test.lance", + custom_fields={"seq": ["params", "seq"]}, + ) + pkl_file = tmp_path / "test.pkl" + + with pkl_file.open("wb+") as pkl: + pickle.dump(dataset, pkl) + + with pkl_file.open("rb") as pkl: + loaded = pickle.load(pkl) + + assert len(dataset) == len(loaded)