Skip to content

Commit

Permalink
Add precision and fix pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
wfondrie committed Oct 7, 2023
1 parent 4eb6a8a commit cfdeeb4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
20 changes: 18 additions & 2 deletions depthcharge/data/spectrum_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from os import PathLike
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
from typing import Any, Literal

import lance
import polars as pl
Expand All @@ -28,6 +28,10 @@
class CollateFnMixin:
"""Define a common collate function."""

def __init__(self) -> None:
"""Specify the float precision."""
self.precision = torch.float64

Check warning on line 33 in depthcharge/data/spectrum_datasets.py

View check run for this annotation

Codecov / codecov/patch

depthcharge/data/spectrum_datasets.py#L33

Added line #L33 was not covered by tests

def collate_fn(
self,
batch: Iterable[dict[Any]],
Expand Down Expand Up @@ -62,10 +66,22 @@ def collate_fn(
batch = torch.utils.data.default_collate(batch)
batch["mz_array"] = mz_array
batch["intensity_array"] = intensity_array

for key, val in batch.items():
if isinstance(val, torch.Tensor) and torch.is_floating_point(val):
batch[key] = val.type(self.precision)

return batch

def loader(self, **kwargs: dict) -> DataLoader:
def loader(
self,
precision: Literal[
torch.float16, torch.float32, torch.float64
] = torch.float64,
**kwargs: dict,
) -> DataLoader:
"""Create a suitable PyTorch DataLoader."""
self.precision = precision
if kwargs.get("collate_fn", False):
warnings.warn("The default collate_fn was overridden.")
else:
Expand Down
15 changes: 15 additions & 0 deletions depthcharge/tokenizers/peptides.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ def __init__(

super().__init__(list(self.residues.keys()))

def __getstate__(self) -> dict:
"""How to pickle the object."""
self.residues = dict(self.residues)
return self.__dict__

def __setstate__(self, state: dict) -> None:
"""How to unpickle the object."""
self.__dict__ = state
residues = self.residues
self.residues = nb.typed.Dict.empty(
nb.types.unicode_type,
nb.types.float64,
)
self.residues.update(residues)

def split(self, sequence: str) -> list[str]:
"""Split a ProForma peptide sequence.
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/test_data/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the datasets."""
import pickle
import shutil

import pytest
Expand Down Expand Up @@ -154,3 +155,33 @@ def test_peptide_dataset(tokenizer):

torch.testing.assert_close(dset.tokens, tokenizer.tokenize(seqs))
torch.testing.assert_close(dset.charges, charges)


def test_pickle(tokenizer, tmp_path, mgf_small):
"""Test that datasets can be pickled."""
dataset = SpectrumDataset(mgf_small, path=tmp_path / "test")
pkl_file = tmp_path / "test.pkl"
with pkl_file.open("wb+") as pkl:
pickle.dump(dataset, pkl)

with pkl_file.open("rb") as pkl:
loaded = pickle.load(pkl)

assert len(dataset) == len(loaded)

dataset = AnnotatedSpectrumDataset(
[mgf_small],
tokenizer,
"seq",
tmp_path / "test.lance",
custom_fields={"seq": ["params", "seq"]},
)
pkl_file = tmp_path / "test.pkl"

with pkl_file.open("wb+") as pkl:
pickle.dump(dataset, pkl)

with pkl_file.open("rb") as pkl:
loaded = pickle.load(pkl)

assert len(dataset) == len(loaded)

0 comments on commit cfdeeb4

Please sign in to comment.