diff --git a/docs/release-notes/0.11.0.md b/docs/release-notes/0.11.0.md index 84ec3851b..b9ed0babc 100644 --- a/docs/release-notes/0.11.0.md +++ b/docs/release-notes/0.11.0.md @@ -2,6 +2,7 @@ ```{rubric} Features ``` +* Allow `axis` parameter of e.g. :func:`anndata.concat` to accept `'obs'` and `'var` {pr}`1244` {user}`flying-sheep` * Add `settings` object with methods for altering internally-used options, like checking for uniqueness on `obs`' index {pr}`1270` {user}`ilan-gold` * Add `remove_unused_categories` option to `anndata.settings` to override current behavior. Default is `True` (i.e., previous behavior). Please refer to the [documentation](https://anndata.readthedocs.io/en/latest/generated/anndata.settings.html) for usage. {pr}`1340` {user}`ilan-gold` * `scipy.sparse.csr_array` and `scipy.sparse.csc_array` are now supported when constructing `AnnData` objects {pr}`1028` {user}`ilan-gold` {user}`isaac-virshup` diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index f57dfb272..102547d1a 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -20,7 +20,7 @@ from anndata._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning from anndata.compat import AwkArray -from ..utils import deprecated, dim_len, ensure_df_homogeneous, warn_once +from ..utils import axis_len, deprecated, ensure_df_homogeneous, warn_once from .access import ElementRef from .index import _subset from .views import as_view, view_update @@ -69,10 +69,10 @@ def _validate_value(self, val: V, key: str) -> V: # stacklevel=3, ) for i, axis in enumerate(self.axes): - if self.parent.shape[axis] == dim_len(val, i): + if self.parent.shape[axis] == axis_len(val, i): continue right_shape = tuple(self.parent.shape[a] for a in self.axes) - actual_shape = tuple(dim_len(val, a) for a, _ in enumerate(self.axes)) + actual_shape = tuple(axis_len(val, a) for a, _ in enumerate(self.axes)) if actual_shape[i] is None and isinstance(val, AwkArray): dim = ("obs", "var")[i] msg = ( diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 90ab34b1e..4ad818e62 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -40,7 +40,7 @@ _move_adj_mtx, ) from ..logging import anndata_logger as logger -from ..utils import convert_to_dict, deprecated, dim_len, ensure_df_homogeneous +from ..utils import axis_len, convert_to_dict, deprecated, ensure_df_homogeneous from .access import ElementRef from .aligned_df import _gen_dataframe from .aligned_mapping import ( @@ -1843,7 +1843,7 @@ def _check_dimensions(self, key=None): if "obsm" in key: obsm = self._obsm if ( - not all([dim_len(o, 0) == self.n_obs for o in obsm.values()]) + not all([axis_len(o, 0) == self.n_obs for o in obsm.values()]) and len(obsm.dim_names) != self.n_obs ): raise ValueError( @@ -1853,7 +1853,7 @@ def _check_dimensions(self, key=None): if "varm" in key: varm = self._varm if ( - not all([dim_len(v, 0) == self.n_vars for v in varm.values()]) + not all([axis_len(v, 0) == self.n_vars for v in varm.values()]) and len(varm.dim_names) != self.n_vars ): raise ValueError( diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 69f85c898..909415165 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -38,7 +38,7 @@ SpArray, _map_cat_to_str, ) -from ..utils import asarray, dim_len, warn_once +from ..utils import asarray, axis_len, warn_once from .anndata import AnnData from .index import _subset, make_slice @@ -536,7 +536,7 @@ def apply(self, el, *, axis, fill_value=None): Missing values are to be replaced with `fill_value`. """ - if self.no_change and (dim_len(el, axis) == len(self.old_idx)): + if self.no_change and (axis_len(el, axis) == len(self.old_idx)): return el if isinstance(el, pd.DataFrame): return self._apply_to_df(el, axis=axis, fill_value=fill_value) @@ -1017,32 +1017,20 @@ def merge_outer(mappings, batch_keys, *, join_index="-", merge=merge_unique): return out -def _resolve_dim(*, dim: str = None, axis: int = None) -> tuple[int, str]: - _dims = ("obs", "var") - if (dim is None and axis is None) or (dim is not None and axis is not None): - raise ValueError( - f"Must pass exactly one of `dim` or `axis`. Got: dim={dim}, axis={axis}." - ) - elif dim is not None and dim not in _dims: - raise ValueError(f"`dim` must be one of ('obs', 'var'), was {dim}") - elif axis is not None and axis not in (0, 1): - raise ValueError(f"`axis` must be either 0 or 1, was {axis}") - if dim is not None: - return _dims.index(dim), dim - else: - return axis, _dims[axis] +def _resolve_axis( + axis: Literal["obs", 0, "var", 1], +) -> tuple[Literal[0], Literal["obs"]] | tuple[Literal[1], Literal["var"]]: + if axis in {0, "obs"}: + return (0, "obs") + if axis in {1, "var"}: + return (1, "var") + raise ValueError(f"`axis` must be either 0, 1, 'obs', or 'var', was {axis}") -def dim_indices(adata, *, axis=None, dim=None) -> pd.Index: +def axis_indices(adata: AnnData, axis: Literal["obs", 0, "var", 1]) -> pd.Index: """Helper function to get adata.{dim}_names.""" - _, dim = _resolve_dim(axis=axis, dim=dim) - return getattr(adata, f"{dim}_names") - - -def dim_size(adata, *, axis=None, dim=None) -> int: - """Helper function to get adata.shape[dim].""" - ax, _ = _resolve_dim(axis, dim) - return adata.shape[ax] + _, axis_name = _resolve_axis(axis) + return getattr(adata, f"{axis_name}_names") # TODO: Resolve https://github.com/scverse/anndata/issues/678 and remove this function @@ -1071,7 +1059,7 @@ def concat_Xs(adatas, reindexers, axis, fill_value): def concat( adatas: Collection[AnnData] | typing.Mapping[str, AnnData], *, - axis: Literal[0, 1] = 0, + axis: Literal["obs", 0, "var", 1] = "obs", join: Literal["inner", "outer"] = "inner", merge: StrategiesLiteral | Callable | None = None, uns_merge: StrategiesLiteral | Callable | None = None, @@ -1178,7 +1166,7 @@ def concat( s2 2 3 s3 4 5 s4 7 8 - >>> ad.concat([a, c], axis=1).to_df() + >>> ad.concat([a, c], axis="var").to_df() var1 var2 var3 var4 s1 0 1 10 11 s2 2 3 12 13 @@ -1205,6 +1193,19 @@ def concat( s3 4 5 6 s4 7 8 9 + Using the axis’ index instead of its name + + >>> ad.concat([a, b], axis=0).to_df() # Equivalent to axis="obs" + var1 var2 + s1 0 1 + s2 2 3 + s3 4 5 + s4 7 8 + >>> ad.concat([a, c], axis=1).to_df() # Equivalent to axis="var" + var1 var2 var3 var4 + s1 0 1 10 11 + s2 2 3 12 13 + Keeping track of source objects >>> ad.concat({"a": a, "b": b}, label="batch").obs @@ -1273,8 +1274,8 @@ def concat( if keys is None: keys = np.arange(len(adatas)).astype(str) - axis, dim = _resolve_dim(axis=axis) - alt_axis, alt_dim = _resolve_dim(axis=1 - axis) + axis, axis_name = _resolve_axis(axis) + alt_axis, alt_axis_name = _resolve_axis(axis=1 - axis) # Label column label_col = pd.Categorical.from_codes( @@ -1284,7 +1285,7 @@ def concat( # Combining indexes concat_indices = pd.concat( - [pd.Series(dim_indices(a, axis=axis)) for a in adatas], ignore_index=True + [pd.Series(axis_indices(a, axis=axis)) for a in adatas], ignore_index=True ) if index_unique is not None: concat_indices = concat_indices.str.cat( @@ -1293,16 +1294,16 @@ def concat( concat_indices = pd.Index(concat_indices) alt_indices = merge_indices( - [dim_indices(a, axis=alt_axis) for a in adatas], join=join + [axis_indices(a, axis=alt_axis) for a in adatas], join=join ) reindexers = [ - gen_reindexer(alt_indices, dim_indices(a, axis=alt_axis)) for a in adatas + gen_reindexer(alt_indices, axis_indices(a, axis=alt_axis)) for a in adatas ] # Annotation for concatenation axis - check_combinable_cols([getattr(a, dim).columns for a in adatas], join=join) + check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join) concat_annot = pd.concat( - unify_dtypes(getattr(a, dim) for a in adatas), + unify_dtypes(getattr(a, axis_name) for a in adatas), join=join, ignore_index=True, ) @@ -1312,7 +1313,7 @@ def concat( # Annotation for other axis alt_annot = merge_dataframes( - [getattr(a, alt_dim) for a in adatas], alt_indices, merge + [getattr(a, alt_axis_name) for a in adatas], alt_indices, merge ) X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value) @@ -1332,11 +1333,11 @@ def concat( [a.layers for a in adatas], axis=axis, reindexers=reindexers ) concat_mapping = concat_aligned_mapping( - [getattr(a, f"{dim}m") for a in adatas], index=concat_indices + [getattr(a, f"{axis_name}m") for a in adatas], index=concat_indices ) if pairwise: concat_pairwise = concat_pairwise_mapping( - mappings=[getattr(a, f"{dim}p") for a in adatas], + mappings=[getattr(a, f"{axis_name}p") for a in adatas], shapes=[a.shape[axis] for a in adatas], join_keys=join_keys, ) @@ -1346,13 +1347,16 @@ def concat( # TODO: Reindex lazily, so we don't have to make those copies until we're sure we need the element alt_mapping = merge( [ - {k: r(v, axis=0) for k, v in getattr(a, f"{alt_dim}m").items()} + {k: r(v, axis=0) for k, v in getattr(a, f"{alt_axis_name}m").items()} for r, a in zip(reindexers, adatas) ], ) alt_pairwise = merge( [ - {k: r(r(v, axis=0), axis=1) for k, v in getattr(a, f"{alt_dim}p").items()} + { + k: r(r(v, axis=0), axis=1) + for k, v in getattr(a, f"{alt_axis_name}p").items() + } for r, a in zip(reindexers, adatas) ] ) @@ -1388,12 +1392,12 @@ def concat( **{ "X": X, "layers": layers, - dim: concat_annot, - alt_dim: alt_annot, - f"{dim}m": concat_mapping, - f"{alt_dim}m": alt_mapping, - f"{dim}p": concat_pairwise, - f"{alt_dim}p": alt_pairwise, + axis_name: concat_annot, + alt_axis_name: alt_annot, + f"{axis_name}m": concat_mapping, + f"{alt_axis_name}m": alt_mapping, + f"{axis_name}p": concat_pairwise, + f"{alt_axis_name}p": alt_pairwise, "uns": uns, "raw": raw, } diff --git a/src/anndata/experimental/merge.py b/src/anndata/experimental/merge.py index 2e0e925a2..aa6f47e9b 100644 --- a/src/anndata/experimental/merge.py +++ b/src/anndata/experimental/merge.py @@ -16,7 +16,7 @@ MissingVal, Reindexer, StrategiesLiteral, - _resolve_dim, + _resolve_axis, concat_arrays, gen_inner_reindexers, gen_reindexer, @@ -367,34 +367,36 @@ def _write_concat_sequence( ) -def _write_alt_mapping(groups, output_group, alt_dim, alt_indices, merge): - alt_mapping = merge([read_as_backed(g[alt_dim]) for g in groups]) +def _write_alt_mapping(groups, output_group, alt_axis_name, alt_indices, merge): + alt_mapping = merge([read_as_backed(g[alt_axis_name]) for g in groups]) # If its empty, we need to write an empty dataframe with the correct index if not alt_mapping: alt_df = pd.DataFrame(index=alt_indices) - write_elem(output_group, alt_dim, alt_df) + write_elem(output_group, alt_axis_name, alt_df) else: - write_elem(output_group, alt_dim, alt_mapping) + write_elem(output_group, alt_axis_name, alt_mapping) -def _write_alt_annot(groups, output_group, alt_dim, alt_indices, merge): +def _write_alt_annot(groups, output_group, alt_axis_name, alt_indices, merge): # Annotation for other axis alt_annot = merge_dataframes( - [read_elem(g[alt_dim]) for g in groups], alt_indices, merge + [read_elem(g[alt_axis_name]) for g in groups], alt_indices, merge ) - write_elem(output_group, alt_dim, alt_annot) + write_elem(output_group, alt_axis_name, alt_annot) -def _write_dim_annot(groups, output_group, dim, concat_indices, label, label_col, join): +def _write_axis_annot( + groups, output_group, axis_name, concat_indices, label, label_col, join +): concat_annot = pd.concat( - unify_dtypes(read_elem(g[dim]) for g in groups), + unify_dtypes(read_elem(g[axis_name]) for g in groups), join=join, ignore_index=True, ) concat_annot.index = concat_indices if label is not None: concat_annot[label] = label_col - write_elem(output_group, dim, concat_annot) + write_elem(output_group, axis_name, concat_annot) def concat_on_disk( @@ -402,7 +404,7 @@ def concat_on_disk( out_file: str | os.PathLike, *, max_loaded_elems: int = 100_000_000, - axis: Literal[0, 1] = 0, + axis: Literal["obs", 0, "var", 1] = 0, join: Literal["inner", "outer"] = "inner", merge: StrategiesLiteral | Callable[[Collection[Mapping]], Mapping] | None = None, uns_merge: ( @@ -563,17 +565,17 @@ def concat_on_disk( if keys is None: keys = np.arange(len(in_files)).astype(str) - _, dim = _resolve_dim(axis=axis) - _, alt_dim = _resolve_dim(axis=1 - axis) + axis, axis_name = _resolve_axis(axis) + _, alt_axis_name = _resolve_axis(1 - axis) output_group = as_group(out_file, mode="w") groups = [as_group(f) for f in in_files] use_reindexing = False - alt_dims = [_df_index(g[alt_dim]) for g in groups] - # All dim_names must be equal if reindexing not applied - if not _indices_equal(alt_dims): + alt_idxs = [_df_index(g[alt_axis_name]) for g in groups] + # All {axis_name}_names must be equal if reindexing not applied + if not _indices_equal(alt_idxs): use_reindexing = True # All groups must be anndata @@ -594,34 +596,36 @@ def concat_on_disk( # Combining indexes concat_indices = pd.concat( - [pd.Series(_df_index(g[dim])) for g in groups], ignore_index=True + [pd.Series(_df_index(g[axis_name])) for g in groups], ignore_index=True ) if index_unique is not None: concat_indices = concat_indices.str.cat( _map_cat_to_str(label_col), sep=index_unique ) - # Resulting indices for {dim} and {alt_dim} + # Resulting indices for {axis_name} and {alt_axis_name} concat_indices = pd.Index(concat_indices) - alt_indices = merge_indices(alt_dims, join=join) + alt_index = merge_indices(alt_idxs, join=join) reindexers = None if use_reindexing: reindexers = [ - gen_reindexer(alt_indices, alt_old_index) for alt_old_index in alt_dims + gen_reindexer(alt_index, alt_old_index) for alt_old_index in alt_idxs ] else: reindexers = [IdentityReindexer()] * len(groups) - # Write {dim} - _write_dim_annot(groups, output_group, dim, concat_indices, label, label_col, join) + # Write {axis_name} + _write_axis_annot( + groups, output_group, axis_name, concat_indices, label, label_col, join + ) - # Write {alt_dim} - _write_alt_annot(groups, output_group, alt_dim, alt_indices, merge) + # Write {alt_axis_name} + _write_alt_annot(groups, output_group, alt_axis_name, alt_index, merge) - # Write {alt_dim}m - _write_alt_mapping(groups, output_group, alt_dim, alt_indices, merge) + # Write {alt_axis_name}m + _write_alt_mapping(groups, output_group, alt_axis_name, alt_index, merge) # Write X @@ -635,10 +639,10 @@ def concat_on_disk( max_loaded_elems=max_loaded_elems, ) - # Write Layers and {dim}m + # Write Layers and {axis_name}m mapping_names = [ ( - f"{dim}m", + f"{axis_name}m", concat_indices, 0, None if use_reindexing else [IdentityReindexer()] * len(groups), diff --git a/src/anndata/utils.py b/src/anndata/utils.py index a26affd6d..266b8e3be 100644 --- a/src/anndata/utils.py +++ b/src/anndata/utils.py @@ -3,7 +3,7 @@ import re import warnings from functools import singledispatch, wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import h5py import numpy as np @@ -101,7 +101,7 @@ def convert_to_dict_nonetype(obj: None): @singledispatch -def dim_len(x, axis): +def axis_len(x, axis: Literal[0, 1]) -> int | None: """\ Return the size of an array in dimension `axis`. @@ -175,11 +175,11 @@ def _size_at_depth(layout, depth, lateral_context, **kwargs): lateral_context["out"] = result return ak.contents.EmptyArray() - @dim_len.register(ak.Array) - def dim_len_awkward(array, axis): - """Get the length of an awkward array in a given dimension + @axis_len.register(ak.Array) + def axis_len_awkward(array, axis: Literal[0, 1]) -> int | None: + """Get the length of an awkward array in a given axis - Returns None if the dimension is of variable length. + Returns None if the axis is of variable length. Code adapted from @jpivarski's solution in https://github.com/scikit-hep/awkward/discussions/1654#discussioncomment-3521574 """ diff --git a/tests/test_awkward.py b/tests/test_awkward.py index ac525ce95..9780de4e2 100644 --- a/tests/test_awkward.py +++ b/tests/test_awkward.py @@ -17,7 +17,7 @@ ) from anndata.compat import awkward as ak from anndata.tests.helpers import assert_equal, gen_adata, gen_awkward -from anndata.utils import dim_len +from anndata.utils import axis_len @pytest.mark.parametrize( @@ -66,14 +66,14 @@ [ak.to_regular(ak.Array([["a", "b"], ["c", "d"], ["e", "f"]]), 1), (3, 2)], ], ) -def test_dim_len(array, shape): - """Test that dim_len returns the right value for awkward arrays.""" +def test_axis_len(array, shape): + """Test that axis_len returns the right value for awkward arrays.""" for axis, size in enumerate(shape): - assert size == dim_len(array, axis) + assert size == axis_len(array, axis) # Requesting the size for an axis higher than the array has dimensions should raise a TypeError with pytest.raises(TypeError): - dim_len(array, len(shape)) + axis_len(array, len(shape)) @pytest.mark.parametrize( @@ -361,7 +361,7 @@ def test_concat_mixed_types(key, arrays, expected, join): for a in arrays: shape = np.array([3, 3]) # default shape (in case of missing array) if a is not None: - length = dim_len(a, 0) + length = axis_len(a, 0) shape[axis] = length tmp_adata = gen_adata( diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 28af38ad3..6a0cc95bd 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -6,6 +6,7 @@ from copy import deepcopy from functools import partial, singledispatch from itertools import chain, permutations, product +from operator import attrgetter from typing import Any, Callable, Literal import numpy as np @@ -114,8 +115,8 @@ def fill_val(request): return request.param -@pytest.fixture(params=[0, 1]) -def axis(request) -> Literal[0, 1]: +@pytest.fixture(params=["obs", "var"]) +def axis_name(request) -> Literal["obs", "var"]: return request.param @@ -431,6 +432,15 @@ def test_concatenate_obsm_outer(obsm_adatas, fill_val): pd.testing.assert_frame_equal(true_df, cur_df) +@pytest.mark.parametrize( + ("axis", "axis_name"), + [("obs", 0), ("var", 1)], +) +def test_concat_axis_param(axis, axis_name): + a, b = gen_adata((10, 10)), gen_adata((10, 10)) + assert_equal(concat([a, b], axis=axis), concat([a, b], axis=axis_name)) + + def test_concat_annot_join(obsm_adatas, join_type): adatas = [ AnnData(sparse.csr_matrix(a.shape), obs=a.obsm["df"], var=a.var) @@ -824,26 +834,24 @@ def test_awkward_does_not_mix(join_type, other): concat([adata_a, adata_b], join=join_type) -def test_pairwise_concat(axis, array_type): - dim_sizes = [[100, 200, 50], [50, 50, 50]] - if axis: - dim_sizes.reverse() - Ms, Ns = dim_sizes - dim = ("obs", "var")[axis] - alt = ("var", "obs")[axis] - dim_attr = f"{dim}p" - alt_attr = f"{alt}p" +def test_pairwise_concat(axis_name, array_type): + axis, axis_name = merge._resolve_axis(axis_name) + _, alt_axis_name = merge._resolve_axis(1 - axis) + axis_sizes = [[100, 200, 50], [50, 50, 50]] + if axis_name == "var": + axis_sizes.reverse() + Ms, Ns = axis_sizes + axis_attr = f"{axis_name}p" + alt_attr = f"{alt_axis_name}p" - def gen_dim_array(m): + def gen_axis_array(m): return array_type(sparse.random(m, m, format="csr", density=0.1)) adatas = { k: AnnData( - **{ - "X": sparse.csr_matrix((m, n)), - "obsp": {"arr": gen_dim_array(m)}, - "varp": {"arr": gen_dim_array(n)}, - } + X=sparse.csr_matrix((m, n)), + obsp={"arr": gen_axis_array(m)}, + varp={"arr": gen_axis_array(n)}, ) for k, m, n in zip("abc", Ms, Ns) } @@ -852,16 +860,16 @@ def gen_dim_array(m): wo_pairwise = concat(adatas, axis=axis, label="orig", pairwise=False) # Check that argument controls whether elements are included - assert getattr(wo_pairwise, dim_attr) == {} - assert getattr(w_pairwise, dim_attr) != {} + assert getattr(wo_pairwise, axis_attr) == {} + assert getattr(w_pairwise, axis_attr) != {} # Check values of included elements full_inds = np.arange(w_pairwise.shape[axis]) - obs_var: pd.DataFrame = getattr(w_pairwise, dim) + obs_var: pd.DataFrame = getattr(w_pairwise, axis_name) groups = obs_var.groupby("orig", observed=True).indices for k, inds in groups.items(): - orig_arr = getattr(adatas[k], dim_attr)["arr"] - full_arr = getattr(w_pairwise, dim_attr)["arr"] + orig_arr = getattr(adatas[k], axis_attr)["arr"] + full_arr = getattr(w_pairwise, axis_attr)["arr"] if isinstance(full_arr, DaskArray): full_arr = full_arr.compute() @@ -884,14 +892,14 @@ def gen_dim_array(m): ) -def test_nan_merge(axis, join_type, array_type): - # concat_dim = ("obs", "var")[axis] - alt_dim = ("var", "obs")[axis] - mapping_attr = f"{alt_dim}m" +def test_nan_merge(axis_name, join_type, array_type): + axis, _ = merge._resolve_axis(axis_name) + alt_axis, alt_axis_name = merge._resolve_axis(1 - axis) + mapping_attr = f"{alt_axis_name}m" adata_shape = (20, 10) arr = array_type( - sparse.random(adata_shape[1 - axis], 10, density=0.1, format="csr") + sparse.random(adata_shape[alt_axis], 10, density=0.1, format="csr") ) arr_nan = arr.copy() with warnings.catch_warnings(): @@ -904,7 +912,7 @@ def test_nan_merge(axis, join_type, array_type): _data = {"X": sparse.csr_matrix(adata_shape), mapping_attr: {"arr": arr_nan}} orig1 = AnnData(**_data) orig2 = AnnData(**_data) - result = concat([orig1, orig2], axis=axis, merge="same") + result = concat([orig1, orig2], axis=axis, join=join_type, merge="same") assert_equal(getattr(orig1, mapping_attr), getattr(result, mapping_attr)) @@ -1149,29 +1157,28 @@ def test_concatenate_uns(unss, merge_strategy, result, value_gen): assert_equal(merged, result, elem_name="uns") -def test_transposed_concat(array_type, axis, join_type, merge_strategy, fill_val): +def test_transposed_concat(array_type, axis_name, join_type, merge_strategy): + axis, axis_name = merge._resolve_axis(axis_name) + alt_axis = 1 - axis lhs = gen_adata((10, 10), X_type=array_type, **GEN_ADATA_DASK_ARGS) rhs = gen_adata((10, 12), X_type=array_type, **GEN_ADATA_DASK_ARGS) a = concat([lhs, rhs], axis=axis, join=join_type, merge=merge_strategy) - b = concat( - [lhs.T, rhs.T], axis=abs(axis - 1), join=join_type, merge=merge_strategy - ).T + b = concat([lhs.T, rhs.T], axis=alt_axis, join=join_type, merge=merge_strategy).T assert_equal(a, b) -def test_batch_key(axis): +def test_batch_key(axis_name): """Test that concat only adds a label if the key is provided""" - def get_annot(adata): - return getattr(adata, ("obs", "var")[axis]) + get_annot = attrgetter(axis_name) lhs = gen_adata((10, 10), **GEN_ADATA_DASK_ARGS) rhs = gen_adata((10, 12), **GEN_ADATA_DASK_ARGS) # There is probably a prettier way to do this - annot = get_annot(concat([lhs, rhs], axis=axis)) + annot = get_annot(concat([lhs, rhs], axis=axis_name)) assert ( list( annot.columns.difference( @@ -1181,7 +1188,7 @@ def get_annot(adata): == [] ) - batch_annot = get_annot(concat([lhs, rhs], axis=axis, label="batch")) + batch_annot = get_annot(concat([lhs, rhs], axis=axis_name, label="batch")) assert list( batch_annot.columns.difference( get_annot(lhs).columns.union(get_annot(rhs).columns) @@ -1328,30 +1335,34 @@ def test_bool_promotion(): assert result.obs["bool"].dtype == np.dtype(bool) -def test_concat_names(axis): - def get_annot(adata): - return getattr(adata, ("obs", "var")[axis]) +def test_concat_names(axis_name): + get_annot = attrgetter(axis_name) lhs = gen_adata((10, 10)) rhs = gen_adata((10, 10)) - assert not get_annot(concat([lhs, rhs], axis=axis)).index.is_unique - assert get_annot(concat([lhs, rhs], axis=axis, index_unique="-")).index.is_unique + assert not get_annot(concat([lhs, rhs], axis=axis_name)).index.is_unique + assert get_annot( + concat([lhs, rhs], axis=axis_name, index_unique="-") + ).index.is_unique -def axis_labels(adata, axis): +def axis_labels(adata: AnnData, axis: Literal[0, 1]) -> pd.Index: return (adata.obs_names, adata.var_names)[axis] -def expected_shape(a, b, axis, join): - labels = partial(axis_labels, axis=abs(axis - 1)) +def expected_shape( + a: AnnData, b: AnnData, axis: Literal[0, 1], join: Literal["inner", "outer"] +) -> tuple[int, int]: + alt_axis = 1 - axis + labels = partial(axis_labels, axis=alt_axis) shape = [None, None] shape[axis] = a.shape[axis] + b.shape[axis] if join == "inner": - shape[abs(axis - 1)] = len(labels(a).intersection(labels(b))) + shape[alt_axis] = len(labels(a).intersection(labels(b))) elif join == "outer": - shape[abs(axis - 1)] = len(labels(a).union(labels(b))) + shape[alt_axis] = len(labels(a).union(labels(b))) else: raise ValueError() @@ -1361,12 +1372,12 @@ def expected_shape(a, b, axis, join): @pytest.mark.parametrize( "shape", [pytest.param((8, 0), id="no_var"), pytest.param((0, 10), id="no_obs")] ) -def test_concat_size_0_dim(axis, join_type, merge_strategy, shape): - # https://github.com/scverse/anndata/issues/526 +def test_concat_size_0_axis(axis_name, join_type, merge_strategy, shape): + """Regression test for https://github.com/scverse/anndata/issues/526""" + axis, axis_name = merge._resolve_axis(axis_name) + alt_axis = 1 - axis a = gen_adata((5, 7)) b = gen_adata(shape) - alt_axis = 1 - axis - dim = ("obs", "var")[axis] expected_size = expected_shape(a, b, axis=axis, join=join_type) @@ -1404,8 +1415,8 @@ def test_concat_size_0_dim(axis, join_type, merge_strategy, shape): if shape[axis] > 0: b_result = result[axis_idx].copy() - mapping_elem = f"{dim}m" - setattr(b_result, f"{dim}_names", getattr(b, f"{dim}_names")) + mapping_elem = f"{axis_name}m" + setattr(b_result, f"{axis_name}_names", getattr(b, f"{axis_name}_names")) for k, result_elem in getattr(b_result, mapping_elem).items(): elem_name = f"{mapping_elem}/{k}" # pd.concat can have unintuitive return types. is similar to numpy promotion @@ -1436,7 +1447,7 @@ def test_concat_outer_aligned_mapping(elem): @mark_legacy_concatenate -def test_concatenate_size_0_dim(): +def test_concatenate_size_0_axis(): # https://github.com/scverse/anndata/issues/526 a = gen_adata((5, 10)) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index fc39df83d..a530ea5a1 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -16,7 +16,7 @@ gen_awkward, report_name, ) -from anndata.utils import dim_len +from anndata.utils import axis_len # Testing to see if all error types can have the key name appended. # Currently fails for 22/118 since they have required arguments. Not sure what to do about that. @@ -72,7 +72,7 @@ def test_gen_awkward(shape, datashape): arr = gen_awkward(shape) for i, s in enumerate(shape): - assert dim_len(arr, i) == s + assert axis_len(arr, i) == s arr_type = ak.types.from_datashape(datashape) assert arr.type == arr_type