From cec3af984b8890c5f078d29f1d856ea2f0ba999d Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Tue, 10 Dec 2024 10:17:24 -0500 Subject: [PATCH] Fix issue with load_xgi_data HIF --- tests/readwrite/test_xgi_data.py | 6 ++++++ xgi/convert/hif_dict.py | 26 +++++++++++++------------- xgi/readwrite/xgi_data.py | 2 +- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/tests/readwrite/test_xgi_data.py b/tests/readwrite/test_xgi_data.py index 976c50e3..84b6e721 100644 --- a/tests/readwrite/test_xgi_data.py +++ b/tests/readwrite/test_xgi_data.py @@ -60,6 +60,12 @@ def test_load_xgi_data(capfd): assert collection["as-you-like-it"].num_nodes == 30 assert collection["as-you-like-it"].num_edges == 80 + # test HIF + H = load_xgi_data("recipe-rec") + assert H.num_nodes == 9271 + assert H.num_edges == 77733 + + @pytest.mark.skipif( sys.version_info != (3, 12) and not platform.system() == "Linux", diff --git a/xgi/convert/hif_dict.py b/xgi/convert/hif_dict.py index 0d288fb3..f137f057 100644 --- a/xgi/convert/hif_dict.py +++ b/xgi/convert/hif_dict.py @@ -108,13 +108,13 @@ def _convert_id(i, idtype): network_type = "undirected" if network_type in {"asc", "undirected"}: - G = Hypergraph() + H = Hypergraph() elif network_type == "directed": - G = DiHypergraph() + H = DiHypergraph() # Import network metadata if "metadata" in data: - G._net_attr.update(data["metadata"]) + H._net_attr.update(data["metadata"]) for record in data["incidences"]: n = _convert_id(record["node"], nodetype) @@ -123,9 +123,9 @@ def _convert_id(i, idtype): if network_type == "directed": d = record["direction"] d = _convert_d(d) # convert from head/tail to in/out - G.add_node_to_edge(e, n, d) + H.add_node_to_edge(e, n, d) else: - G.add_node_to_edge(e, n) + H.add_node_to_edge(e, n) # import node attributes if they exist if "nodes" in data: @@ -136,10 +136,10 @@ def _convert_id(i, idtype): else: attr = {} - if n not in G._node: - G.add_node(n, **attr) + if n not in H._node: + H.add_node(n, **attr) else: - G.set_node_attributes({n: attr}) + H.set_node_attributes({n: attr}) # import edge attributes if they exist if "edges" in data: @@ -149,11 +149,11 @@ def _convert_id(i, idtype): attr = record["attrs"] else: attr = {} - if e not in G._edge: - G.add_edge(_empty_edge(network_type), e, **attr) + if e not in H._edge: + H.add_edge(_empty_edge(network_type), e, **attr) else: - G.set_edge_attributes({e: attr}) + H.set_edge_attributes({e: attr}) if network_type == "asc": - G = SimplicialComplex(G) - return G + H = SimplicialComplex(H) + return H diff --git a/xgi/readwrite/xgi_data.py b/xgi/readwrite/xgi_data.py index e2922525..a35baa33 100644 --- a/xgi/readwrite/xgi_data.py +++ b/xgi/readwrite/xgi_data.py @@ -165,7 +165,7 @@ def _request_from_xgi_data( jsondata = request_json_from_url(url) if "incidences" in jsondata: - H = from_hif_dict(H, nodetype=nodetype, edgetype=edgetype) + H = from_hif_dict(jsondata, nodetype=nodetype, edgetype=edgetype) if max_order: H = cut_to_order(H, order=max_order) return H