Skip to content

Commit

Permalink
Added edge_ec argument in draw to specify edge colors (#575)
Browse files Browse the repository at this point in the history
* feat: added edge_ec color to specify edge colors (mostly useful when hull=True)

* tutos: added example with hull and contours + reran notebook that had still had old plotting defaults

* test: initial for edge_ec

* fix: typo from review
  • Loading branch information
maximelucas authored Aug 30, 2024
1 parent 9492352 commit f25e0ca
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 172 deletions.
27 changes: 24 additions & 3 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,27 @@ def test_draw_hyperedges_fc_cmap(edgelist8):
plt.close()


def test_draw_hyperedges_ec(edgelist8):
# implemented in PR #575

H = xgi.Hypergraph(edgelist8)

colors = np.array([[0.6468274 , 0.80289262, 0.56592265, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4]])

# edge stat color
fig, ax = plt.subplots()
ax, collections = xgi.draw_hyperedges(H,ax=ax, edge_ec=H.edges.size, edge_fc="w")
(_, edge_collection) = collections

assert np.all(edge_collection.get_edgecolor() == colors)
plt.close("all")


def test_draw_simplices(edgelist8):
with pytest.raises(XGIError):
H = xgi.Hypergraph(edgelist8)
Expand Down Expand Up @@ -684,16 +705,16 @@ def test_draw_undirected_dyads(edgelist8):
H = xgi.Hypergraph(edgelist8)

fig, ax = plt.subplots()
ax, dyad_collection = xgi.draw_undirected_dyads(H)
ax, dyad_collection = xgi.draw_undirected_dyads(H, ax=ax)
assert len(dyad_collection._paths) == 26 # number of lines

with pytest.raises(ValueError):
fig, ax = plt.subplots()
ax, dyad_collection = xgi.draw_undirected_dyads(H, dyad_lw=-1)
ax, dyad_collection = xgi.draw_undirected_dyads(H, dyad_lw=-1, ax=ax)

fig, ax = plt.subplots()
ax, dyad_collection = xgi.draw_undirected_dyads(
H, dyad_color=np.random.random(H.num_edges)
H, dyad_color=np.random.random(H.num_edges), ax=ax
)
assert len(np.unique(dyad_collection.get_color())) == 28
plt.close("all")
Expand Down
226 changes: 128 additions & 98 deletions tutorials/focus/Tutorial 5 - Plotting.ipynb

Large diffs are not rendered by default.

45 changes: 36 additions & 9 deletions tutorials/getting_started/XGI in 1 minute.ipynb

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions tutorials/getting_started/XGI in 15 minutes.ipynb

Large diffs are not rendered by default.

71 changes: 49 additions & 22 deletions tutorials/getting_started/XGI in 5 minutes.ipynb

Large diffs are not rendered by default.

102 changes: 80 additions & 22 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def draw(
edge_fc_cmap="crest_r",
edge_vmin=None,
edge_vmax=None,
edge_ec=None,
alpha=0.4,
hull=False,
radius=0.05,
Expand Down Expand Up @@ -166,6 +167,23 @@ def draw(
Colormap used to map the edge colors. By default, "cres_r".
edge_vmin, edge_vmax : float, optional
Minimum and maximum for edge colormap scaling. By default, None.
edge_ec : color or list of colors or array-like or dict or EdgeStat, optional
Color of the hyperedges. The accepted formats are the same as
matplotlib's scatter, with the addition of dict and IDStat.
Formats with colors:
* single color as a string
* single color as 3- or 4-tuple
* list of colors of length len(ids)
* dict of colors containing the `ids` as keys
Formats with numerical values (will be mapped to colors):
* array of floats
* dict of floats containing the `ids` as keys
* IDStat containing the `ids` as keys
If None (default), color by edge size.
Numerical formats will be mapped to colors using edge_vmin, edge_vmax,
and edge_fc_cmap.
alpha : float, optional
The edge transparency. By default, 0.4.
hull : bool, optional
Expand Down Expand Up @@ -262,6 +280,7 @@ def draw(
edge_fc_cmap=edge_fc_cmap,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
edge_ec=edge_ec,
max_order=max_order,
hyperedge_labels=hyperedge_labels,
rescale_sizes=rescale_sizes,
Expand All @@ -285,6 +304,7 @@ def draw(
edge_fc_cmap=edge_fc_cmap,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
edge_ec=edge_ec,
max_order=max_order,
hyperedge_labels=hyperedge_labels,
hull=hull,
Expand Down Expand Up @@ -523,6 +543,7 @@ def draw_hyperedges(
edge_fc_cmap="crest_r",
edge_vmin=None,
edge_vmax=None,
edge_ec=None,
alpha=0.4,
max_order=None,
params=dict(),
Expand Down Expand Up @@ -566,13 +587,13 @@ def draw_hyperedges(
edge_fc : color or list of colors or array-like or dict or EdgeStat, optional
Color of the hyperedges. The accepted formats are the same as
matplotlib's scatter, with the addition of dict and IDStat.
Those with colors:
Formats with colors:
* single color as a string
* single color as 3- or 4-tuple
* list of colors of length len(ids)
* dict of colors containing the `ids` as keys
Those with numerical values (will be mapped to colors):
Formats with numerical values (will be mapped to colors):
* array of floats
* dict of floats containing the `ids` as keys
* IDStat containing the `ids` as keys
Expand All @@ -582,6 +603,23 @@ def draw_hyperedges(
Colormap used to map the edge colors. By default, "crest_r".
edge_vmin, edge_vmax : float, optional
Minimum and maximum for edge colormap scaling. By default, None.
edge_ec : color or list of colors or array-like or dict or EdgeStat, optional
Color of the hyperedges. The accepted formats are the same as
matplotlib's scatter, with the addition of dict and IDStat.
Formats with colors:
* single color as a string
* single color as 3- or 4-tuple
* list of colors of length len(ids)
* dict of colors containing the `ids` as keys
Formats with numerical values (will be mapped to colors):
* array of floats
* dict of floats containing the `ids` as keys
* IDStat containing the `ids` as keys
If None (default), color by edge size.
Numerical formats will be mapped to colors using edge_vmin, edge_vmax,
and edge_fc_cmap.
alpha : float, optional
The edge transparency. By default, 0.4.
max_order : int, optional
Expand Down Expand Up @@ -650,13 +688,18 @@ def draw_hyperedges(

if edge_fc is None: # color is proportional to size
edge_fc = edges.size
if edge_ec is None: # color is proportional to size
edge_ec = edges.size

# convert all formats to ndarray
dyad_lw = _draw_arg_to_arr(dyad_lw)

# parse colors
dyad_color, dyad_c_mapped = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_mapped = _parse_color_arg(edge_fc, list(edges))
dyad_color, dyad_c_to_map = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_to_map = _parse_color_arg(edge_fc, list(edges))
edge_ec, edge_ec_to_map = _parse_color_arg(edge_ec, list(edges))
# edge_c_to_map and dyad_c_to_map are True if the colors
# are input as numeric values that need to be mapped to colors

# check validity of input values
if np.any(dyad_lw < 0):
Expand All @@ -672,7 +715,7 @@ def draw_hyperedges(
dyad_pos = np.asarray([(pos[list(e)[0]], pos[list(e)[1]]) for e in dyads.members()])

# plot dyads
if dyad_c_mapped:
if dyad_c_to_map:
dyad_c_arr = dyad_color
dyad_colors = None
else:
Expand All @@ -682,7 +725,7 @@ def draw_hyperedges(
dyad_collection = LineCollection(
dyad_pos,
colors=dyad_colors,
array=dyad_c_arr, # colors if mapped, ie arr of floats
array=dyad_c_arr, # colors if to be mapped, ie arr of floats
linewidths=dyad_lw,
antialiaseds=(1,),
linestyle=dyad_style,
Expand All @@ -691,7 +734,7 @@ def draw_hyperedges(
)

# dyad_collection.set_cmap(dyad_color_cmap)
if dyad_c_mapped:
if dyad_c_to_map:
dyad_collection.set_clim(dyad_vmin, dyad_vmax)
# dyad_collection.set_zorder(max_order - 1) # edges go behind nodes
ax.add_collection(dyad_collection)
Expand All @@ -700,13 +743,27 @@ def draw_hyperedges(
ids_sorted = np.argsort(edges.size.aslist())[::-1]

# plot other hyperedges
if edge_c_mapped:

# prepare colors for PatchCollection format
if edge_c_to_map:
edge_fc_arr = edge_fc[ids_sorted]
edge_fc_colors = None
else:
edge_fc_arr = None
edge_fc_colors = edge_fc[ids_sorted] if len(edge_fc) > 1 else edge_fc


edge_ec = edge_ec[ids_sorted] if len(edge_ec) > 1 else edge_ec # reorder

if edge_ec_to_map: # edgecolors need to be manually mapped

# create scalarmappable to map floats to colors
# we use the same vmin, vmax, and cmap as for edge_fc
norm = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
sm_edgecolors = cm.ScalarMappable(norm=norm, cmap=edge_fc_cmap)

edge_ec = sm_edgecolors.to_rgba(edge_ec) # map to colors

patches = []
for he in np.array(edges.members())[ids_sorted]:
d = len(he) - 1
Expand All @@ -733,13 +790,14 @@ def draw_hyperedges(
edge_collection = PatchCollection(
patches,
facecolors=edge_fc_colors,
array=edge_fc_arr,
array=edge_fc_arr, # will be mapped by PatchCollection
cmap=edge_fc_cmap,
edgecolors=edge_ec,
alpha=alpha,
zorder=max_order - 2, # below dyads
)
# edge_collection.set_cmap(edge_fc_cmap)
if edge_c_mapped:
if edge_c_to_map:
edge_collection.set_clim(edge_vmin, edge_vmax)
ax.add_collection(edge_collection)

Expand Down Expand Up @@ -1379,9 +1437,9 @@ def draw_multilayer(
raise ValueError("dyad_lw cannot contain negative values.")

# parse colors
dyad_color, dyad_c_mapped = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_mapped = _parse_color_arg(edge_fc, list(edges))
layer_color, layer_c_mapped = _parse_color_arg(layer_color, orders)
dyad_color, dyad_c_to_map = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_to_map = _parse_color_arg(edge_fc, list(edges))
layer_color, layer_c_to_map = _parse_color_arg(layer_color, orders)

node_size = np.array(node_size) ** 2

Expand All @@ -1402,7 +1460,7 @@ def draw_multilayer(
# draw surfaces corresponding to the different orders
zz = np.zeros(xx.shape) + d * sep

if layer_c_mapped:
if layer_c_to_map:
layer_c = None
else:
layer_c = layer_color[jj] if len(layer_color) > 1 else layer_color
Expand All @@ -1427,7 +1485,7 @@ def draw_multilayer(
]

# plot dyads
if dyad_c_mapped:
if dyad_c_to_map:
raise ValueError(
"dyad_color needs to be a color or list of colors, not numerical values."
)
Expand All @@ -1447,7 +1505,7 @@ def draw_multilayer(
ids_sorted = np.argsort(edges.size.aslist())[::-1]

# plot other hyperedges
if edge_c_mapped:
if edge_c_to_map:
edge_fc_arr = edge_fc[ids_sorted]
edge_fc_colors = None
else:
Expand All @@ -1474,7 +1532,7 @@ def draw_multilayer(
zorder=max_order - 2, # below dyads
)
edge_collection.set_cmap(edge_fc_cmap)
if edge_c_mapped:
if edge_c_to_map:
edge_collection.set_clim(edge_vmin, edge_vmax)
ax.add_collection3d(edge_collection)

Expand Down Expand Up @@ -1963,7 +2021,7 @@ def draw_undirected_dyads(
)

# parse colors
dyad_color, dyads_c_mapped = _parse_color_arg(dyad_color, H.edges)
dyad_color, dyads_c_to_map = _parse_color_arg(dyad_color, H.edges)

# The following two list comprehensions map colors assigned to a hyperedge to
# all of the bipartite edges, so that users need not specify colors for every
Expand All @@ -1986,7 +2044,7 @@ def draw_undirected_dyads(
)

# convert numbers to colors for FancyArrowPatch
if dyads_c_mapped:
if dyads_c_to_map:
norm = mpl.colors.Normalize()
m = cm.ScalarMappable(norm=norm, cmap=dyad_color_cmap)
dyad_color = m.to_rgba(dyad_color)
Expand Down Expand Up @@ -2160,10 +2218,10 @@ def draw_directed_dyads(
)

# parse colors
dyad_color, dyads_c_mapped = _parse_color_arg(dyad_color, H.edges)
dyad_color, dyads_c_to_map = _parse_color_arg(dyad_color, H.edges)

# convert numbers to colors for FancyArrowPatch
if dyads_c_mapped:
if dyads_c_to_map:
norm = mpl.colors.Normalize()
m = cm.ScalarMappable(norm=norm, cmap=dyad_color_cmap)
dyad_color = m.to_rgba(dyad_color)
Expand Down Expand Up @@ -2233,7 +2291,7 @@ def to_marker_edge(marker_size, marker):
else:
dlw = dyad_lw

if dyads_c_mapped:
if dyads_c_to_map:
d_color = dyad_color[edge_to_idx[e]]
else:
d_color = dyad_color
Expand Down
14 changes: 8 additions & 6 deletions xgi/drawing/draw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _parse_color_arg(colors, ids, id_kind="edges"):
This function is needed to handle the input formats not naturally
handled by matploltib's Collections: IDStat, dict, and arrays of
floats. All those are converted to arrays of floats and.
floats. All those numerical formats are converted to arrays of floats.
Parameters:
-----------
Expand All @@ -103,8 +103,8 @@ def _parse_color_arg(colors, ids, id_kind="edges"):
--------
colors : single color or ndarray
Processed color values for plotting.
colors_are_mapped : bool
True if the colors are mapped and need special handling. This
colors_to_map : bool
True if the colors need to be mapped and need special handling. This
is used in draw_hyperedges to deal with Collections.
Raises:
Expand All @@ -127,6 +127,7 @@ def _parse_color_arg(colors, ids, id_kind="edges"):

xsize = len(ids)

# convert all dict-like input formats to an array
if isinstance(colors, IDStat):
colors = colors.asdict()
if isinstance(colors, dict):
Expand All @@ -135,13 +136,14 @@ def _parse_color_arg(colors, ids, id_kind="edges"):
values = list(colors.values())
colors = np.array(values)

# see if input format needs to be mapped to colors (if numeric)
try: # see if the input format is compatible with PatchCollection's facecolor
colors = to_rgba_array(colors)
colors_are_mapped = False
colors_to_map = False
except:
try: # in case of array of floats (can be fed to PatchCollection with some care)
colors = np.asanyarray(colors, dtype=float)
colors_are_mapped = True
colors_to_map = True
except:
raise ValueError("Invalid input format for colors.")

Expand All @@ -150,7 +152,7 @@ def _parse_color_arg(colors, ids, id_kind="edges"):
f"The input color argument must be a single color or its length must match the number of plotted elements ({xsize})."
)

return colors, colors_are_mapped
return colors, colors_to_map


def _draw_arg_to_arr(arg):
Expand Down

0 comments on commit f25e0ca

Please sign in to comment.