From bfd8df75615a42c32dcfb7e4f57acef92f6c6981 Mon Sep 17 00:00:00 2001 From: Isaac Virshup Date: Tue, 10 Oct 2023 23:35:12 +0200 Subject: [PATCH] simplify --- anndata/_core/merge.py | 22 +++++++++---------- anndata/compat/__init__.py | 9 ++++++++ anndata/experimental/merge.py | 15 ++++--------- .../multi_files/_anncollection.py | 14 ++++-------- 4 files changed, 28 insertions(+), 32 deletions(-) diff --git a/anndata/_core/merge.py b/anndata/_core/merge.py index e29bd759b..2df90e290 100644 --- a/anndata/_core/merge.py +++ b/anndata/_core/merge.py @@ -22,13 +22,19 @@ import numpy as np import pandas as pd from natsort import natsorted -from packaging.version import parse as parse_version from scipy import sparse from scipy.sparse import spmatrix from anndata._warnings import ExperimentalFeatureWarning -from ..compat import AwkArray, CupyArray, CupyCSRMatrix, CupySparseMatrix, DaskArray +from ..compat import ( + AwkArray, + CupyArray, + CupyCSRMatrix, + CupySparseMatrix, + DaskArray, + _map_cat_to_str, +) from ..utils import asarray, dim_len from .anndata import AnnData from .index import _subset, make_slice @@ -1240,15 +1246,9 @@ def concat( [pd.Series(dim_indices(a, axis=axis)) for a in adatas], ignore_index=True ) if index_unique is not None: - if parse_version(pd.__version__) >= parse_version("2.0"): - # Argument added in pandas 2.0 - concat_indices = concat_indices.str.cat( - label_col.map(str, na_action="ignore"), sep=index_unique - ) - else: - concat_indices = concat_indices.str.cat( - label_col.map(str), sep=index_unique - ) + concat_indices = concat_indices.str.cat( + _map_cat_to_str(label_col), sep=index_unique + ) concat_indices = pd.Index(concat_indices) alt_indices = merge_indices( diff --git a/anndata/compat/__init__.py b/anndata/compat/__init__.py index d6ab5c2f3..10a41d037 100644 --- a/anndata/compat/__init__.py +++ b/anndata/compat/__init__.py @@ -14,6 +14,7 @@ import h5py import numpy as np import pandas as pd +from packaging.version import parse as _parse_version from scipy.sparse import issparse, spmatrix from .exceptiongroups import add_note # noqa: F401 @@ -391,3 +392,11 @@ def _safe_transpose(x): return _transpose_by_block(x) else: return x.T + + +def _map_cat_to_str(cat: pd.Categorical): + if _parse_version(pd.__version__) >= _parse_version("2.0"): + # Argument added in pandas 2.0 + return cat.map(str, na_action="ignore") + else: + return cat.map(str) diff --git a/anndata/experimental/merge.py b/anndata/experimental/merge.py index fef14a2ea..59c0623a8 100644 --- a/anndata/experimental/merge.py +++ b/anndata/experimental/merge.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd -from packaging.version import parse as parse_version from scipy.sparse import csc_matrix, csr_matrix from .._core.file_backing import to_memory @@ -33,7 +32,7 @@ ) from .._core.sparse_dataset import BaseCompressedSparseDataset, sparse_dataset from .._io.specs import read_elem, write_elem -from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup +from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup, _map_cat_to_str from . import read_dispatched SPARSE_MATRIX = {"csc_matrix", "csr_matrix"} @@ -597,15 +596,9 @@ def concat_on_disk( [pd.Series(_df_index(g[dim])) for g in groups], ignore_index=True ) if index_unique is not None: - if parse_version(pd.__version__) >= parse_version("2.0"): - # Argument added in pandas 2.0 - concat_indices = concat_indices.str.cat( - label_col.map(str, na_action="ignore"), sep=index_unique - ) - else: - concat_indices = concat_indices.str.cat( - label_col.map(str), sep=index_unique - ) + concat_indices = concat_indices.str.cat( + _map_cat_to_str(label_col), sep=index_unique + ) # Resulting indices for {dim} and {alt_dim} concat_indices = pd.Index(concat_indices) diff --git a/anndata/experimental/multi_files/_anncollection.py b/anndata/experimental/multi_files/_anncollection.py index b5fa42a7d..acacdc8d3 100644 --- a/anndata/experimental/multi_files/_anncollection.py +++ b/anndata/experimental/multi_files/_anncollection.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd from h5py import Dataset -from packaging.version import parse as parse_version from ..._core.aligned_mapping import AxisArrays from ..._core.anndata import AnnData @@ -16,6 +15,7 @@ from ..._core.merge import concat_arrays, inner_concat_aligned_mapping from ..._core.sparse_dataset import BaseCompressedSparseDataset from ..._core.views import _resolve_idx +from ...compat import _map_cat_to_str ATTRS = ["obs", "obsm", "layers"] @@ -721,15 +721,9 @@ def __init__( categories=keys, ) if index_unique is not None: - if parse_version(pd.__version__) >= parse_version("2.0"): - # Argument added in pandas 2.0 - concat_indices = concat_indices.str.cat( - label_col.map(str, na_action="ignore"), sep=index_unique - ) - else: - concat_indices = concat_indices.str.cat( - label_col.map(str), sep=index_unique - ) + concat_indices = concat_indices.str.cat( + _map_cat_to_str(label_col), sep=index_unique + ) self.obs_names = pd.Index(concat_indices) if not self.obs_names.is_unique: