From f412fd37e3f0d6845364c0d7b86b71097ada1671 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Tue, 10 Dec 2024 07:10:32 -0800 Subject: [PATCH] Backport PR #1780 on branch 0.11.x ((fix): use dask array for missing element in dask concatenation) (#1800) Co-authored-by: Ilan Gold --- src/anndata/_core/merge.py | 36 ++++++++++++++++++++++++++++-------- tests/test_concatenate.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 8 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 0dfa5dab2..77672fdda 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -939,19 +939,37 @@ def gen_outer_reindexers(els, shapes, new_index: pd.Index, *, axis=0): return reindexers +def missing_element( + n: int, + els: list[SpArray | sparse.csr_matrix | sparse.csc_matrix | np.ndarray | DaskArray], + axis: Literal[0, 1] = 0, + fill_value: Any | None = None, +) -> np.ndarray | DaskArray: + """Generates value to use when there is a missing element.""" + should_return_dask = any(isinstance(el, DaskArray) for el in els) + try: + non_missing_elem = next(el for el in els if not_missing(el)) + except StopIteration: # pragma: no cover + msg = "All elements are missing when attempting to generate missing elements." + raise ValueError(msg) + # 0 sized array for in-memory prevents allocating unnecessary memory while preserving broadcasting. + off_axis_size = 0 if not should_return_dask else non_missing_elem.shape[axis - 1] + shape = (n, off_axis_size) if axis == 0 else (off_axis_size, n) + if should_return_dask: + import dask.array as da + + return da.full( + shape, default_fill_value(els) if fill_value is None else fill_value + ) + return np.zeros(shape, dtype=bool) + + def outer_concat_aligned_mapping( mappings, *, reindexers=None, index=None, axis=0, fill_value=None ): result = {} ns = [m.parent.shape[axis] for m in mappings] - def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray: - """Generates value to use when there is a missing element.""" - if axis == 0: - return np.zeros((n, 0), dtype=bool) - else: - return np.zeros((0, n), dtype=bool) - for k in union_keys(mappings): els = [m.get(k, MissingVal) for m in mappings] if reindexers is None: @@ -963,7 +981,9 @@ def missing_element(n: int, axis: Literal[0, 1] = 0) -> np.ndarray: # We should probably just handle missing elements for all types result[k] = concat_arrays( [ - el if not_missing(el) else missing_element(n, axis=axis) + el + if not_missing(el) + else missing_element(n, axis=axis, els=els, fill_value=fill_value) for el, n in zip(els, ns) ], cur_reindexers, diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index d9f399dd6..2a2e16a5a 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1533,6 +1533,34 @@ def test_concat_different_types_dask(merge_strategy, array_type): assert_equal(result2, target2) +def test_concat_missing_elem_dask_join(join_type): + import dask.array as da + + import anndata as ad + + ad1 = ad.AnnData(X=np.ones((5, 5))) + ad2 = ad.AnnData(X=np.zeros((5, 5)), layers={"a": da.ones((5, 5))}) + ad_in_memory_with_layers = ad2.to_memory() + + result1 = ad.concat([ad1, ad2], join=join_type) + result2 = ad.concat([ad1, ad_in_memory_with_layers], join=join_type) + assert_equal(result1, result2) + + +def test_impute_dask(axis_name): + import dask.array as da + + from anndata._core.merge import _resolve_axis, missing_element + + axis, _ = _resolve_axis(axis_name) + els = [da.ones((5, 5))] + missing = missing_element(6, els, axis=axis) + assert isinstance(missing, DaskArray) + in_memory = missing.compute() + assert np.all(np.isnan(in_memory)) + assert in_memory.shape[axis] == 6 + + def test_outer_concat_with_missing_value_for_df(): # https://github.com/scverse/anndata/issues/901 # TODO: Extend this test to cover all cases of missing values