From 0fa245d8bafef5e81ad540c5304eb1a9bc325a47 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Mon, 15 Jan 2024 14:26:52 +0100 Subject: [PATCH] Fix uns merge 3d (#1302) * add test * Better tests * Allow high dim objects to be compared during merge * release note * Remove redundant asarray * Simplify --- anndata/_core/merge.py | 8 +++++++- anndata/tests/helpers.py | 6 +++++- anndata/tests/test_concatenate.py | 12 ++++++++++++ docs/release-notes/0.10.5.md | 1 + 4 files changed, 25 insertions(+), 2 deletions(-) diff --git a/anndata/_core/merge.py b/anndata/_core/merge.py index 4429a7f15..48f36be9d 100644 --- a/anndata/_core/merge.py +++ b/anndata/_core/merge.py @@ -141,7 +141,13 @@ def equal_dask_array(a, b) -> bool: @equal.register(np.ndarray) def equal_array(a, b) -> bool: - return equal(pd.DataFrame(a), pd.DataFrame(asarray(b))) + # Reshaping allows us to compare inputs with >2 dimensions + # We cast to pandas since it will still work with non-numeric types + b = asarray(b) + if a.shape != b.shape: + return False + + return equal(pd.DataFrame(a.reshape(-1)), pd.DataFrame(b.reshape(-1))) @equal.register(CupyArray) diff --git a/anndata/tests/helpers.py b/anndata/tests/helpers.py index 428b3c21b..a62d890d0 100644 --- a/anndata/tests/helpers.py +++ b/anndata/tests/helpers.py @@ -413,7 +413,11 @@ def assert_equal_ndarray(a, b, exact=False, elem_name=None): and len(a.dtype) > 1 and len(b.dtype) > 0 ): - assert_equal(pd.DataFrame(a), pd.DataFrame(b), exact, elem_name) + # Reshaping to allow >2d arrays + assert a.shape == b.shape, format_msg(elem_name) + assert_equal( + pd.DataFrame(a.reshape(-1)), pd.DataFrame(b.reshape(-1)), exact, elem_name + ) else: assert np.all(a == b), format_msg(elem_name) diff --git a/anndata/tests/test_concatenate.py b/anndata/tests/test_concatenate.py index 4a51c5230..c0d670cd0 100644 --- a/anndata/tests/test_concatenate.py +++ b/anndata/tests/test_concatenate.py @@ -28,6 +28,7 @@ as_dense_dask_array, assert_equal, gen_adata, + gen_vstr_recarray, ) from anndata.utils import asarray @@ -1018,6 +1019,15 @@ def gen_something(n): return np.random.choice(options)(n) +def gen_3d_numeric_array(n): + return np.random.randn(n, n, n) + + +def gen_3d_recarray(_): + # Ignoring n as it can get quite slow + return gen_vstr_recarray(8, 3).reshape(2, 2, 2) + + def gen_concat_params(unss, compat2result): value_generators = [ lambda x: x, @@ -1026,6 +1036,8 @@ def gen_concat_params(unss, compat2result): gen_list, gen_sparse, gen_something, + gen_3d_numeric_array, + gen_3d_recarray, ] for gen, (mode, result) in product(value_generators, compat2result.items()): yield pytest.param(unss, mode, result, gen) diff --git a/docs/release-notes/0.10.5.md b/docs/release-notes/0.10.5.md index 5310ad681..130c2d0cf 100644 --- a/docs/release-notes/0.10.5.md +++ b/docs/release-notes/0.10.5.md @@ -4,6 +4,7 @@ ``` * Fix outer concatenation along variables when only a subset of objects had an entry in layers {pr}`1291` {user}`ivirshup` +* Fix comparison of >2d arrays in `uns` during concatenation {pr}`1300` {user}`ivirshup` ```{rubric} Documentation ```