From ba06d2a905e7694bc84860f5768a20bfe2d16b46 Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Tue, 10 Dec 2024 02:44:28 -0800 Subject: [PATCH] Backport PR #1746 on branch 0.11.x ((fix): raise error on non-integer floating types in iterables) (#1799) Co-authored-by: Ilan Gold --- docs/release-notes/1746.bugfix.md | 1 + src/anndata/_core/index.py | 6 ++++++ tests/test_views.py | 16 ++++++++++++++++ 3 files changed, 23 insertions(+) create mode 100644 docs/release-notes/1746.bugfix.md diff --git a/docs/release-notes/1746.bugfix.md b/docs/release-notes/1746.bugfix.md new file mode 100644 index 000000000..1923bcdd1 --- /dev/null +++ b/docs/release-notes/1746.bugfix.md @@ -0,0 +1 @@ +Error out on floating point indices that are not actually integers {user}`ilan-gold` diff --git a/src/anndata/_core/index.py b/src/anndata/_core/index.py index f1d72ce0d..53434186a 100644 --- a/src/anndata/_core/index.py +++ b/src/anndata/_core/index.py @@ -82,6 +82,12 @@ def name_idx(i): indexer = np.array(indexer) if len(indexer) == 0: indexer = indexer.astype(int) + if isinstance(indexer, np.ndarray) and np.issubdtype( + indexer.dtype, np.floating + ): + indexer_int = indexer.astype(int) + if np.all((indexer - indexer_int) != 0): + raise IndexError(f"Indexer {indexer!r} has floating point values.") if issubclass(indexer.dtype.type, np.integer | np.floating): return indexer # Might not work for range indexes elif issubclass(indexer.dtype.type, np.bool_): diff --git a/tests/test_views.py b/tests/test_views.py index 6e57e08c7..fb6794dfd 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -814,6 +814,22 @@ def test_index_3d_errors(index: tuple[int | EllipsisType, ...], expected_error: gen_adata((10, 10))[index] +@pytest.mark.parametrize( + "index", + [ + pytest.param(sparse.csr_matrix(np.random.random((1, 10))), id="sparse"), + pytest.param([1.2, 3.4], id="list"), + *( + pytest.param(np.array([1.2, 2.3], dtype=dtype), id=f"ndarray-{dtype}") + for dtype in [np.float32, np.float64] + ), + ], +) +def test_index_float_sequence_raises_error(index): + with pytest.raises(IndexError, match=r"has floating point values"): + gen_adata((10, 10))[index] + + # @pytest.mark.parametrize("dim", ["obs", "var"]) # @pytest.mark.parametrize( # ("idx", "pat"),