From 4dc1f702b4489c5cb8bf599a90eb87a862bd6500 Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Sat, 19 Oct 2024 15:30:00 -0400 Subject: [PATCH 1/3] added the `unique()` method to stats --- tests/stats/test_core_stats_functions.py | 7 +++++++ xgi/stats/__init__.py | 21 +++++++++++++-------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/tests/stats/test_core_stats_functions.py b/tests/stats/test_core_stats_functions.py index 919cb45e1..60b4c638d 100644 --- a/tests/stats/test_core_stats_functions.py +++ b/tests/stats/test_core_stats_functions.py @@ -13,6 +13,7 @@ ### General functionality +import numpy as np import pandas as pd import pytest @@ -422,6 +423,7 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.nodes.degree.mean(), 3) == 1.125 assert round(H.nodes.degree.std(), 3) == 0.331 assert round(H.nodes.degree.var(), 3) == 0.109 + assert np.allclose(H.nodes.degree.unique(), np.array([1, 2])) assert H.edges.order.max() == 2 assert H.edges.order.min() == 0 @@ -429,6 +431,7 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.edges.order.mean(), 3) == 1.25 assert round(H.edges.order.std(), 3) == 0.829 assert round(H.edges.order.var(), 3) == 0.688 + assert np.allclose(H.edges.order.unique(), np.array([0, 1, 2])) H = xgi.Hypergraph(edgelist2) assert H.nodes.degree.max() == 2 @@ -442,6 +445,7 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.nodes.degree.mean(), 3) == 1.167 assert round(H.nodes.degree.std(), 3) == 0.373 assert round(H.nodes.degree.var(), 3) == 0.139 + assert np.allclose(H.nodes.degree.unique(), np.array([1, 2])) assert H.edges.order.max() == 2 assert H.edges.order.min() == 1 @@ -449,6 +453,9 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.edges.order.mean(), 3) == 1.333 assert round(H.edges.order.std(), 3) == 0.471 assert round(H.edges.order.var(), 3) == 0.222 + assert np.allclose(H.edges.order.unique(), np.array([1, 2])) + assert len(H.edges.order.unique(return_counts=True)) == 2 + assert np.allclose(H.edges.order.unique(return_counts=True)[1], np.array([2, 1])) H = xgi.Hypergraph(edgelist8) assert H.nodes.degree.max() == 6 diff --git a/xgi/stats/__init__.py b/xgi/stats/__init__.py index 042573fb2..748e7a085 100644 --- a/xgi/stats/__init__.py +++ b/xgi/stats/__init__.py @@ -203,23 +203,23 @@ def ashist(self, bins=10, bin_edges=False, density=False, log_binning=False): def max(self): """The maximum value of this stat.""" - return self.asnumpy().max(axis=0) + return float(self.asnumpy().max(axis=0)) def min(self): """The minimum value of this stat.""" - return self.asnumpy().min(axis=0) + return float(self.asnumpy().min(axis=0)) def sum(self): """The sum of this stat.""" - return self.asnumpy().sum(axis=0) + return float(self.asnumpy().sum(axis=0)) def mean(self): """The arithmetic mean of this stat.""" - return self.asnumpy().mean(axis=0) + return float(self.asnumpy().mean(axis=0)) def median(self): """The median of this stat.""" - return np.median(self.asnumpy(), axis=0) + return float(np.median(self.asnumpy(), axis=0)) def std(self): """The standard deviation of this stat. @@ -231,7 +231,7 @@ def std(self): See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/ for more details. """ - return self.asnumpy().std(axis=0) + return float(self.asnumpy().std(axis=0)) def var(self): """The variance of this stat. @@ -243,7 +243,7 @@ def var(self): See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/ for more details. """ - return self.asnumpy().var(axis=0) + return float(self.asnumpy().var(axis=0)) def moment(self, order=2, center=False): """The statistical moments of this stat. @@ -257,7 +257,9 @@ def moment(self, order=2, center=False): """ arr = self.asnumpy() - return spmoment(arr, moment=order) if center else np.mean(arr**order) + return ( + float(spmoment(arr, moment=order)) if center else float(np.mean(arr**order)) + ) def argmin(self): """The ID corresponding to the minimum of the stat @@ -306,6 +308,9 @@ def argsort(self, reverse=False): d = self.asdict() return sorted(d, key=d.get, reverse=reverse) + def unique(self, return_counts=False): + return np.unique(self.asnumpy(), return_counts=return_counts) + class NodeStat(IDStat): """An arbitrary node-quantity mapping. From 6ceba6e6fbda3f53f4f234f1570e7e8b8d6bb98e Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Sat, 19 Oct 2024 15:59:59 -0400 Subject: [PATCH 2/3] fix broken test --- xgi/stats/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xgi/stats/__init__.py b/xgi/stats/__init__.py index 748e7a085..92d3cd45b 100644 --- a/xgi/stats/__init__.py +++ b/xgi/stats/__init__.py @@ -203,23 +203,23 @@ def ashist(self, bins=10, bin_edges=False, density=False, log_binning=False): def max(self): """The maximum value of this stat.""" - return float(self.asnumpy().max(axis=0)) + return self.asnumpy().max(axis=0).item() def min(self): """The minimum value of this stat.""" - return float(self.asnumpy().min(axis=0)) + return self.asnumpy().min(axis=0).item() def sum(self): """The sum of this stat.""" - return float(self.asnumpy().sum(axis=0)) + return self.asnumpy().sum(axis=0).item() def mean(self): """The arithmetic mean of this stat.""" - return float(self.asnumpy().mean(axis=0)) + return self.asnumpy().mean(axis=0).item() def median(self): """The median of this stat.""" - return float(np.median(self.asnumpy(), axis=0)) + return np.median(self.asnumpy(), axis=0).item() def std(self): """The standard deviation of this stat. @@ -231,7 +231,7 @@ def std(self): See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/ for more details. """ - return float(self.asnumpy().std(axis=0)) + return self.asnumpy().std(axis=0).item() def var(self): """The variance of this stat. @@ -243,7 +243,7 @@ def var(self): See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/ for more details. """ - return float(self.asnumpy().var(axis=0)) + return self.asnumpy().var(axis=0).item() def moment(self, order=2, center=False): """The statistical moments of this stat. @@ -258,7 +258,7 @@ def moment(self, order=2, center=False): """ arr = self.asnumpy() return ( - float(spmoment(arr, moment=order)) if center else float(np.mean(arr**order)) + spmoment(arr, moment=order) if center else np.mean(arr**order).item() ) def argmin(self): From 3e59dc493bb73ecddc59028147ca8bdf02280693 Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Sat, 19 Oct 2024 16:14:33 -0400 Subject: [PATCH 3/3] Update test_core_stats_functions.py --- tests/stats/test_core_stats_functions.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/stats/test_core_stats_functions.py b/tests/stats/test_core_stats_functions.py index 60b4c638d..d11f03602 100644 --- a/tests/stats/test_core_stats_functions.py +++ b/tests/stats/test_core_stats_functions.py @@ -980,3 +980,15 @@ def test_multi_with_attrs(hyperwithattrs): 5: [2, "blue"], } assert multi.asdict(list) == d + + +def test_aggregate_stats_types(edgelist1): + H = xgi.Hypergraph(edgelist1) + assert isinstance(H.nodes.degree.max(), int) + assert isinstance(H.nodes.degree.min(), int) + assert isinstance(H.nodes.degree.median(), float) + assert isinstance(H.nodes.degree.mean(), float) + assert isinstance(H.nodes.degree.sum(), int) + assert isinstance(H.nodes.degree.std(), float) + assert isinstance(H.nodes.degree.var(), float) + assert isinstance(H.nodes.degree.moment(), float) \ No newline at end of file