Skip to content

Commit

Permalink
fix: settings can now be changed with single values
Browse files Browse the repository at this point in the history
  • Loading branch information
maximelucas committed Oct 17, 2023
1 parent b14fd9e commit f178c23
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 56 deletions.
32 changes: 30 additions & 2 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,12 @@ def test_correct_number_of_collections_draw_multilayer(edgelist8):
def test_draw_dihypergraph(diedgelist2, edgelist8):
DH = xgi.DiHypergraph(diedgelist2)

fig, ax1 = plt.subplots()
fig1, ax1 = plt.subplots()
ax1 = xgi.draw_dihypergraph(DH, ax=ax1)
fig2, ax2 = plt.subplots()
ax2 = xgi.draw_dihypergraph(
DH, ax=ax2, node_fc="red", node_ec="blue", node_lw=2, node_size=20
)

# number of elements
assert len(ax1.lines) == 7 # number of source nodes
Expand All @@ -471,6 +475,30 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
DH.edges.filterby("size", 1)
) # hyperedges markers + nodes

# node face colors
assert np.all(
ax1.collections[-1].get_facecolor() == np.array([[1, 1, 1, 1]])
) # white
assert np.all(
ax2.collections[-1].get_facecolor() == np.array([[1, 0, 0, 1]])
) # red

# node edge colors
assert np.all(
ax1.collections[-1].get_edgecolor() == np.array([[0, 0, 0, 1]])
) # black
assert np.all(
ax2.collections[-1].get_edgecolor() == np.array([[0, 0, 1, 1]])
) # blue

# node_lw
assert np.all(ax1.collections[-1].get_linewidth() == np.array([1]))
assert np.all(ax2.collections[-1].get_linewidth() == np.array([2]))

# node_size
assert np.all(ax1.collections[-1].get_sizes() == np.array([15**2]))
assert np.all(ax2.collections[-1].get_sizes() == np.array([20**2]))

# zorder
for line, z in zip(ax1.lines, [1, 1, 1, 1, 0, 0, 0]): # lines for source nodes
assert line.get_zorder() == z
Expand All @@ -479,7 +507,7 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
for collection in ax1.collections:
assert collection.get_zorder() == 3 # nodes and hyperedges markers

plt.close()
plt.close("all")

# test toggle for edges
fig, ax2 = plt.subplots()
Expand Down
119 changes: 65 additions & 54 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,14 +1074,17 @@ def _draw_hull(node_pos, ax, edges_ec, facecolor, alpha, zorder, radius):
Parameters
----------
node_pos : np.array
nx2 dimensional array containing positions of the nodes
Array of dimension (n, 2) containing node positions
ax : matplotlib.pyplot.axes
Axis to plot on
edges_ec : str
Color of the border of the convex hull
facecolor : str
Filling color of the convex hull
alpha : float
Transparency of the convex hull
zorder : float
Vertical order on which to plot
radius : float
Radius of the convex hull in the vicinity of the nodes.
Expand Down Expand Up @@ -1433,6 +1436,13 @@ def draw_multilayer(
-------
ax : matplotlib Axes3DSubplot
The subplot with the multilayer network visualization.
Notes
-----
The effect of the `sep` parameter is limited by the `height` of the figure.
If `sep` is larger than a certain value depending on `height`, no additional
effect will be seen.
"""
settings = {
"min_node_size": 10.0,
Expand Down Expand Up @@ -1602,13 +1612,14 @@ def draw_multilayer(
def draw_dihypergraph(
DH,
ax=None,
lines_fc=None,
lines_lw=1.5,
line_head_width=0.05,
node_fc="white",
node_ec="black",
node_lw=1,
node_size=15,
node_fc_cmap="Reds",
lines_fc=None,
lines_lw=1.5,
line_head_width=0.05,
edge_marker_toggle=True,
edge_marker_fc=None,
edge_marker_ec=None,
Expand All @@ -1619,6 +1630,7 @@ def draw_dihypergraph(
node_labels=False,
hyperedge_labels=False,
settings=None,
rescale_sizes=True,
**kwargs,
):
"""Draw a directed hypergraph
Expand All @@ -1629,18 +1641,6 @@ def draw_dihypergraph(
The directed hypergraph to draw.
ax : matplotlib.pyplot.axes, optional
Axis to draw on. If None (default), get the current axes.
lines_fc : str, dict, iterable, optional
Color of the hyperedges (lines). If str, use the same color for all hyperedges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
lines_lw : int, float, dict, iterable, optional
Line width of the hyperedges (lines). If int or float, use the same width for
all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other
iterable, assume the widths are specified in the same order as the hyperedges
are found in DH.edges. By default, 1.5.
line_head_width : float, optional
Length of arrows' heads. By default, 0.05
node_fc : str, dict, iterable, or NodeStat, optional
Color of the nodes. If str, use the same color for all nodes. If a dict, must
contain (node_id: color_str) pairs. If other iterable, assume the colors are
Expand All @@ -1663,6 +1663,21 @@ def draw_dihypergraph(
the radiuses are specified in the same order as the nodes are found in
H.nodes. If NodeStat, use a monotonic linear interpolation defined between
min_node_size and max_node_size. By default, 15.
node_fc_cmap : colormap
Colormap for mapping node colors. By default, "Reds". Ignored, if `node_fc` is
a str (single color).
lines_fc : str, dict, iterable, optional
Color of the hyperedges (lines). If str, use the same color for all hyperedges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
lines_lw : int, float, dict, iterable, optional
Line width of the hyperedges (lines). If int or float, use the same width for
all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other
iterable, assume the widths are specified in the same order as the hyperedges
are found in DH.edges. By default, 1.5.
line_head_width : float, optional
Length of arrows' heads. By default, 0.05
edge_marker_toggle: bool, optional
If True then marker representing the hyperedges are drawn. By default True.
edge_marker_fc: str, dict, iterable, optional
Expand Down Expand Up @@ -1691,14 +1706,11 @@ def draw_dihypergraph(
* max_node_size
* min_node_lw
* max_node_lw
* node_fc_cmap
* node_ec_cmap
* min_lines_lw
* max_lines_lw
* lines_fc_cmap
* edge_fc_cmap
* edge_marker_fc_cmap
* edge_marker_ec_cmap
Returns
-------
Expand All @@ -1719,36 +1731,39 @@ def draw_dihypergraph(
if not isinstance(DH, DiHypergraph):
raise XGIError("The input must be a DiHypergraph")

if settings is None:
settings = {
"min_node_size": 10.0,
"max_node_size": 30.0,
"min_node_lw": 1.0,
"max_node_lw": 5.0,
"node_fc_cmap": cm.Reds,
"node_ec_cmap": cm.Greys,
"min_lines_lw": 2.0,
"max_lines_lw": 10.0,
"lines_fc_cmap": cm.Blues,
"edge_marker_fc_cmap": cm.Blues,
"edge_marker_ec_cmap": cm.Greys,
}
settings = {
"min_node_size": 5,
"max_node_size": 30,
"min_node_lw": 0,
"max_node_lw": 5,
"min_lines_lw": 2.0,
"max_lines_lw": 10.0,
"lines_fc_cmap": plt.cm.Blues,
"edge_marker_fc_cmap": plt.cm.Blues,
}

settings.update(kwargs)

if ax is None:
ax = plt.gca()

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")

# convert to hypergraph in order to use the augmented projection function
H_conv = convert.convert_to_hypergraph(DH)

(
ax,
_,
) = _draw_init(H_conv, ax, True)

if not max_order:
max_order = max_edge_order(H_conv)

# convert all formats to ndarray
node_size = _draw_arg_to_arr(node_size)

# interpolate if needed
if rescale_sizes and isinstance(node_size, np.ndarray):
node_size = _interp_draw_arg(
node_size, settings["min_node_size"], settings["max_node_size"]
)

lines_lw = _scalar_arg_to_dict(
lines_lw, H_conv.edges, settings["min_lines_lw"], settings["max_lines_lw"]
)
Expand All @@ -1765,17 +1780,6 @@ def draw_dihypergraph(
edge_marker_fc, H_conv.edges, settings["edge_marker_fc_cmap"]
)

if edge_marker_ec is None:
edge_marker_ec = H_conv.edges.size

edge_marker_ec = _color_arg_to_dict(
edge_marker_ec, H_conv.edges, settings["edge_marker_ec_cmap"]
)

node_size = _scalar_arg_to_dict(
node_size, H_conv.nodes, settings["min_node_size"], settings["max_node_size"]
)

G_aug = _augmented_projection(H_conv)
for dyad in H_conv.edges.filterby("size", 2).members():
try:
Expand Down Expand Up @@ -1810,8 +1814,13 @@ def draw_dihypergraph(
# the following to avoid the point of the arrow overlapping the node
distance = np.hypot(dx, dy)
direction_vector = np.array([dx, dy]) / distance
size = (
node_size
if not isinstance(node_size, np.ndarray)
else node_size[node]
)
shortened_distance = (
distance - node_size[node] * 0.003
distance - size * 0.003
) # Calculate the shortened length
dx = direction_vector[0] * shortened_distance
dy = direction_vector[1] * shortened_distance
Expand All @@ -1834,7 +1843,7 @@ def draw_dihypergraph(
marker=edge_marker,
s=edge_marker_size**2,
c=edge_marker_fc[id],
edgecolors=edge_marker_ec[id],
edgecolors=edge_marker_ec,
linewidths=edge_marker_lw,
zorder=max_order,
)
Expand All @@ -1852,17 +1861,19 @@ def draw_dihypergraph(
label_kwds["font_size_edges"] = 6
draw_hyperedge_labels(H_conv, pos, hyperedge_labels, ax_edges=ax, **label_kwds)

draw_nodes(
ax, node_collection = draw_nodes(
H=H_conv,
pos=pos,
ax=ax,
node_fc=node_fc,
node_ec=node_ec,
node_lw=node_lw,
node_size=node_size,
# node_shape=node_shape,
zorder=max_order,
params=settings,
node_labels=node_labels,
# rescale_sizes=rescale_sizes,
**kwargs,
)

Expand Down

0 comments on commit f178c23

Please sign in to comment.