From f178c2398f3e7545e44f292bfb67c42dd4d0b24b Mon Sep 17 00:00:00 2001 From: Maxime Lucas Date: Tue, 17 Oct 2023 19:59:20 +0200 Subject: [PATCH] fix: settings can now be changed with single values --- tests/drawing/test_draw.py | 32 +++++++++- xgi/drawing/draw.py | 119 ++++++++++++++++++++----------------- 2 files changed, 95 insertions(+), 56 deletions(-) diff --git a/tests/drawing/test_draw.py b/tests/drawing/test_draw.py index 7c365a503..4769c0896 100644 --- a/tests/drawing/test_draw.py +++ b/tests/drawing/test_draw.py @@ -461,8 +461,12 @@ def test_correct_number_of_collections_draw_multilayer(edgelist8): def test_draw_dihypergraph(diedgelist2, edgelist8): DH = xgi.DiHypergraph(diedgelist2) - fig, ax1 = plt.subplots() + fig1, ax1 = plt.subplots() ax1 = xgi.draw_dihypergraph(DH, ax=ax1) + fig2, ax2 = plt.subplots() + ax2 = xgi.draw_dihypergraph( + DH, ax=ax2, node_fc="red", node_ec="blue", node_lw=2, node_size=20 + ) # number of elements assert len(ax1.lines) == 7 # number of source nodes @@ -471,6 +475,30 @@ def test_draw_dihypergraph(diedgelist2, edgelist8): DH.edges.filterby("size", 1) ) # hyperedges markers + nodes + # node face colors + assert np.all( + ax1.collections[-1].get_facecolor() == np.array([[1, 1, 1, 1]]) + ) # white + assert np.all( + ax2.collections[-1].get_facecolor() == np.array([[1, 0, 0, 1]]) + ) # red + + # node edge colors + assert np.all( + ax1.collections[-1].get_edgecolor() == np.array([[0, 0, 0, 1]]) + ) # black + assert np.all( + ax2.collections[-1].get_edgecolor() == np.array([[0, 0, 1, 1]]) + ) # blue + + # node_lw + assert np.all(ax1.collections[-1].get_linewidth() == np.array([1])) + assert np.all(ax2.collections[-1].get_linewidth() == np.array([2])) + + # node_size + assert np.all(ax1.collections[-1].get_sizes() == np.array([15**2])) + assert np.all(ax2.collections[-1].get_sizes() == np.array([20**2])) + # zorder for line, z in zip(ax1.lines, [1, 1, 1, 1, 0, 0, 0]): # lines for source nodes assert line.get_zorder() == z @@ -479,7 +507,7 @@ def test_draw_dihypergraph(diedgelist2, edgelist8): for collection in ax1.collections: assert collection.get_zorder() == 3 # nodes and hyperedges markers - plt.close() + plt.close("all") # test toggle for edges fig, ax2 = plt.subplots() diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index 523156543..4339b2a8b 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -1074,14 +1074,17 @@ def _draw_hull(node_pos, ax, edges_ec, facecolor, alpha, zorder, radius): Parameters ---------- node_pos : np.array - nx2 dimensional array containing positions of the nodes + Array of dimension (n, 2) containing node positions ax : matplotlib.pyplot.axes + Axis to plot on edges_ec : str Color of the border of the convex hull facecolor : str Filling color of the convex hull alpha : float Transparency of the convex hull + zorder : float + Vertical order on which to plot radius : float Radius of the convex hull in the vicinity of the nodes. @@ -1433,6 +1436,13 @@ def draw_multilayer( ------- ax : matplotlib Axes3DSubplot The subplot with the multilayer network visualization. + + + Notes + ----- + The effect of the `sep` parameter is limited by the `height` of the figure. + If `sep` is larger than a certain value depending on `height`, no additional + effect will be seen. """ settings = { "min_node_size": 10.0, @@ -1602,13 +1612,14 @@ def draw_multilayer( def draw_dihypergraph( DH, ax=None, - lines_fc=None, - lines_lw=1.5, - line_head_width=0.05, node_fc="white", node_ec="black", node_lw=1, node_size=15, + node_fc_cmap="Reds", + lines_fc=None, + lines_lw=1.5, + line_head_width=0.05, edge_marker_toggle=True, edge_marker_fc=None, edge_marker_ec=None, @@ -1619,6 +1630,7 @@ def draw_dihypergraph( node_labels=False, hyperedge_labels=False, settings=None, + rescale_sizes=True, **kwargs, ): """Draw a directed hypergraph @@ -1629,18 +1641,6 @@ def draw_dihypergraph( The directed hypergraph to draw. ax : matplotlib.pyplot.axes, optional Axis to draw on. If None (default), get the current axes. - lines_fc : str, dict, iterable, optional - Color of the hyperedges (lines). If str, use the same color for all hyperedges. - If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable, - assume the colors are specified in the same order as the hyperedges are found - in DH.edges. If None (default), use the size of the hyperedges. - lines_lw : int, float, dict, iterable, optional - Line width of the hyperedges (lines). If int or float, use the same width for - all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other - iterable, assume the widths are specified in the same order as the hyperedges - are found in DH.edges. By default, 1.5. - line_head_width : float, optional - Length of arrows' heads. By default, 0.05 node_fc : str, dict, iterable, or NodeStat, optional Color of the nodes. If str, use the same color for all nodes. If a dict, must contain (node_id: color_str) pairs. If other iterable, assume the colors are @@ -1663,6 +1663,21 @@ def draw_dihypergraph( the radiuses are specified in the same order as the nodes are found in H.nodes. If NodeStat, use a monotonic linear interpolation defined between min_node_size and max_node_size. By default, 15. + node_fc_cmap : colormap + Colormap for mapping node colors. By default, "Reds". Ignored, if `node_fc` is + a str (single color). + lines_fc : str, dict, iterable, optional + Color of the hyperedges (lines). If str, use the same color for all hyperedges. + If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable, + assume the colors are specified in the same order as the hyperedges are found + in DH.edges. If None (default), use the size of the hyperedges. + lines_lw : int, float, dict, iterable, optional + Line width of the hyperedges (lines). If int or float, use the same width for + all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other + iterable, assume the widths are specified in the same order as the hyperedges + are found in DH.edges. By default, 1.5. + line_head_width : float, optional + Length of arrows' heads. By default, 0.05 edge_marker_toggle: bool, optional If True then marker representing the hyperedges are drawn. By default True. edge_marker_fc: str, dict, iterable, optional @@ -1691,14 +1706,11 @@ def draw_dihypergraph( * max_node_size * min_node_lw * max_node_lw - * node_fc_cmap - * node_ec_cmap * min_lines_lw * max_lines_lw * lines_fc_cmap * edge_fc_cmap * edge_marker_fc_cmap - * edge_marker_ec_cmap Returns ------- @@ -1719,36 +1731,39 @@ def draw_dihypergraph( if not isinstance(DH, DiHypergraph): raise XGIError("The input must be a DiHypergraph") - if settings is None: - settings = { - "min_node_size": 10.0, - "max_node_size": 30.0, - "min_node_lw": 1.0, - "max_node_lw": 5.0, - "node_fc_cmap": cm.Reds, - "node_ec_cmap": cm.Greys, - "min_lines_lw": 2.0, - "max_lines_lw": 10.0, - "lines_fc_cmap": cm.Blues, - "edge_marker_fc_cmap": cm.Blues, - "edge_marker_ec_cmap": cm.Greys, - } + settings = { + "min_node_size": 5, + "max_node_size": 30, + "min_node_lw": 0, + "max_node_lw": 5, + "min_lines_lw": 2.0, + "max_lines_lw": 10.0, + "lines_fc_cmap": plt.cm.Blues, + "edge_marker_fc_cmap": plt.cm.Blues, + } settings.update(kwargs) - if ax is None: - ax = plt.gca() - - ax.get_xaxis().set_ticks([]) - ax.get_yaxis().set_ticks([]) - ax.axis("off") - # convert to hypergraph in order to use the augmented projection function H_conv = convert.convert_to_hypergraph(DH) + ( + ax, + _, + ) = _draw_init(H_conv, ax, True) + if not max_order: max_order = max_edge_order(H_conv) + # convert all formats to ndarray + node_size = _draw_arg_to_arr(node_size) + + # interpolate if needed + if rescale_sizes and isinstance(node_size, np.ndarray): + node_size = _interp_draw_arg( + node_size, settings["min_node_size"], settings["max_node_size"] + ) + lines_lw = _scalar_arg_to_dict( lines_lw, H_conv.edges, settings["min_lines_lw"], settings["max_lines_lw"] ) @@ -1765,17 +1780,6 @@ def draw_dihypergraph( edge_marker_fc, H_conv.edges, settings["edge_marker_fc_cmap"] ) - if edge_marker_ec is None: - edge_marker_ec = H_conv.edges.size - - edge_marker_ec = _color_arg_to_dict( - edge_marker_ec, H_conv.edges, settings["edge_marker_ec_cmap"] - ) - - node_size = _scalar_arg_to_dict( - node_size, H_conv.nodes, settings["min_node_size"], settings["max_node_size"] - ) - G_aug = _augmented_projection(H_conv) for dyad in H_conv.edges.filterby("size", 2).members(): try: @@ -1810,8 +1814,13 @@ def draw_dihypergraph( # the following to avoid the point of the arrow overlapping the node distance = np.hypot(dx, dy) direction_vector = np.array([dx, dy]) / distance + size = ( + node_size + if not isinstance(node_size, np.ndarray) + else node_size[node] + ) shortened_distance = ( - distance - node_size[node] * 0.003 + distance - size * 0.003 ) # Calculate the shortened length dx = direction_vector[0] * shortened_distance dy = direction_vector[1] * shortened_distance @@ -1834,7 +1843,7 @@ def draw_dihypergraph( marker=edge_marker, s=edge_marker_size**2, c=edge_marker_fc[id], - edgecolors=edge_marker_ec[id], + edgecolors=edge_marker_ec, linewidths=edge_marker_lw, zorder=max_order, ) @@ -1852,7 +1861,7 @@ def draw_dihypergraph( label_kwds["font_size_edges"] = 6 draw_hyperedge_labels(H_conv, pos, hyperedge_labels, ax_edges=ax, **label_kwds) - draw_nodes( + ax, node_collection = draw_nodes( H=H_conv, pos=pos, ax=ax, @@ -1860,9 +1869,11 @@ def draw_dihypergraph( node_ec=node_ec, node_lw=node_lw, node_size=node_size, + # node_shape=node_shape, zorder=max_order, params=settings, node_labels=node_labels, + # rescale_sizes=rescale_sizes, **kwargs, )