Skip to content

Commit

Permalink
Add function for calculating quantiles of weighed samples. (#3340)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenZickel authored Mar 18, 2024
1 parent 8869834 commit 0474cc9
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
61 changes: 61 additions & 0 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import math
import numbers
from typing import List, Tuple, Union

import torch
from torch.fft import irfft, rfft
Expand Down Expand Up @@ -261,6 +262,66 @@ def quantile(input, probs, dim=0):
return quantiles if probs.shape != torch.Size([]) else quantiles.squeeze(dim)


def weighed_quantile(
input: torch.Tensor,
probs: Union[List[float], Tuple[float, ...], torch.Tensor],
log_weights: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""
Computes quantiles of weighed ``input`` samples at ``probs``.
:param torch.Tensor input: the input tensor.
:param list probs: quantile positions.
:param torch.Tensor log_weights: sample weights tensor.
:param int dim: dimension to take quantiles from ``input``.
:returns torch.Tensor: quantiles of ``input`` at ``probs``.
Example:
>>> from pyro.ops.stats import weighed_quantile
>>> import torch
>>> input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
>>> probs = torch.Tensor([0.2, 0.8])
>>> log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
>>> result = weighed_quantile(input, probs, log_weights, -1)
>>> torch.testing.assert_close(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))
"""
dim = dim if dim >= 0 else (len(input.shape) + dim)
if isinstance(probs, (list, tuple)):
probs = torch.tensor(probs, dtype=input.dtype, device=input.device)
assert isinstance(probs, torch.Tensor)
# Calculate normalized weights
weights = (log_weights - torch.logsumexp(log_weights, 0)).exp()
# Sort input and weights
sorted_input, sorting_indices = input.sort(dim)
weights = weights[sorting_indices].cumsum(dim)
# Scale weights to be between zero and one
weights = weights - weights.min(dim, keepdim=True)[0]
weights = weights / weights.max(dim, keepdim=True)[0]
# Calculate indices
indices_above = (
(weights[..., None] <= probs)
.sum(dim, keepdim=True)
.swapaxes(dim, -1)
.clamp(max=input.size(dim) - 1)[..., 0]
)
indices_below = (indices_above - 1).clamp(min=0)
# Calculate below and above qunatiles
quantiles_below = sorted_input.gather(dim, indices_below)
quantiles_above = sorted_input.gather(dim, indices_above)
# Calculate weights for below and above quantiles
probs_shape = [None] * dim + [slice(None)] + [None] * (len(input.shape) - dim - 1)
expanded_probs_shape = list(input.shape)
expanded_probs_shape[dim] = len(probs)
probs = probs[probs_shape].expand(*expanded_probs_shape)
weights_below = weights.gather(dim, indices_below)
weights_above = weights.gather(dim, indices_above)
weights_below = (weights_above - probs) / (weights_above - weights_below)
weights_above = 1 - weights_below
# Return quantiles
return weights_below * quantiles_below + weights_above * quantiles_above


def pi(input, prob, dim=0):
"""
Computes percentile interval which assigns equal probability mass
Expand Down
20 changes: 20 additions & 0 deletions tests/ops/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
resample,
split_gelman_rubin,
waic,
weighed_quantile,
)
from tests.common import assert_close, assert_equal, xfail_if_not_implemented

Expand Down Expand Up @@ -57,6 +58,25 @@ def test_quantile():
assert_equal(quantile(z, probs=0.8413), torch.tensor(1.0), prec=0.02)


@pytest.mark.init(rng_seed=3)
def test_weighed_quantile():
# Fixed values test
input = torch.Tensor([[10, 50, 40], [20, 30, 0]])
probs = [0.2, 0.8]
log_weights = torch.Tensor([0.4, 0.5, 0.1]).log()
result = weighed_quantile(input, probs, log_weights, -1)
assert_equal(result, torch.Tensor([[40.4, 47.6], [9.0, 26.4]]))

# Random values test
dist = torch.distributions.normal.Normal(0, 1)
input = dist.sample((100000,))
probs = [0.1, 0.7, 0.95]
log_weights = dist.log_prob(input)
result = weighed_quantile(input, probs, log_weights)
result_dist = torch.distributions.normal.Normal(0, torch.tensor(0.5).sqrt())
assert_equal(result, result_dist.icdf(torch.Tensor(probs)), prec=0.01)


def test_pi():
x = torch.randn(1000).exp()
assert_equal(pi(x, prob=0.8), quantile(x, probs=[0.1, 0.9]))
Expand Down

0 comments on commit 0474cc9

Please sign in to comment.