From 139a2bc5f1cff7f77090d1e7b882821833b7fe0b Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Sun, 20 Oct 2024 13:21:12 -0400 Subject: [PATCH] Change aggregate stats type and added `unique` method. (#603) * added the `unique()` method to stats * fix broken test * Update test_core_stats_functions.py --- tests/stats/test_core_stats_functions.py | 19 +++++++++++++++++++ xgi/stats/__init__.py | 21 +++++++++++++-------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/tests/stats/test_core_stats_functions.py b/tests/stats/test_core_stats_functions.py index 919cb45e1..d11f03602 100644 --- a/tests/stats/test_core_stats_functions.py +++ b/tests/stats/test_core_stats_functions.py @@ -13,6 +13,7 @@ ### General functionality +import numpy as np import pandas as pd import pytest @@ -422,6 +423,7 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.nodes.degree.mean(), 3) == 1.125 assert round(H.nodes.degree.std(), 3) == 0.331 assert round(H.nodes.degree.var(), 3) == 0.109 + assert np.allclose(H.nodes.degree.unique(), np.array([1, 2])) assert H.edges.order.max() == 2 assert H.edges.order.min() == 0 @@ -429,6 +431,7 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.edges.order.mean(), 3) == 1.25 assert round(H.edges.order.std(), 3) == 0.829 assert round(H.edges.order.var(), 3) == 0.688 + assert np.allclose(H.edges.order.unique(), np.array([0, 1, 2])) H = xgi.Hypergraph(edgelist2) assert H.nodes.degree.max() == 2 @@ -442,6 +445,7 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.nodes.degree.mean(), 3) == 1.167 assert round(H.nodes.degree.std(), 3) == 0.373 assert round(H.nodes.degree.var(), 3) == 0.139 + assert np.allclose(H.nodes.degree.unique(), np.array([1, 2])) assert H.edges.order.max() == 2 assert H.edges.order.min() == 1 @@ -449,6 +453,9 @@ def test_hypergraph_aggregates(edgelist1, edgelist2, edgelist8): assert round(H.edges.order.mean(), 3) == 1.333 assert round(H.edges.order.std(), 3) == 0.471 assert round(H.edges.order.var(), 3) == 0.222 + assert np.allclose(H.edges.order.unique(), np.array([1, 2])) + assert len(H.edges.order.unique(return_counts=True)) == 2 + assert np.allclose(H.edges.order.unique(return_counts=True)[1], np.array([2, 1])) H = xgi.Hypergraph(edgelist8) assert H.nodes.degree.max() == 6 @@ -973,3 +980,15 @@ def test_multi_with_attrs(hyperwithattrs): 5: [2, "blue"], } assert multi.asdict(list) == d + + +def test_aggregate_stats_types(edgelist1): + H = xgi.Hypergraph(edgelist1) + assert isinstance(H.nodes.degree.max(), int) + assert isinstance(H.nodes.degree.min(), int) + assert isinstance(H.nodes.degree.median(), float) + assert isinstance(H.nodes.degree.mean(), float) + assert isinstance(H.nodes.degree.sum(), int) + assert isinstance(H.nodes.degree.std(), float) + assert isinstance(H.nodes.degree.var(), float) + assert isinstance(H.nodes.degree.moment(), float) \ No newline at end of file diff --git a/xgi/stats/__init__.py b/xgi/stats/__init__.py index 042573fb2..92d3cd45b 100644 --- a/xgi/stats/__init__.py +++ b/xgi/stats/__init__.py @@ -203,23 +203,23 @@ def ashist(self, bins=10, bin_edges=False, density=False, log_binning=False): def max(self): """The maximum value of this stat.""" - return self.asnumpy().max(axis=0) + return self.asnumpy().max(axis=0).item() def min(self): """The minimum value of this stat.""" - return self.asnumpy().min(axis=0) + return self.asnumpy().min(axis=0).item() def sum(self): """The sum of this stat.""" - return self.asnumpy().sum(axis=0) + return self.asnumpy().sum(axis=0).item() def mean(self): """The arithmetic mean of this stat.""" - return self.asnumpy().mean(axis=0) + return self.asnumpy().mean(axis=0).item() def median(self): """The median of this stat.""" - return np.median(self.asnumpy(), axis=0) + return np.median(self.asnumpy(), axis=0).item() def std(self): """The standard deviation of this stat. @@ -231,7 +231,7 @@ def std(self): See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/ for more details. """ - return self.asnumpy().std(axis=0) + return self.asnumpy().std(axis=0).item() def var(self): """The variance of this stat. @@ -243,7 +243,7 @@ def var(self): See https://www.allendowney.com/blog/2024/06/08/which-standard-deviation/ for more details. """ - return self.asnumpy().var(axis=0) + return self.asnumpy().var(axis=0).item() def moment(self, order=2, center=False): """The statistical moments of this stat. @@ -257,7 +257,9 @@ def moment(self, order=2, center=False): """ arr = self.asnumpy() - return spmoment(arr, moment=order) if center else np.mean(arr**order) + return ( + spmoment(arr, moment=order) if center else np.mean(arr**order).item() + ) def argmin(self): """The ID corresponding to the minimum of the stat @@ -306,6 +308,9 @@ def argsort(self, reverse=False): d = self.asdict() return sorted(d, key=d.get, reverse=reverse) + def unique(self, return_counts=False): + return np.unique(self.asnumpy(), return_counts=return_counts) + class NodeStat(IDStat): """An arbitrary node-quantity mapping.