Skip to content

Commit

Permalink
Merge pull request #134 from SarahOuologuem/fix_132
Browse files Browse the repository at this point in the history
Small improvements in mu.pl.scatter()
  • Loading branch information
gtca authored Oct 17, 2024
2 parents c7461aa + 9b98f51 commit eee8df0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
20 changes: 8 additions & 12 deletions muon/_core/plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Optional, Iterable, Sequence, Dict
from typing import Dict, Iterable, List, Optional, Sequence, Union
import warnings

from matplotlib.axes import Axes
Expand Down Expand Up @@ -43,7 +43,7 @@ def scatter(
y : Optional[str]
y coordinate
color : Optional[Union[str, Sequence[str]]], optional (default: None)
Keys for variables or annotations of observations (.obs columns),
Keys or a single key for variables or annotations of observations (.obs columns),
or a hex colour specification.
use_raw : Optional[bool], optional (default: None)
Use `.raw` attribute of the modality where a feature (from `color`) is derived from.
Expand All @@ -53,9 +53,7 @@ def scatter(
No layer is used by default. A single layer value will be expanded to [layer, layer, layer].
"""
if isinstance(data, AnnData):
return sc.pl.embedding(
data, x=x, y=y, color=color, use_raw=use_raw, layers=layers, **kwargs
)
return sc.pl.scatter(data, x=x, y=y, color=color, use_raw=use_raw, layers=layers, **kwargs)

if isinstance(layers, str) or layers is None:
layers = [layers, layers, layers]
Expand All @@ -72,10 +70,9 @@ def scatter(
if isinstance(color, str):
color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])
color_obs = pd.DataFrame({color: color_obs})
color = [color]
else:
# scanpy#311 / scanpy#1497 has to be fixed for this to work
color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])

color_obs.index = data.obs_names
obs = pd.concat([obs, color_obs], axis=1, ignore_index=False)

Expand All @@ -86,11 +83,10 @@ def scatter(
# and are now stored in .obs
retval = sc.pl.scatter(ad, x=x, y=y, color=color, **kwargs)
if color is not None:
for col in color:
try:
data.uns[f"{col}_colors"] = ad.uns[f"{col}_colors"]
except KeyError:
pass
try:
data.uns[f"{color}_colors"] = ad.uns[f"{color}_colors"]
except KeyError:
pass
return retval


Expand Down
31 changes: 31 additions & 0 deletions tests/test_muon_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

import numpy as np
from scipy import sparse
import pandas as pd
from anndata import AnnData
import muon as mu
from muon import MuData
import matplotlib

matplotlib.use("Agg")


@pytest.fixture()
def mdata():
mdata = MuData(
{
"mod1": AnnData(np.arange(0, 100, 0.1).reshape(-1, 10)),
"mod2": AnnData(np.arange(101, 2101, 1).reshape(-1, 20)),
}
)
mdata.var_names_make_unique()
yield mdata


class TestScatter:
def test_pl_scatter(self, mdata):
mdata = mdata.copy()
np.random.seed(42)
mdata.obs["condition"] = np.random.choice(["a", "b"], mdata.n_obs)
mu.pl.scatter(mdata, x="mod1:0", y="mod2:0", color="condition")

0 comments on commit eee8df0

Please sign in to comment.