Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add semantic dimension identifier to concat API #1244

Merged
merged 29 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a7a2457
Concat API needs dim
flying-sheep Nov 28, 2023
9400495
mix up concat tests
flying-sheep Nov 28, 2023
e291621
relnote
flying-sheep Nov 28, 2023
1a26aef
Fix tests
flying-sheep Nov 29, 2023
7b036f3
Merge branch 'main' into fix-concat-api
flying-sheep Dec 7, 2023
cc3a1a5
Sort changelog
flying-sheep Dec 7, 2023
1f0a2ca
Merge branch 'main' into fix-concat-api
flying-sheep Dec 12, 2023
002223b
Merge branch 'main' into fix-concat-api
flying-sheep Jan 9, 2024
042263e
Consensus API
flying-sheep Jan 16, 2024
7d6e8d2
fix relnote
flying-sheep Jan 16, 2024
bb9031c
Merge branch 'main' into fix-concat-api
flying-sheep Jan 19, 2024
e3852b3
Merge branch 'main' into fix-concat-api
flying-sheep Jan 23, 2024
3359cd7
Update anndata/_core/merge.py
flying-sheep Jan 23, 2024
fdc6e69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2024
a40afdd
undo unnecessary changes
flying-sheep Jan 23, 2024
59683e8
doctest
flying-sheep Jan 23, 2024
40247a7
Simple axis val test
flying-sheep Jan 23, 2024
2d712c4
move relnote
flying-sheep Jan 23, 2024
0a24155
Use helper
flying-sheep Jan 23, 2024
4667813
Merge branch 'main' into fix-concat-api
flying-sheep Jan 23, 2024
15bd843
Merge branch 'main' into fix-concat-api
flying-sheep Feb 1, 2024
d283bda
move alt axis selection
flying-sheep Feb 2, 2024
0194efe
simpler test_concat_axis_param
flying-sheep Feb 2, 2024
2bbbb91
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2024
7ebfc51
address remaining comments
flying-sheep Feb 13, 2024
fc3e1d6
Merge branch 'main' into fix-concat-api
flying-sheep Feb 13, 2024
bea4ada
Merge branch 'main' into fix-concat-api
flying-sheep Apr 19, 2024
7696e83
Merge branch 'main' into fix-concat-api
flying-sheep Apr 25, 2024
f346188
Merge branch 'main' into fix-concat-api
flying-sheep Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions anndata/_core/aligned_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from anndata._warnings import ExperimentalFeatureWarning, ImplicitModificationWarning
from anndata.compat import AwkArray

from ..utils import deprecated, dim_len, ensure_df_homogeneous, warn_once
from ..utils import axis_len, deprecated, ensure_df_homogeneous, warn_once
from .access import ElementRef
from .index import _subset
from .views import as_view, view_update
Expand Down Expand Up @@ -69,10 +69,10 @@ def _validate_value(self, val: V, key: str) -> V:
# stacklevel=3,
)
for i, axis in enumerate(self.axes):
if self.parent.shape[axis] == dim_len(val, i):
if self.parent.shape[axis] == axis_len(val, i):
continue
right_shape = tuple(self.parent.shape[a] for a in self.axes)
actual_shape = tuple(dim_len(val, a) for a, _ in enumerate(self.axes))
actual_shape = tuple(axis_len(val, a) for a, _ in enumerate(self.axes))
if actual_shape[i] is None and isinstance(val, AwkArray):
dim = ("obs", "var")[i]
msg = (
Expand Down
6 changes: 3 additions & 3 deletions anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
_move_adj_mtx,
)
from ..logging import anndata_logger as logger
from ..utils import convert_to_dict, deprecated, dim_len, ensure_df_homogeneous
from ..utils import axis_len, convert_to_dict, deprecated, ensure_df_homogeneous
from .access import ElementRef
from .aligned_mapping import (
AxisArrays,
Expand Down Expand Up @@ -1920,7 +1920,7 @@ def _check_dimensions(self, key=None):
if "obsm" in key:
obsm = self._obsm
if (
not all([dim_len(o, 0) == self.n_obs for o in obsm.values()])
not all([axis_len(o, 0) == self.n_obs for o in obsm.values()])
and len(obsm.dim_names) != self.n_obs
):
raise ValueError(
Expand All @@ -1930,7 +1930,7 @@ def _check_dimensions(self, key=None):
if "varm" in key:
varm = self._varm
if (
not all([dim_len(v, 0) == self.n_vars for v in varm.values()])
not all([axis_len(v, 0) == self.n_vars for v in varm.values()])
and len(varm.dim_names) != self.n_vars
):
raise ValueError(
Expand Down
79 changes: 39 additions & 40 deletions anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
DaskArray,
_map_cat_to_str,
)
from ..utils import asarray, dim_len, warn_once
from ..utils import asarray, axis_len, warn_once
from .anndata import AnnData
from .index import _subset, make_slice

Expand Down Expand Up @@ -526,7 +526,7 @@ def apply(self, el, *, axis, fill_value=None):

Missing values are to be replaced with `fill_value`.
"""
if self.no_change and (dim_len(el, axis) == len(self.old_idx)):
if self.no_change and (axis_len(el, axis) == len(self.old_idx)):
return el
if isinstance(el, pd.DataFrame):
return self._apply_to_df(el, axis=axis, fill_value=fill_value)
Expand Down Expand Up @@ -991,31 +991,25 @@ def merge_outer(mappings, batch_keys, *, join_index="-", merge=merge_unique):
return out


def _resolve_dim(*, dim: str = None, axis: int = None) -> tuple[int, str]:
_dims = ("obs", "var")
if (dim is None and axis is None) or (dim is not None and axis is not None):
raise ValueError(
f"Must pass exactly one of `dim` or `axis`. Got: dim={dim}, axis={axis}."
)
elif dim is not None and dim not in _dims:
raise ValueError(f"`dim` must be one of ('obs', 'var'), was {dim}")
elif axis is not None and axis not in (0, 1):
raise ValueError(f"`axis` must be either 0 or 1, was {axis}")
if dim is not None:
return _dims.index(dim), dim
else:
return axis, _dims[axis]
def _resolve_axis(
axis: Literal["obs", 0, "var", 1],
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
) -> tuple[Literal[0], Literal["obs"]] | tuple[Literal[1], Literal["var"]]:
if axis is None:
raise ValueError("Must pass `axis` != None.")
elif axis not in (0, 1, "obs", "var"):
raise ValueError(f"`axis` must be either 0, 1, 'obs', or 'var', was {axis}")
return (0, "obs") if axis in {0, "obs"} else (1, "var")
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved


def dim_indices(adata, *, axis=None, dim=None) -> pd.Index:
def axis_indices(adata: AnnData, axis: Literal["obs", 0, "var", 1]) -> pd.Index:
"""Helper function to get adata.{dim}_names."""
_, dim = _resolve_dim(axis=axis, dim=dim)
return getattr(adata, f"{dim}_names")
_, axis_name = _resolve_axis(axis)
return getattr(adata, f"{axis_name}_names")


def dim_size(adata, *, axis=None, dim=None) -> int:
def axis_size(adata: AnnData, axis: Literal["obs", 0, "var", 1]) -> int:
"""Helper function to get adata.shape[dim]."""
ax, _ = _resolve_dim(axis, dim)
ax, _ = _resolve_axis(axis)
return adata.shape[ax]


Expand Down Expand Up @@ -1045,7 +1039,7 @@ def concat_Xs(adatas, reindexers, axis, fill_value):
def concat(
adatas: Collection[AnnData] | typing.Mapping[str, AnnData],
*,
axis: Literal[0, 1] = 0,
axis: Literal["obs", 0, "var", 1] = "obs",
join: Literal["inner", "outer"] = "inner",
merge: StrategiesLiteral | Callable | None = None,
uns_merge: StrategiesLiteral | Callable | None = None,
Expand Down Expand Up @@ -1152,6 +1146,8 @@ def concat(
s2 2 3
s3 4 5
s4 7 8
>>> # ad.concat([a, c], axis="var").to_df()
>>> # or
flying-sheep marked this conversation as resolved.
Show resolved Hide resolved
>>> ad.concat([a, c], axis=1).to_df()
var1 var2 var3 var4
s1 0 1 10 11
Expand Down Expand Up @@ -1247,8 +1243,8 @@ def concat(
if keys is None:
keys = np.arange(len(adatas)).astype(str)

axis, dim = _resolve_dim(axis=axis)
alt_axis, alt_dim = _resolve_dim(axis=1 - axis)
axis, axis_name = _resolve_axis(axis)
alt_axis, alt_axis_name = _resolve_axis(axis=1 - axis)

# Label column
label_col = pd.Categorical.from_codes(
Expand All @@ -1258,7 +1254,7 @@ def concat(

# Combining indexes
concat_indices = pd.concat(
[pd.Series(dim_indices(a, axis=axis)) for a in adatas], ignore_index=True
[pd.Series(axis_indices(a, axis=axis)) for a in adatas], ignore_index=True
)
if index_unique is not None:
concat_indices = concat_indices.str.cat(
Expand All @@ -1267,16 +1263,16 @@ def concat(
concat_indices = pd.Index(concat_indices)

alt_indices = merge_indices(
[dim_indices(a, axis=alt_axis) for a in adatas], join=join
[axis_indices(a, axis=alt_axis) for a in adatas], join=join
)
reindexers = [
gen_reindexer(alt_indices, dim_indices(a, axis=alt_axis)) for a in adatas
gen_reindexer(alt_indices, axis_indices(a, axis=alt_axis)) for a in adatas
]

# Annotation for concatenation axis
check_combinable_cols([getattr(a, dim).columns for a in adatas], join=join)
check_combinable_cols([getattr(a, axis_name).columns for a in adatas], join=join)
concat_annot = pd.concat(
unify_dtypes(getattr(a, dim) for a in adatas),
unify_dtypes(getattr(a, axis_name) for a in adatas),
join=join,
ignore_index=True,
)
Expand All @@ -1286,7 +1282,7 @@ def concat(

# Annotation for other axis
alt_annot = merge_dataframes(
[getattr(a, alt_dim) for a in adatas], alt_indices, merge
[getattr(a, alt_axis_name) for a in adatas], alt_indices, merge
)

X = concat_Xs(adatas, reindexers, axis=axis, fill_value=fill_value)
Expand All @@ -1306,11 +1302,11 @@ def concat(
[a.layers for a in adatas], axis=axis, reindexers=reindexers
)
concat_mapping = concat_aligned_mapping(
[getattr(a, f"{dim}m") for a in adatas], index=concat_indices
[getattr(a, f"{axis_name}m") for a in adatas], index=concat_indices
)
if pairwise:
concat_pairwise = concat_pairwise_mapping(
mappings=[getattr(a, f"{dim}p") for a in adatas],
mappings=[getattr(a, f"{axis_name}p") for a in adatas],
shapes=[a.shape[axis] for a in adatas],
join_keys=join_keys,
)
Expand All @@ -1320,13 +1316,16 @@ def concat(
# TODO: Reindex lazily, so we don't have to make those copies until we're sure we need the element
alt_mapping = merge(
[
{k: r(v, axis=0) for k, v in getattr(a, f"{alt_dim}m").items()}
{k: r(v, axis=0) for k, v in getattr(a, f"{alt_axis_name}m").items()}
for r, a in zip(reindexers, adatas)
],
)
alt_pairwise = merge(
[
{k: r(r(v, axis=0), axis=1) for k, v in getattr(a, f"{alt_dim}p").items()}
{
k: r(r(v, axis=0), axis=1)
for k, v in getattr(a, f"{alt_axis_name}p").items()
}
for r, a in zip(reindexers, adatas)
]
)
Expand Down Expand Up @@ -1362,12 +1361,12 @@ def concat(
**{
"X": X,
"layers": layers,
dim: concat_annot,
alt_dim: alt_annot,
f"{dim}m": concat_mapping,
f"{alt_dim}m": alt_mapping,
f"{dim}p": concat_pairwise,
f"{alt_dim}p": alt_pairwise,
axis_name: concat_annot,
alt_axis_name: alt_annot,
f"{axis_name}m": concat_mapping,
f"{alt_axis_name}m": alt_mapping,
f"{axis_name}p": concat_pairwise,
f"{alt_axis_name}p": alt_pairwise,
"uns": uns,
"raw": raw,
}
Expand Down
62 changes: 33 additions & 29 deletions anndata/experimental/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MissingVal,
Reindexer,
StrategiesLiteral,
_resolve_dim,
_resolve_axis,
concat_arrays,
gen_inner_reindexers,
gen_reindexer,
Expand Down Expand Up @@ -364,42 +364,44 @@ def _write_concat_sequence(
)


def _write_alt_mapping(groups, output_group, alt_dim, alt_indices, merge):
alt_mapping = merge([read_as_backed(g[alt_dim]) for g in groups])
def _write_alt_mapping(groups, output_group, alt_axis_name, alt_indices, merge):
alt_mapping = merge([read_as_backed(g[alt_axis_name]) for g in groups])
# If its empty, we need to write an empty dataframe with the correct index
if not alt_mapping:
alt_df = pd.DataFrame(index=alt_indices)
write_elem(output_group, alt_dim, alt_df)
write_elem(output_group, alt_axis_name, alt_df)
else:
write_elem(output_group, alt_dim, alt_mapping)
write_elem(output_group, alt_axis_name, alt_mapping)


def _write_alt_annot(groups, output_group, alt_dim, alt_indices, merge):
def _write_alt_annot(groups, output_group, alt_axis_name, alt_indices, merge):
# Annotation for other axis
alt_annot = merge_dataframes(
[read_elem(g[alt_dim]) for g in groups], alt_indices, merge
[read_elem(g[alt_axis_name]) for g in groups], alt_indices, merge
)
write_elem(output_group, alt_dim, alt_annot)
write_elem(output_group, alt_axis_name, alt_annot)


def _write_dim_annot(groups, output_group, dim, concat_indices, label, label_col, join):
def _write_axis_annot(
groups, output_group, axis_name, concat_indices, label, label_col, join
):
concat_annot = pd.concat(
unify_dtypes(read_elem(g[dim]) for g in groups),
unify_dtypes(read_elem(g[axis_name]) for g in groups),
join=join,
ignore_index=True,
)
concat_annot.index = concat_indices
if label is not None:
concat_annot[label] = label_col
write_elem(output_group, dim, concat_annot)
write_elem(output_group, axis_name, concat_annot)


def concat_on_disk(
in_files: Collection[str | os.PathLike] | Mapping[str, str | os.PathLike],
out_file: str | os.PathLike,
*,
max_loaded_elems: int = 100_000_000,
axis: Literal[0, 1] = 0,
axis: Literal["obs", 0, "var", 1] = 0,
join: Literal["inner", "outer"] = "inner",
merge: StrategiesLiteral | Callable[[Collection[Mapping]], Mapping] | None = None,
uns_merge: (
Expand Down Expand Up @@ -562,17 +564,17 @@ def concat_on_disk(
if keys is None:
keys = np.arange(len(in_files)).astype(str)

_, dim = _resolve_dim(axis=axis)
_, alt_dim = _resolve_dim(axis=1 - axis)
axis, axis_name = _resolve_axis(axis)
_, alt_axis_name = _resolve_axis(1 - axis)

output_group = as_group(out_file, mode="w")
groups = [as_group(f) for f in in_files]

use_reindexing = False

alt_dims = [_df_index(g[alt_dim]) for g in groups]
# All dim_names must be equal if reindexing not applied
if not _indices_equal(alt_dims):
alt_idxs = [_df_index(g[alt_axis_name]) for g in groups]
# All {axis_name}_names must be equal if reindexing not applied
if not _indices_equal(alt_idxs):
use_reindexing = True

# All groups must be anndata
Expand All @@ -593,34 +595,36 @@ def concat_on_disk(

# Combining indexes
concat_indices = pd.concat(
[pd.Series(_df_index(g[dim])) for g in groups], ignore_index=True
[pd.Series(_df_index(g[axis_name])) for g in groups], ignore_index=True
)
if index_unique is not None:
concat_indices = concat_indices.str.cat(
_map_cat_to_str(label_col), sep=index_unique
)

# Resulting indices for {dim} and {alt_dim}
# Resulting indices for {axis_name} and {alt_axis_name}
concat_indices = pd.Index(concat_indices)

alt_indices = merge_indices(alt_dims, join=join)
alt_index = merge_indices(alt_idxs, join=join)

reindexers = None
if use_reindexing:
reindexers = [
gen_reindexer(alt_indices, alt_old_index) for alt_old_index in alt_dims
gen_reindexer(alt_index, alt_old_index) for alt_old_index in alt_idxs
]
else:
reindexers = [IdentityReindexer()] * len(groups)

# Write {dim}
_write_dim_annot(groups, output_group, dim, concat_indices, label, label_col, join)
# Write {axis_name}
_write_axis_annot(
groups, output_group, axis_name, concat_indices, label, label_col, join
)

# Write {alt_dim}
_write_alt_annot(groups, output_group, alt_dim, alt_indices, merge)
# Write {alt_axis_name}
_write_alt_annot(groups, output_group, alt_axis_name, alt_index, merge)

# Write {alt_dim}m
_write_alt_mapping(groups, output_group, alt_dim, alt_indices, merge)
# Write {alt_axis_name}m
_write_alt_mapping(groups, output_group, alt_axis_name, alt_index, merge)

# Write X

Expand All @@ -634,10 +638,10 @@ def concat_on_disk(
max_loaded_elems=max_loaded_elems,
)

# Write Layers and {dim}m
# Write Layers and {axis_name}m
mapping_names = [
(
f"{dim}m",
f"{axis_name}m",
concat_indices,
0,
None if use_reindexing else [IdentityReindexer()] * len(groups),
Expand Down
Loading
Loading