Skip to content

Commit

Permalink
Support DiHypergraph in from_bipartite_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
colltoaction committed Nov 25, 2024
1 parent 71b897f commit 16c2c0e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
16 changes: 16 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,22 @@ def bipartite_graph4():
return G


@pytest.fixture
def bipartite_dihypergraph1():
G = nx.Graph()
G.graph["network-type"] = "directed"
G.add_nodes_from([1, 2, 3, 4], bipartite=0)
G.add_nodes_from(["a", "b", "c"], bipartite=1)
G.add_edges_from([
(1, "a"),
(1, "b", {"direction": "tail"}),
(2, "b"),
(2, "c", {"direction": "tail"}),
(3, "c"),
(4, "a", {"direction": "tail"})])
return G


@pytest.fixture
def attr0():
return {"color": "brown", "name": "camel"}
Expand Down
15 changes: 15 additions & 0 deletions tests/convert/test_bipartite_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,18 @@ def test_from_bipartite_graph(
# not bipartite
with pytest.raises(XGIError):
H = xgi.from_bipartite_graph(bipartite_graph4, dual=True)


def test_from_bipartite_graph_to_dihypergraph(
bipartite_dihypergraph1
):
H = xgi.from_bipartite_graph(bipartite_dihypergraph1)

assert set(H.nodes) == {1, 2, 3, 4}
assert set(H.edges) == {"a", "b", "c"}
assert H.edges.head("a") == {1}
assert H.edges.tail("a") == {4}
assert H.edges.head("b") == {2}
assert H.edges.tail("b") == {1}
assert H.edges.head("c") == {3}
assert H.edges.tail("c") == {2}
19 changes: 15 additions & 4 deletions xgi/convert/bipartite_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from networkx import bipartite

from ..exception import XGIError
from ..generators import empty_hypergraph
from ..generators import empty_hypergraph, empty_dihypergraph

__all__ = ["from_bipartite_graph", "to_bipartite_graph"]

Expand Down Expand Up @@ -76,11 +76,22 @@ def from_bipartite_graph(G, create_using=None, dual=False):
if not bipartite.is_bipartite_node_set(G, nodes):
raise XGIError("The network is not bipartite")

H = empty_hypergraph(create_using)
network_type = G.graph.get("network-type")
if network_type == "directed":
H = empty_dihypergraph(create_using)
else:
H = empty_hypergraph(create_using)

H.add_nodes_from(nodes)
for edge in edges:
nodes_in_edge = list(G.neighbors(edge))
H.add_edge(nodes_in_edge, idx=edge)
for u, v, d in G.edges(edge, data="direction"):
if network_type == "directed":
if d == "tail":
H.add_node_to_edge(u, v, direction="in")
else:
H.add_node_to_edge(u, v, direction="out")
else:
H.add_node_to_edge(u, v)
return H.dual() if dual else H


Expand Down

0 comments on commit 16c2c0e

Please sign in to comment.