diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 9360e318..9f386298 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -680,7 +680,7 @@ def variant_stats( -------- :func:`count_variant_genotypes` """ - from .aggregation_numba_fns import count_hom + from .aggregation_numba_fns import count_hom_new_axis variables.validate(ds, {call_genotype: variables.call_genotype_spec}) mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) @@ -697,7 +697,7 @@ def variant_stats( G = da.asarray(ds[call_genotype].data) H = xr.DataArray( da.map_blocks( - lambda *args: count_hom(*args)[:, np.newaxis, :], + count_hom_new_axis, G, np.zeros(3, np.uint64), drop_axis=2, @@ -796,7 +796,7 @@ def sample_stats( ValueError If the dataset contains mixed-ploidy genotype calls. """ - from .aggregation_numba_fns import count_hom + from .aggregation_numba_fns import count_hom_new_axis variables.validate(ds, {call_genotype: variables.call_genotype_spec}) mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) @@ -805,7 +805,7 @@ def sample_stats( GT = da.asarray(ds[call_genotype].transpose("samples", "variants", "ploidy").data) H = xr.DataArray( da.map_blocks( - lambda *args: count_hom(*args)[:, np.newaxis, :], + count_hom_new_axis, GT, np.zeros(3, np.uint64), drop_axis=2, diff --git a/sgkit/stats/aggregation_numba_fns.py b/sgkit/stats/aggregation_numba_fns.py index 3335f545..b84b92a0 100644 --- a/sgkit/stats/aggregation_numba_fns.py +++ b/sgkit/stats/aggregation_numba_fns.py @@ -2,6 +2,8 @@ # in a separate file here, and imported dynamically to avoid # initial compilation overhead. +import numpy as np + from sgkit.accelerate import numba_guvectorize, numba_jit from sgkit.typing import ArrayLike @@ -102,3 +104,7 @@ def count_hom( index = _classify_hom(genotypes[i]) if index >= 0: out[index] += 1 + + +def count_hom_new_axis(genotypes: ArrayLike, _: ArrayLike) -> ArrayLike: + return count_hom(genotypes, _)[:, np.newaxis, :] diff --git a/sgkit/stats/popgen.py b/sgkit/stats/popgen.py index d000bdbe..e201dfc9 100644 --- a/sgkit/stats/popgen.py +++ b/sgkit/stats/popgen.py @@ -595,9 +595,7 @@ def pbs( cohorts = cohorts or list(itertools.combinations(range(n_cohorts), 3)) # type: ignore ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts_0", None)) - p = da.map_blocks( - lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64 - ) + p = da.map_blocks(_pbs_cohorts, t, ct, chunks=shape, new_axis=3, dtype=np.float64) assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts) new_ds = create_dataset(