Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for event weights when calculating percentile cuts #225

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/225.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Weights can now be passed as an optional argument to ```cuts.calculate_percentile_cut```.
75 changes: 65 additions & 10 deletions pyirf/cuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,59 @@
from .binning import calculate_bin_indices, bin_center

__all__ = [
'calculate_percentile_cut',
'evaluate_binned_cut',
'compare_irf_cuts',
"calculate_percentile_cut",
"evaluate_binned_cut",
"compare_irf_cuts",
"weighted_quantile",
]


def weighted_quantile(values, weights, quantiles=0.5, interpolate=False):
"""
Calculate weighted quantiles.

Parameters
----------
values : np.array
weights : np.array
Has to be of same length as ```values```.
quantiles : float or array[float], optional, between (0,1)
By default 0.5, i.e. Median
interpolate : bool, optional
Interpolate between values?, by default False

Returns
-------
float
Quantile
"""
assert len(values) == len(weights), "values and weights must be of the same length"
values = values[~np.isnan(values)]
weights = weights[~np.isnan(values)]

i = values.argsort()
sorted_weights = weights[i]
sorted_values = values[i]
Sn = sorted_weights.cumsum()

if interpolate:
Pn = (Sn - sorted_weights / 2) / Sn[-1]
return np.interp(quantiles, Pn, sorted_values)
else:
return sorted_values[np.searchsorted(Sn, quantiles * Sn[-1])]


def calculate_percentile_cut(
values, bin_values, bins, fill_value, percentile=68, min_value=None, max_value=None,
smoothing=None, min_events=10,
values,
bin_values,
bins,
fill_value,
weights=None,
percentile=68,
min_value=None,
max_value=None,
smoothing=None,
min_events=10,
):
"""
Calculate cuts as the percentile of a given quantity in bins of another
Expand All @@ -31,6 +75,10 @@ def calculate_percentile_cut(
fill_value: float or quantity
Value for bins with less than ``min_events``,
must have same unit as values
weights: ``~numpy.ndarray``
Array containing the weight of each entry in ```values```.
The default value of None corresponds to equal weights.
Must be of same length as ```values```.
percentile: float
The percentile to calculate in each bin as a percentage,
i.e. 0 <= percentile <= 100.
Expand All @@ -44,9 +92,14 @@ def calculate_percentile_cut(
min_events: int
Bins with less events than this number are replaced with ``fill_value``
"""

if weights is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer it (and it's probably also faster) if the weights=None case would use the existing code instead of the weighted_quantile method with weights=1.

weights = np.ones(len(values))
else:
assert len(weights) == len(values)
# create a table to make use of groupby operations
# we use a normal table here to avoid astropy/astropy#13840
table = Table({"values": values}, copy=False)
table = Table({"values": values, "weights": weights}, copy=False)
unit = table["values"].unit

# make sure units match
Expand Down Expand Up @@ -77,17 +130,19 @@ def calculate_percentile_cut(
if n_events < min_events:
cut_table["cut"][bin_idx] = fill_value
else:
value = np.nanpercentile(group["values"], percentile)
value = weighted_quantile(
group["values"], group["weights"], percentile / 100, interpolate=True
)
if min_value is not None or max_value is not None:
value = np.clip(value, min_value, max_value)

cut_table["cut"].value[bin_idx] = value

if smoothing is not None:
cut_table['cut'].value[:] = gaussian_filter1d(
cut_table["cut"].value[:] = gaussian_filter1d(
cut_table["cut"].value,
smoothing,
mode='nearest',
mode="nearest",
)

return cut_table
Expand Down Expand Up @@ -119,7 +174,7 @@ def evaluate_binned_cut(values, bin_values, cut_table, op):
Must support vectorized application.
"""
if not isinstance(cut_table, QTable):
raise ValueError('cut_table needs to be an astropy.table.QTable')
raise ValueError("cut_table needs to be an astropy.table.QTable")

bins = np.append(cut_table["low"], cut_table["high"][-1])
bin_index, valid = calculate_bin_indices(bin_values, bins)
Expand Down
90 changes: 82 additions & 8 deletions pyirf/tests/test_cuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@ def events():
)


def test_weighted_quantile():
from pyirf.cuts import weighted_quantile

vals = np.arange(1, 5)
unequal_weights = np.array([1, 1, 1, 3])
equal_weights = np.ones(4)
weights_too_long = np.ones(5)
assert np.allclose(
weighted_quantile(vals, equal_weights, quantiles=0.5, interpolate=False), 2.0
)
assert np.allclose(
weighted_quantile(vals, equal_weights, quantiles=0.5, interpolate=True), 2.5
)

assert np.allclose(
weighted_quantile(vals, unequal_weights, quantiles=0.5, interpolate=False), 3.0
)
assert np.allclose(
weighted_quantile(vals, unequal_weights, quantiles=0.5, interpolate=True), 3.25
)

with pytest.raises(AssertionError):
weighted_quantile(vals, weights_too_long)


def test_calculate_percentile_cuts():
from pyirf.cuts import calculate_percentile_cut

Expand Down Expand Up @@ -93,7 +118,9 @@ def test_calculate_percentile_cuts_smoothing():
bin_values = np.append(np.zeros(N), np.ones(N)) * u.m
bins = [-0.5, 0.5, 1.5] * u.m

cuts = calculate_percentile_cut(values, bin_values, bins, fill_value=np.nan, smoothing=1)
cuts = calculate_percentile_cut(
values, bin_values, bins, fill_value=np.nan, smoothing=1
)
assert np.all(cuts["low"] == bins[:-1])
assert np.all(cuts["high"] == bins[1:])

Expand All @@ -104,10 +131,50 @@ def test_calculate_percentile_cuts_smoothing():
)


def test_calculate_percentile_cuts_weights():
from pyirf.cuts import calculate_percentile_cut

np.random.seed(0)

dist1 = norm(0, 1)
dist2 = norm(10, 1)
N = int(1e4)

values = np.append(np.sort(dist1.rvs(size=N)), np.sort(dist2.rvs(size=N))) * u.deg
bin_values = np.append(np.zeros(N), np.ones(N)) * u.m
weights = np.append(
np.concatenate((68 * np.ones(int(N / 2)), 32 * np.ones(int(N / 2)))),
np.concatenate((68 * np.ones(int(N / 2)), 32 * np.ones(int(N / 2)))),
)
# add some values outside of binning to test that under/overflow are ignored
bin_values[10] = 5 * u.m
bin_values[30] = -1 * u.m

bins = [-0.5, 0.5, 1.5] * u.m

cuts = calculate_percentile_cut(
values, bin_values, bins, weights=weights, fill_value=np.nan * u.deg
)
assert np.all(cuts["low"] == bins[:-1])
assert np.all(cuts["high"] == bins[1:])

assert np.allclose(
cuts["cut"].to_value(u.deg),
[0, 10],
atol=0.1,
)


def test_evaluate_binned_cut():
from pyirf.cuts import evaluate_binned_cut

cuts = QTable({"low": [0, 1], "high": [1, 2], "cut": [100, 1000],})
cuts = QTable(
{
"low": [0, 1],
"high": [1, 2],
"cut": [100, 1000],
}
)

survived = evaluate_binned_cut(
np.array([500, 1500, 50, 2000, 25, 800]),
Expand All @@ -119,7 +186,11 @@ def test_evaluate_binned_cut():

# test with quantity
cuts = QTable(
{"low": [0, 1] * u.TeV, "high": [1, 2] * u.TeV, "cut": [100, 1000] * u.m,}
{
"low": [0, 1] * u.TeV,
"high": [1, 2] * u.TeV,
"cut": [100, 1000] * u.m,
}
)

survived = evaluate_binned_cut(
Expand All @@ -135,6 +206,7 @@ def test_compare_irf_cuts():
"""Tests compare_irf_cuts."""

from pyirf.cuts import compare_irf_cuts

# first create some dummy cuts
enbins = np.logspace(-2, 3) * u.TeV
thcuts1 = np.linspace(0.5, 0.1) * u.deg
Expand All @@ -151,7 +223,7 @@ def test_compare_irf_cuts():


def test_calculate_percentile_cuts_table():
'''Test that calculate percentile cuts does not modify input table'''
"""Test that calculate percentile cuts does not modify input table"""
from pyirf.cuts import calculate_percentile_cut

np.random.seed(0)
Expand All @@ -160,10 +232,12 @@ def test_calculate_percentile_cuts_table():
dist2 = norm(10, 1)
N = int(1e4)

table = QTable({
"foo": np.append(dist1.rvs(size=N), dist2.rvs(size=N)) * u.deg,
"bar": np.append(np.zeros(N), np.ones(N)) * u.m,
})
table = QTable(
{
"foo": np.append(dist1.rvs(size=N), dist2.rvs(size=N)) * u.deg,
"bar": np.append(np.zeros(N), np.ones(N)) * u.m,
}
)

bins = [-0.5, 0.5, 1.5] * u.m
cuts = calculate_percentile_cut(
Expand Down