Skip to content

Commit

Permalink
Run tests using Zarr Python v3
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Nov 4, 2024
1 parent 1846566 commit f2983f3
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 25 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,27 @@ jobs:
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}

test-zarr-version:
name: Test Zarr Python v3
# Scheduled runs only on the origin org
if: (github.event_name == 'schedule' && github.repository_owner == 'sgkit-dev') || (github.event_name != 'schedule')
runs-on: ubuntu-latest
strategy:
matrix:
zarr: [">=3"]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt -r requirements-dev.txt
- name: Install zarr${{ matrix.zarr }}
run: |
python -m pip install --pre 'zarr${{ matrix.zarr }}'
- name: Run tests
run: |
pytest
7 changes: 4 additions & 3 deletions sgkit/io/bgen/bgen_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dask
import dask.array as da
import dask.dataframe as dd
import numcodecs
import numpy as np
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -348,7 +349,7 @@ def encode_variables(
ds: Dataset,
chunk_length: int,
chunk_width: int,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
compressor: Optional[Any] = numcodecs.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[Any] = "uint8",
) -> Dict[Hashable, Dict[str, Any]]:
encoding = {}
Expand Down Expand Up @@ -424,7 +425,7 @@ def rechunk_bgen(
*,
chunk_length: int = 10_000,
chunk_width: int = 1_000,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
compressor: Optional[Any] = numcodecs.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[DType] = "uint8",
max_mem: str = "4GB",
pack: bool = True,
Expand Down Expand Up @@ -538,7 +539,7 @@ def bgen_to_zarr(
chunk_length: int = 10_000,
chunk_width: int = 1_000,
temp_chunk_length: int = 100,
compressor: Optional[Any] = zarr.Blosc(cname="zstd", clevel=7, shuffle=2),
compressor: Optional[Any] = numcodecs.Blosc(cname="zstd", clevel=7, shuffle=2),
probability_dtype: Optional[DType] = "uint8",
max_mem: str = "4GB",
pack: bool = True,
Expand Down
20 changes: 6 additions & 14 deletions sgkit/io/dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from pathlib import Path
from typing import Any, Dict, MutableMapping, Optional, Union

import fsspec
import numcodecs
import xarray as xr
from xarray import Dataset

from sgkit.typing import PathType
from sgkit.utils import has_keyword


def save_dataset(
ds: Dataset,
store: Union[PathType, MutableMapping[str, bytes]],
storage_options: Optional[Dict[str, str]] = None,
auto_rechunk: Optional[bool] = None,
zarr_format: int = 2,
**kwargs: Any,
) -> None:
"""Save a dataset to Zarr storage.
Expand All @@ -35,11 +35,6 @@ def save_dataset(
kwargs
Additional arguments to pass to :meth:`xarray.Dataset.to_zarr`.
"""
if isinstance(store, str):
storage_options = storage_options or {}
store = fsspec.get_mapper(store, **storage_options)
elif isinstance(store, Path):
store = str(store)
if auto_rechunk is None:
auto_rechunk = False
for v in ds:
Expand Down Expand Up @@ -71,7 +66,9 @@ def save_dataset(

# Catch unequal chunking errors to provide a more helpful error message
try:
ds.to_zarr(store, **kwargs)
if has_keyword(ds.to_zarr, "zarr_format"): # from xarray v2024.10.0
kwargs["zarr_format"] = zarr_format
ds.to_zarr(store, storage_options=storage_options, **kwargs)
except ValueError as e:
if "Zarr requires uniform chunk sizes" in str(
e
Expand Down Expand Up @@ -109,12 +106,7 @@ def load_dataset(
Dataset
The dataset loaded from the Zarr store or file system.
"""
if isinstance(store, str):
storage_options = storage_options or {}
store = fsspec.get_mapper(store, **storage_options)
elif isinstance(store, Path):
store = str(store)
ds: Dataset = xr.open_zarr(store, concat_characters=False, **kwargs) # type: ignore[no-untyped-call]
ds: Dataset = xr.open_zarr(store, storage_options=storage_options, concat_characters=False, **kwargs) # type: ignore[no-untyped-call]
for v in ds:
# Workaround for https://github.com/pydata/xarray/issues/4386
if v.endswith("_mask"): # type: ignore
Expand Down
10 changes: 9 additions & 1 deletion sgkit/tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest
import xarray as xr
import zarr
from packaging.version import Version
from xarray import Dataset

from sgkit import load_dataset, save_dataset
Expand Down Expand Up @@ -54,7 +56,10 @@ def test_save_unequal_chunks_error():
n_variant=10, n_sample=10, n_ploidy=10, n_allele=10, n_contig=10
)
# Normal zarr errors shouldn't be caught
with pytest.raises(ValueError, match="path '' contains an array"):
with pytest.raises(
(FileExistsError, ValueError),
match="(path '' contains an array|Store already exists)",
):
save_dataset(ds, {".zarray": ""})

# Make the dataset have unequal chunk sizes across all dimensions
Expand All @@ -74,6 +79,9 @@ def test_save_unequal_chunks_error():
save_dataset(ds, {})


@pytest.mark.skipif(
Version(zarr.__version__).major >= 3, reason="Fails for Zarr Python 3"
)
def test_save_auto_rechunk():
# Make all dimensions the same size for ease of testing
ds = simulate_genotype_call_dataset(
Expand Down
12 changes: 7 additions & 5 deletions sgkit/tests/test_association.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import pandas as pd
import pytest
import xarray as xr
import zarr
from pandas import DataFrame
from xarray import Dataset

try:
from zarr.storage import ZipStore # v3
except ImportError: # pragma: no cover
from zarr import ZipStore

import sgkit.distarray as da
from sgkit.stats.association import (
gwas_linear_regression,
Expand Down Expand Up @@ -313,12 +317,10 @@ def test_regenie_loco_regression(ndarray_type: str, covariate: bool) -> None:

for ds_name in datasets:
# Load simulated data
genotypes_store = zarr.ZipStore(
genotypes_store = ZipStore(
str(ds_dir / ds_name / "genotypes.zarr.zip"), mode="r"
)
glow_store = zarr.ZipStore(
str(ds_dir / ds_name / glow_offsets_filename), mode="r"
)
glow_store = ZipStore(str(ds_dir / ds_name / glow_offsets_filename), mode="r")

ds = xr.open_zarr(genotypes_store, consolidated=False)
glow_loco_predictions = xr.open_zarr(glow_store, consolidated=False)
Expand Down
8 changes: 6 additions & 2 deletions sgkit/tests/test_regenie.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
import xarray as xr
import yaml
import zarr
from dask.array import Array
from hypothesis import given, settings
from hypothesis import strategies as st
Expand All @@ -18,6 +17,11 @@
from pandas import DataFrame
from xarray import Dataset

try:
from zarr.storage import ZipStore # v3
except ImportError: # pragma: no cover
from zarr import ZipStore

from sgkit.stats.association import LinearRegressionResult, linear_regression
from sgkit.stats.regenie import (
index_array_blocks,
Expand Down Expand Up @@ -258,7 +262,7 @@ def check_simulation_result(
result_dir = datadir / "result" / run["name"]

# Load simulated data
with zarr.ZipStore(str(dataset_dir / "genotypes.zarr.zip"), mode="r") as store:
with ZipStore(str(dataset_dir / "genotypes.zarr.zip"), mode="r") as store:
ds = xr.open_zarr(store, consolidated=False)
df_covariate = load_covariates(dataset_dir)
df_trait = load_traits(dataset_dir)
Expand Down
8 changes: 8 additions & 0 deletions sgkit/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import warnings
from itertools import product
from typing import Any, Callable, Hashable, List, Mapping, Optional, Set, Tuple, Union
Expand Down Expand Up @@ -425,3 +426,10 @@ def smallest_numpy_int_dtype(value: int) -> Optional[DType]:
if np.iinfo(dtype).min <= value <= np.iinfo(dtype).max:
return dtype
raise OverflowError(f"Value {value} cannot be stored in np.int64")


def has_keyword(func, keyword):
try:
return keyword in inspect.signature(func).parameters
except Exception:
return False

0 comments on commit f2983f3

Please sign in to comment.