diff --git a/audmetric/core/utils.py b/audmetric/core/utils.py index 1d6726b..5ef6061 100644 --- a/audmetric/core/utils.py +++ b/audmetric/core/utils.py @@ -32,7 +32,9 @@ def infer_labels( labels in sorted order """ - return sorted(list(set(truth) | set(prediction))) + truth = _remove_nan(truth) + prediction = _remove_nan(prediction) + return sorted(list(set(truth) | set(prediction)), key=str) def scores_per_subgroup_and_class( @@ -104,3 +106,10 @@ def scores_per_subgroup_and_class( zero_division=zero_division, ) return score + + +def _remove_nan(sequence): + return [ + s for s in sequence + if not isinstance(s, float) or not np.isnan(s) + ] diff --git a/tests/test_utils.py b/tests/test_utils.py index 083c136..c75af5d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,51 @@ import audmetric +@pytest.mark.parametrize( + 'truth, prediction, expected_labels', + [ + ( + ['a', 'b'], + ['c', 'd'], + ['a', 'b', 'c', 'd'], + ), + ( + ['d', 'c'], + ['b', 'a'], + ['a', 'b', 'c', 'd'], + ), + ( + ['a', 'a'], + ['a', 'a'], + ['a'], + ), + ( + [0, 1, 2], + [2, 2, 2], + [0, 1, 2], + ), + ( + ['a', 1], + ['a', 0], + [0, 1, 'a'], + ), + ( + [0, 1], + [np.NaN, 1], + [0, 1], + ), + ( + ['a', 'b'], + [np.NaN, 'c'], + ['a', 'b', 'c'], + ), + ] +) +def test_infer_labels(truth, prediction, expected_labels): + labels = audmetric.utils.infer_labels(truth, prediction) + assert expected_labels == labels + + @pytest.mark.parametrize( 'truth,prediction,protected_variable,metric,labels,subgroups,' 'zero_division,expected',