Skip to content

Commit

Permalink
fix: updated draw_simplices to make it consistent with draw_hyperedges
Browse files Browse the repository at this point in the history
  • Loading branch information
maximelucas committed Oct 16, 2023
1 parent 912a279 commit 27779d3
Showing 1 changed file with 48 additions and 88 deletions.
136 changes: 48 additions & 88 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..algorithms import max_edge_order, unique_edge_sizes
from ..core import DiHypergraph, Hypergraph, SimplicialComplex
from ..exception import XGIError
from ..utils import subfaces
from .draw_utils import (
_CCW_sort,
_color_arg_to_dict,
Expand Down Expand Up @@ -213,20 +214,27 @@ def draw(
max_order = max_edge_order(H)

if isinstance(H, SimplicialComplex):
ax = draw_simplices(
ax, (dyad_collection, edge_collection) = draw_simplices(
SC=H,
pos=pos,
ax=ax,
dyad_color=dyad_color,
dyad_lw=dyad_lw,
dyad_style=dyad_style,
dyad_color_cmap=dyad_color_cmap,
dyad_vmin=dyad_vmin,
dyad_vmax=dyad_vmax,
alpha=alpha,
edge_fc=edge_fc,
edge_fc_cmap=edge_fc_cmap,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
max_order=max_order,
settings=settings,
hyperedge_labels=hyperedge_labels,
rescale_sizes=rescale_sizes,
**kwargs,
)
dyad_collection = None # for compatibility with simplices until update
edge_collection = None # for compatibility with simplices until update

elif isinstance(H, Hypergraph):

ax, (dyad_collection, edge_collection) = draw_hyperedges(
Expand Down Expand Up @@ -685,10 +693,19 @@ def draw_simplices(
ax=None,
dyad_color="black",
dyad_lw=1.5,
dyad_style="solid",
dyad_color_cmap="Greys",
dyad_vmin=None,
dyad_vmax=None,
edge_fc=None,
edge_fc_cmap="crest_r",
edge_vmin=None,
edge_vmax=None,
alpha=0.4,
max_order=None,
settings=None,
params=dict(),
hyperedge_labels=False,
rescale_sizes=True,
**kwargs,
):
"""Draw maximal simplices and pairwise faces.
Expand Down Expand Up @@ -761,94 +778,37 @@ def draw_simplices(
# Plot only the maximal simplices, thus let's convert the SC to H
H_ = convert.from_max_simplices(SC)

# add the projected pairwise interactions
dyads = subfaces(H_.edges.members(), order=1)
H_.add_edges_from(dyads)
H_.cleanup(multiedges=False, isolates=True, connected=False, relabel=False, in_place=True, singletons=True) # remove multi-dyads

if not max_order:
max_order = max_edge_order(H_)

ax, pos = _draw_init(H_, ax, pos)

if edge_fc is None:
edge_fc = H_.edges.size

if settings is None:
settings = {
"min_dyad_lw": 2.0,
"max_dyad_lw": 10.0,
"edge_fc_cmap": cm.Blues,
"dyad_color_cmap": cm.Greys,
}

settings.update(kwargs)

dyad_color = _color_arg_to_dict(dyad_color, H_.edges, settings["dyad_color_cmap"])
dyad_lw = _scalar_arg_to_dict(
dyad_lw,
H_.edges,
settings["min_dyad_lw"],
settings["max_dyad_lw"],
ax, (dyad_collection, edge_collection) = draw_hyperedges(
H_,
pos=pos,
ax=ax,
dyad_color=dyad_color,
dyad_lw=dyad_lw,
dyad_style=dyad_style,
dyad_color_cmap=dyad_color_cmap,
dyad_vmin=dyad_vmin,
dyad_vmax=dyad_vmax,
edge_fc=edge_fc,
edge_fc_cmap=edge_fc_cmap,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
alpha=alpha,
max_order=max_order,
params=params,
hyperedge_labels=hyperedge_labels,
rescale_sizes=rescale_sizes,
**kwargs,
)

edge_fc = _color_arg_to_dict(edge_fc, H_.edges, settings["edge_fc_cmap"])

# Looping over the hyperedges of different order (reversed) -- nodes will be plotted
# separately
for id, he in H_.edges.members(dtype=dict).items():
d = len(he) - 1

if d == 1:
# Drawing the edges
he = list(he)
x_coords = [pos[he[0]][0], pos[he[1]][0]]
y_coords = [pos[he[0]][1], pos[he[1]][1]]

line = plt.Line2D(
x_coords,
y_coords,
color=dyad_color[id],
lw=dyad_lw[id],
zorder=max_order - 1,
)
ax.add_line(line)
else:
# Hyperedges of order d (d=2: triangles, etc.)
# Filling the polygon
coordinates = [[pos[n][0], pos[n][1]] for n in he]
# Sorting the points counterclockwise (needed to have the correct filling)
sorted_coordinates = _CCW_sort(coordinates)
obj = plt.Polygon(
sorted_coordinates,
facecolor=edge_fc[id],
alpha=0.4,
zorder=max_order - d,
)
ax.add_patch(obj)
# Drawing all the edges within
for i, j in combinations(sorted_coordinates, 2):
x_coords = [i[0], j[0]]
y_coords = [i[1], j[1]]
line = plt.Line2D(
x_coords,
y_coords,
color=dyad_color[id],
lw=dyad_lw[id],
zorder=max_order - 1,
)
ax.add_line(line)

if hyperedge_labels:
# Get all valid keywords by inspecting the signatures of draw_node_labels
valid_label_kwds = signature(draw_hyperedge_labels).parameters.keys()
# Remove the arguments of this function (draw_networkx)
valid_label_kwds = valid_label_kwds - {"H", "pos", "ax", "hyperedge_labels"}
if any([k not in valid_label_kwds for k in kwargs]):
invalid_args = ", ".join([k for k in kwargs if k not in valid_label_kwds])
raise ValueError(f"Received invalid argument(s): {invalid_args}")
label_kwds = {k: v for k, v in kwargs.items() if k in valid_label_kwds}
draw_hyperedge_labels(H_, pos, hyperedge_labels, ax_edges=ax, **label_kwds)

# compute axis limits
_update_lims(pos, ax)

return ax
return ax, (dyad_collection, edge_collection)


def draw_node_labels(
Expand Down

0 comments on commit 27779d3

Please sign in to comment.