Skip to content

Commit

Permalink
Fix uns merge 3d (#1302)
Browse files Browse the repository at this point in the history
* add test

* Better tests

* Allow high dim objects to be compared during merge

* release note

* Remove redundant asarray

* Simplify
  • Loading branch information
ivirshup authored Jan 15, 2024
1 parent 86bcc36 commit 0fa245d
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
8 changes: 7 additions & 1 deletion anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ def equal_dask_array(a, b) -> bool:

@equal.register(np.ndarray)
def equal_array(a, b) -> bool:
return equal(pd.DataFrame(a), pd.DataFrame(asarray(b)))
# Reshaping allows us to compare inputs with >2 dimensions
# We cast to pandas since it will still work with non-numeric types
b = asarray(b)
if a.shape != b.shape:
return False

return equal(pd.DataFrame(a.reshape(-1)), pd.DataFrame(b.reshape(-1)))


@equal.register(CupyArray)
Expand Down
6 changes: 5 additions & 1 deletion anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,11 @@ def assert_equal_ndarray(a, b, exact=False, elem_name=None):
and len(a.dtype) > 1
and len(b.dtype) > 0
):
assert_equal(pd.DataFrame(a), pd.DataFrame(b), exact, elem_name)
# Reshaping to allow >2d arrays
assert a.shape == b.shape, format_msg(elem_name)
assert_equal(
pd.DataFrame(a.reshape(-1)), pd.DataFrame(b.reshape(-1)), exact, elem_name
)
else:
assert np.all(a == b), format_msg(elem_name)

Expand Down
12 changes: 12 additions & 0 deletions anndata/tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
as_dense_dask_array,
assert_equal,
gen_adata,
gen_vstr_recarray,
)
from anndata.utils import asarray

Expand Down Expand Up @@ -1018,6 +1019,15 @@ def gen_something(n):
return np.random.choice(options)(n)


def gen_3d_numeric_array(n):
return np.random.randn(n, n, n)


def gen_3d_recarray(_):
# Ignoring n as it can get quite slow
return gen_vstr_recarray(8, 3).reshape(2, 2, 2)


def gen_concat_params(unss, compat2result):
value_generators = [
lambda x: x,
Expand All @@ -1026,6 +1036,8 @@ def gen_concat_params(unss, compat2result):
gen_list,
gen_sparse,
gen_something,
gen_3d_numeric_array,
gen_3d_recarray,
]
for gen, (mode, result) in product(value_generators, compat2result.items()):
yield pytest.param(unss, mode, result, gen)
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/0.10.5.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
```

* Fix outer concatenation along variables when only a subset of objects had an entry in layers {pr}`1291` {user}`ivirshup`
* Fix comparison of >2d arrays in `uns` during concatenation {pr}`1300` {user}`ivirshup`

```{rubric} Documentation
```
Expand Down

0 comments on commit 0fa245d

Please sign in to comment.