diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index b999ab41e..0b12722e4 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -21,6 +21,7 @@ from ..algorithms import max_edge_order, unique_edge_sizes from ..core import DiHypergraph, Hypergraph, SimplicialComplex from ..exception import XGIError +from ..utils import subfaces from .draw_utils import ( _CCW_sort, _color_arg_to_dict, @@ -213,20 +214,27 @@ def draw( max_order = max_edge_order(H) if isinstance(H, SimplicialComplex): - ax = draw_simplices( + ax, (dyad_collection, edge_collection) = draw_simplices( SC=H, pos=pos, ax=ax, dyad_color=dyad_color, dyad_lw=dyad_lw, + dyad_style=dyad_style, + dyad_color_cmap=dyad_color_cmap, + dyad_vmin=dyad_vmin, + dyad_vmax=dyad_vmax, + alpha=alpha, edge_fc=edge_fc, + edge_fc_cmap=edge_fc_cmap, + edge_vmin=edge_vmin, + edge_vmax=edge_vmax, max_order=max_order, - settings=settings, hyperedge_labels=hyperedge_labels, + rescale_sizes=rescale_sizes, **kwargs, ) - dyad_collection = None # for compatibility with simplices until update - edge_collection = None # for compatibility with simplices until update + elif isinstance(H, Hypergraph): ax, (dyad_collection, edge_collection) = draw_hyperedges( @@ -685,10 +693,19 @@ def draw_simplices( ax=None, dyad_color="black", dyad_lw=1.5, + dyad_style="solid", + dyad_color_cmap="Greys", + dyad_vmin=None, + dyad_vmax=None, edge_fc=None, + edge_fc_cmap="crest_r", + edge_vmin=None, + edge_vmax=None, + alpha=0.4, max_order=None, - settings=None, + params=dict(), hyperedge_labels=False, + rescale_sizes=True, **kwargs, ): """Draw maximal simplices and pairwise faces. @@ -761,94 +778,37 @@ def draw_simplices( # Plot only the maximal simplices, thus let's convert the SC to H H_ = convert.from_max_simplices(SC) + # add the projected pairwise interactions + dyads = subfaces(H_.edges.members(), order=1) + H_.add_edges_from(dyads) + H_.cleanup(multiedges=False, isolates=True, connected=False, relabel=False, in_place=True, singletons=True) # remove multi-dyads + if not max_order: max_order = max_edge_order(H_) - ax, pos = _draw_init(H_, ax, pos) - - if edge_fc is None: - edge_fc = H_.edges.size - - if settings is None: - settings = { - "min_dyad_lw": 2.0, - "max_dyad_lw": 10.0, - "edge_fc_cmap": cm.Blues, - "dyad_color_cmap": cm.Greys, - } - - settings.update(kwargs) - - dyad_color = _color_arg_to_dict(dyad_color, H_.edges, settings["dyad_color_cmap"]) - dyad_lw = _scalar_arg_to_dict( - dyad_lw, - H_.edges, - settings["min_dyad_lw"], - settings["max_dyad_lw"], + ax, (dyad_collection, edge_collection) = draw_hyperedges( + H_, + pos=pos, + ax=ax, + dyad_color=dyad_color, + dyad_lw=dyad_lw, + dyad_style=dyad_style, + dyad_color_cmap=dyad_color_cmap, + dyad_vmin=dyad_vmin, + dyad_vmax=dyad_vmax, + edge_fc=edge_fc, + edge_fc_cmap=edge_fc_cmap, + edge_vmin=edge_vmin, + edge_vmax=edge_vmax, + alpha=alpha, + max_order=max_order, + params=params, + hyperedge_labels=hyperedge_labels, + rescale_sizes=rescale_sizes, + **kwargs, ) - edge_fc = _color_arg_to_dict(edge_fc, H_.edges, settings["edge_fc_cmap"]) - - # Looping over the hyperedges of different order (reversed) -- nodes will be plotted - # separately - for id, he in H_.edges.members(dtype=dict).items(): - d = len(he) - 1 - - if d == 1: - # Drawing the edges - he = list(he) - x_coords = [pos[he[0]][0], pos[he[1]][0]] - y_coords = [pos[he[0]][1], pos[he[1]][1]] - - line = plt.Line2D( - x_coords, - y_coords, - color=dyad_color[id], - lw=dyad_lw[id], - zorder=max_order - 1, - ) - ax.add_line(line) - else: - # Hyperedges of order d (d=2: triangles, etc.) - # Filling the polygon - coordinates = [[pos[n][0], pos[n][1]] for n in he] - # Sorting the points counterclockwise (needed to have the correct filling) - sorted_coordinates = _CCW_sort(coordinates) - obj = plt.Polygon( - sorted_coordinates, - facecolor=edge_fc[id], - alpha=0.4, - zorder=max_order - d, - ) - ax.add_patch(obj) - # Drawing all the edges within - for i, j in combinations(sorted_coordinates, 2): - x_coords = [i[0], j[0]] - y_coords = [i[1], j[1]] - line = plt.Line2D( - x_coords, - y_coords, - color=dyad_color[id], - lw=dyad_lw[id], - zorder=max_order - 1, - ) - ax.add_line(line) - - if hyperedge_labels: - # Get all valid keywords by inspecting the signatures of draw_node_labels - valid_label_kwds = signature(draw_hyperedge_labels).parameters.keys() - # Remove the arguments of this function (draw_networkx) - valid_label_kwds = valid_label_kwds - {"H", "pos", "ax", "hyperedge_labels"} - if any([k not in valid_label_kwds for k in kwargs]): - invalid_args = ", ".join([k for k in kwargs if k not in valid_label_kwds]) - raise ValueError(f"Received invalid argument(s): {invalid_args}") - label_kwds = {k: v for k, v in kwargs.items() if k in valid_label_kwds} - draw_hyperedge_labels(H_, pos, hyperedge_labels, ax_edges=ax, **label_kwds) - - # compute axis limits - _update_lims(pos, ax) - - return ax + return ax, (dyad_collection, edge_collection) def draw_node_labels(