Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Oct 10, 2023
1 parent 799cfc3 commit bfd8df7
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 32 deletions.
22 changes: 11 additions & 11 deletions anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,19 @@
import numpy as np
import pandas as pd
from natsort import natsorted
from packaging.version import parse as parse_version
from scipy import sparse
from scipy.sparse import spmatrix

from anndata._warnings import ExperimentalFeatureWarning

from ..compat import AwkArray, CupyArray, CupyCSRMatrix, CupySparseMatrix, DaskArray
from ..compat import (
AwkArray,
CupyArray,
CupyCSRMatrix,
CupySparseMatrix,
DaskArray,
_map_cat_to_str,
)
from ..utils import asarray, dim_len
from .anndata import AnnData
from .index import _subset, make_slice
Expand Down Expand Up @@ -1240,15 +1246,9 @@ def concat(
[pd.Series(dim_indices(a, axis=axis)) for a in adatas], ignore_index=True
)
if index_unique is not None:
if parse_version(pd.__version__) >= parse_version("2.0"):
# Argument added in pandas 2.0
concat_indices = concat_indices.str.cat(
label_col.map(str, na_action="ignore"), sep=index_unique
)
else:
concat_indices = concat_indices.str.cat(
label_col.map(str), sep=index_unique
)
concat_indices = concat_indices.str.cat(
_map_cat_to_str(label_col), sep=index_unique
)
concat_indices = pd.Index(concat_indices)

alt_indices = merge_indices(
Expand Down
9 changes: 9 additions & 0 deletions anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import h5py
import numpy as np
import pandas as pd
from packaging.version import parse as _parse_version
from scipy.sparse import issparse, spmatrix

from .exceptiongroups import add_note # noqa: F401
Expand Down Expand Up @@ -391,3 +392,11 @@ def _safe_transpose(x):
return _transpose_by_block(x)
else:
return x.T


def _map_cat_to_str(cat: pd.Categorical):
if _parse_version(pd.__version__) >= _parse_version("2.0"):
# Argument added in pandas 2.0
return cat.map(str, na_action="ignore")
else:
return cat.map(str)
15 changes: 4 additions & 11 deletions anndata/experimental/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import numpy as np
import pandas as pd
from packaging.version import parse as parse_version
from scipy.sparse import csc_matrix, csr_matrix

from .._core.file_backing import to_memory
Expand All @@ -33,7 +32,7 @@
)
from .._core.sparse_dataset import BaseCompressedSparseDataset, sparse_dataset
from .._io.specs import read_elem, write_elem
from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup
from ..compat import H5Array, H5Group, ZarrArray, ZarrGroup, _map_cat_to_str
from . import read_dispatched

SPARSE_MATRIX = {"csc_matrix", "csr_matrix"}
Expand Down Expand Up @@ -597,15 +596,9 @@ def concat_on_disk(
[pd.Series(_df_index(g[dim])) for g in groups], ignore_index=True
)
if index_unique is not None:
if parse_version(pd.__version__) >= parse_version("2.0"):
# Argument added in pandas 2.0
concat_indices = concat_indices.str.cat(
label_col.map(str, na_action="ignore"), sep=index_unique
)
else:
concat_indices = concat_indices.str.cat(
label_col.map(str), sep=index_unique
)
concat_indices = concat_indices.str.cat(
_map_cat_to_str(label_col), sep=index_unique
)

# Resulting indices for {dim} and {alt_dim}
concat_indices = pd.Index(concat_indices)
Expand Down
14 changes: 4 additions & 10 deletions anndata/experimental/multi_files/_anncollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import numpy as np
import pandas as pd
from h5py import Dataset
from packaging.version import parse as parse_version

from ..._core.aligned_mapping import AxisArrays
from ..._core.anndata import AnnData
from ..._core.index import Index, _normalize_index, _normalize_indices
from ..._core.merge import concat_arrays, inner_concat_aligned_mapping
from ..._core.sparse_dataset import BaseCompressedSparseDataset
from ..._core.views import _resolve_idx
from ...compat import _map_cat_to_str

ATTRS = ["obs", "obsm", "layers"]

Expand Down Expand Up @@ -721,15 +721,9 @@ def __init__(
categories=keys,
)
if index_unique is not None:
if parse_version(pd.__version__) >= parse_version("2.0"):
# Argument added in pandas 2.0
concat_indices = concat_indices.str.cat(
label_col.map(str, na_action="ignore"), sep=index_unique
)
else:
concat_indices = concat_indices.str.cat(
label_col.map(str), sep=index_unique
)
concat_indices = concat_indices.str.cat(
_map_cat_to_str(label_col), sep=index_unique
)
self.obs_names = pd.Index(concat_indices)

if not self.obs_names.is_unique:
Expand Down

0 comments on commit bfd8df7

Please sign in to comment.