Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ignore_nan argument to concordance_cc() #43

Merged
merged 15 commits into from
May 22, 2023
47 changes: 35 additions & 12 deletions audmetric/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def accuracy(
def concordance_cc(
truth: typing.Sequence[float],
prediction: typing.Sequence[float],
*,
ignore_nan: bool = False,
) -> float:
r"""Concordance correlation coefficient.

Expand All @@ -92,6 +94,10 @@ def concordance_cc(
Args:
truth: ground truth values
prediction: predicted values
ignore_nan: if ``True``
all samples that contain ``NaN``
in ``truth`` or ``prediction``
are ignored

Returns:
concordance correlation coefficient :math:`\in [-1, 1]`
Expand All @@ -101,7 +107,7 @@ def concordance_cc(

Examples:
>>> concordance_cc([0, 1, 2], [0, 1, 1])
0.6666666666666666
0.6666666666666665

"""
assert_equal_length(truth, prediction)
Expand All @@ -114,20 +120,37 @@ def concordance_cc(
if len(prediction) < 2:
return np.NaN

r = pearson_cc(prediction, truth)
x_mean = prediction.mean()
y_mean = truth.mean()
x_std = prediction.std()
y_std = truth.std()
denominator = (
x_std * x_std
+ y_std * y_std
+ (x_mean - y_mean) * (x_mean - y_mean)
)
# Handle mask NaN cases separetly
# to be as fast as possible
consider_nan = False
if ignore_nan:
mask = ~(np.isnan(truth) | np.isnan(prediction))
length = mask.sum()
if length < prediction.size:
consider_nan = True
# Replace NaN values,
# otherwise mask * x would return NaN
prediction[~mask] = 0
truth[~mask] = 0
mean_y = np.sum(truth) / length
mean_x = np.sum(prediction) / length
a = mask * (prediction - mean_x)
b = mask * (truth - mean_y)

if not consider_nan:
length = prediction.size
mean_y = np.mean(truth)
mean_x = np.mean(prediction)
a = prediction - mean_x
b = truth - mean_y

numerator = 2 * np.dot(a, b)
denominator = np.dot(a, a) + np.dot(b, b) + length * (mean_x - mean_y) ** 2

if denominator == 0:
ccc = np.nan
else:
ccc = 2 * r * x_std * y_std / denominator
ccc = numerator / denominator

return float(ccc)

Expand Down
51 changes: 36 additions & 15 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,30 +265,51 @@ def test_event_error_rate(truth, prediction, eer):
])
def test_concordance_cc(truth, prediction):

def expected(truth, prediction):
if len(prediction) < 2:
ccc_expected = np.NaN
else:
denominator = (
prediction.std() ** 2
+ truth.std() ** 2
+ (prediction.mean() - truth.mean()) ** 2
)
if denominator == 0:
ccc_expected = np.NaN
else:
r = np.corrcoef(list(prediction), list(truth))[0][1]
ccc_expected = (
2 * r * prediction.std() * truth.std()
/ denominator
)
return ccc_expected

ccc = audmetric.concordance_cc(truth, prediction)

prediction = np.array(list(prediction))
truth = np.array(list(truth))

if len(prediction) < 2:
ccc_expected = np.NaN
else:
denominator = (
prediction.std() ** 2
+ truth.std() ** 2
+ (prediction.mean() - truth.mean()) ** 2
)
if denominator == 0:
ccc_expected = np.NaN
else:
r = np.corrcoef(list(prediction), list(truth))[0][1]
ccc_expected = 2 * r * prediction.std() * truth.std() / denominator

np.testing.assert_almost_equal(
ccc,
ccc_expected,
expected(truth, prediction),
)

# Handle NaN in prediction
frankenjoe marked this conversation as resolved.
Show resolved Hide resolved
frankenjoe marked this conversation as resolved.
Show resolved Hide resolved
if len(prediction) > 1:
prediction = prediction.astype('float')
truth = truth.astype('float')
prediction[0] = np.NaN
truth[-1] = np.NaN

ccc = audmetric.concordance_cc(truth, prediction)
assert np.isnan(ccc)

ccc = audmetric.concordance_cc(truth, prediction, ignore_nan=True)
np.testing.assert_almost_equal(
ccc,
expected(truth[1:-1], prediction[1:-1]),
)


@pytest.mark.parametrize('class_range,num_elements,to_string,percentage', [
([0, 10], 5, False, False),
Expand Down