From 031a7e1a6e87e25485971976463e3014fbdb2edd Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Wed, 15 Dec 2021 15:07:34 +0100 Subject: [PATCH 1/2] Add support for NaN in utils.infer_labels() --- audmetric/core/utils.py | 9 +++++++++ tests/test_utils.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/audmetric/core/utils.py b/audmetric/core/utils.py index 1d6726b..ab99221 100644 --- a/audmetric/core/utils.py +++ b/audmetric/core/utils.py @@ -32,6 +32,8 @@ def infer_labels( labels in sorted order """ + truth = _remove_nan(truth) + prediction = _remove_nan(prediction) return sorted(list(set(truth) | set(prediction))) @@ -104,3 +106,10 @@ def scores_per_subgroup_and_class( zero_division=zero_division, ) return score + + +def _remove_nan(sequence): + return [ + s for s in sequence + if not isinstance(s, float) or not np.isnan(s) + ] diff --git a/tests/test_utils.py b/tests/test_utils.py index 083c136..c94ee5b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,46 @@ import audmetric +@pytest.mark.parametrize( + 'truth, prediction, expected_labels', + [ + ( + ['a', 'b'], + ['c', 'd'], + ['a', 'b', 'c', 'd'], + ), + ( + ['d', 'c'], + ['b', 'a'], + ['a', 'b', 'c', 'd'], + ), + ( + ['a', 'a'], + ['a', 'a'], + ['a'], + ), + ( + [0, 1, 2], + [2, 2, 2], + [0, 1, 2], + ), + ( + [0, 1], + [np.NaN, 1], + [0, 1], + ), + ( + ['a', 'b'], + [np.NaN, 'c'], + ['a', 'b', 'c'], + ), + ] +) +def test_infer_labels(truth, prediction, expected_labels): + labels = audmetric.utils.infer_labels(truth, prediction) + assert expected_labels == labels + + @pytest.mark.parametrize( 'truth,prediction,protected_variable,metric,labels,subgroups,' 'zero_division,expected', From 513244dc1249a5ef3c86bd39b5aced02a0e8e7e1 Mon Sep 17 00:00:00 2001 From: Hagen Wierstorf Date: Wed, 15 Dec 2021 15:08:11 +0100 Subject: [PATCH 2/2] Add support for mixed types in infer_labels() --- audmetric/core/utils.py | 2 +- tests/test_utils.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/audmetric/core/utils.py b/audmetric/core/utils.py index ab99221..5ef6061 100644 --- a/audmetric/core/utils.py +++ b/audmetric/core/utils.py @@ -34,7 +34,7 @@ def infer_labels( """ truth = _remove_nan(truth) prediction = _remove_nan(prediction) - return sorted(list(set(truth) | set(prediction))) + return sorted(list(set(truth) | set(prediction)), key=str) def scores_per_subgroup_and_class( diff --git a/tests/test_utils.py b/tests/test_utils.py index c94ee5b..c75af5d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,6 +27,11 @@ [2, 2, 2], [0, 1, 2], ), + ( + ['a', 1], + ['a', 0], + [0, 1, 'a'], + ), ( [0, 1], [np.NaN, 1],