Skip to content

Commit

Permalink
Add Stable distribution with numerically integrated log-probability c…
Browse files Browse the repository at this point in the history
…alculation (StableWithLogProb). (#3369)
  • Loading branch information
BenZickel authored May 28, 2024
1 parent 7511353 commit 0678b35
Show file tree
Hide file tree
Showing 8 changed files with 688 additions and 75 deletions.
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,13 @@ Stable
:undoc-members:
:show-inheritance:

StableWithLogProb
-----------------
.. autoclass:: pyro.distributions.StableWithLogProb
:members:
:undoc-members:
:show-inheritance:

TruncatedPolyaGamma
-------------------
.. autoclass:: pyro.distributions.TruncatedPolyaGamma
Expand Down
3 changes: 2 additions & 1 deletion pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
from pyro.distributions.sine_skewed import SineSkewed
from pyro.distributions.softlaplace import SoftLaplace
from pyro.distributions.spanning_tree import SpanningTree
from pyro.distributions.stable import Stable
from pyro.distributions.stable import Stable, StableWithLogProb
from pyro.distributions.torch import __all__ as torch_dists
from pyro.distributions.torch_distribution import (
ExpandedDistribution,
Expand Down Expand Up @@ -234,6 +234,7 @@
"SoftLaplace",
"SpanningTree",
"Stable",
"StableWithLogProb",
"StudentT",
"TorchDistribution",
"TransformModule",
Expand Down
22 changes: 22 additions & 0 deletions pyro/distributions/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributions import constraints
from torch.distributions.utils import broadcast_all

from pyro.distributions.stable_log_prob import StableLogProb
from pyro.distributions.torch_distribution import TorchDistribution


Expand Down Expand Up @@ -204,3 +205,24 @@ def mean(self):
def variance(self):
var = self.scale * self.scale
return var.mul(2).masked_fill(self.stability < 2, math.inf)


class StableWithLogProb(StableLogProb, Stable):
r"""
Levy :math:`\alpha`-stable distribution that is based on
:class:`Stable` but with an added method for calculating the
log probability density using numerical integration.
This should be used in cases where reparameterization does not work
like when trying to estimate the skew :math:`\beta` parameter. Running
times are slower than with reparameterization.
The numerical integration implementation is based on the algorithm
proposed by Chambers, Mallows and Stuck (CMS) for simulating the
Levy :math:`\alpha`-stable distribution. The CMS algorithm involves a
nonlinear transformation of two independent random variables into
one stable random variable. The first random variable is uniformly
distributed while the second is exponentially distributed. The numerical
integration is performed over the first uniformly distributed random
variable.
"""
220 changes: 220 additions & 0 deletions pyro/distributions/stable_log_prob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from functools import partial

import torch

value_near_zero_tolerance_alpha = 0.01
value_near_zero_tolerance_density = 0.1
alpha_near_one_tolerance = 0.05


finfo = torch.finfo(torch.float64)
MAX_LOG = math.log10(finfo.max)
MIN_LOG = math.log10(finfo.tiny)


def create_integrator(num_points):
from scipy.special import roots_legendre

roots, weights = roots_legendre(num_points)
roots = torch.Tensor(roots).double()
weights = torch.Tensor(weights).double()
log_weights = weights.log()
half_roots = roots * 0.5

def integrate(fn, domain):
sl = [slice(None)] + (len(domain.shape) - 1) * [None]
half_roots_sl = half_roots[sl]
value = domain[0] * (0.5 - half_roots_sl) + domain[1] * (0.5 + half_roots_sl)
return (
torch.logsumexp(fn(value) + log_weights[sl], dim=0)
+ ((domain[1] - domain[0]) / 2).log()
)

return integrate


def set_integrator(num_points):
global integrate
integrate = create_integrator(num_points)


# Stub which is replaced by the default integrator when called for the first time
# if a default integrator has not already been set.
def integrate(*args, **kwargs):
set_integrator(num_points=501)
return integrate(*args, **kwargs)


class StableLogProb:
def log_prob(self, value):
# Undo shift and scale
value = (value - self.loc) / self.scale
value_dtype = value.dtype

# Use double precision math
alpha = self.stability.double()
beta = self.skew.double()
value = value.double()

log_prob = _stable_log_prob(alpha, beta, value, self.coords)

return log_prob.to(dtype=value_dtype) - self.scale.log()


def _stable_log_prob(alpha, beta, value, coords):
# Convert to Nolan's parametrization S^0 where samples depend
# continuously on (alpha,beta), allowing interpolation around the hole at
# alpha=1.
if coords == "S":
value = torch.where(
alpha == 1, value, value - beta * (math.pi / 2 * alpha)
).tan()
elif coords != "S0":
raise ValueError("Unknown coords: {}".format(coords))

# Find near one alpha
idx = (alpha - 1).abs() < alpha_near_one_tolerance

log_prob = _unsafe_alpha_stable_log_prob_S0(
torch.where(idx, 1 + alpha_near_one_tolerance, alpha), beta, value
)

# Handle alpha near one by interpolation
if idx.any():
log_prob_pos = log_prob[idx]
log_prob_neg = _unsafe_alpha_stable_log_prob_S0(
(1 - alpha_near_one_tolerance) * log_prob_pos.new_ones(log_prob_pos.shape),
beta[idx],
value[idx],
)
weights = (alpha[idx] - 1) / (2 * alpha_near_one_tolerance) + 0.5
log_prob[idx] = torch.logsumexp(
torch.stack(
(log_prob_pos + weights.log(), log_prob_neg + (1 - weights).log()),
dim=0,
),
dim=0,
)

return log_prob


def _unsafe_alpha_stable_log_prob_S0(alpha, beta, Z):
# Calculate log-probability of Z in Nolan's parametrization S^0. This will fail if alpha is close to 1

# Convert from Nolan's parametrization S^0 where samples depend
# continuously on (alpha,beta), allowing interpolation around the hole at
# alpha=1.
Z = Z + beta * (math.pi / 2 * alpha).tan()

# Find near zero values
per_param_value_near_zero_tolerance = (
value_near_zero_tolerance_alpha * alpha / (1 - alpha).abs()
).clamp(
max=value_near_zero_tolerance_density
* _unsafe_alpha_stable_log_prob_at_zero(alpha, 0).exp().reciprocal()
)
idx = Z.abs() < per_param_value_near_zero_tolerance

# Calculate log-prob at safe values
log_prob = _unsafe_stable_log_prob(
alpha, beta, torch.where(idx, per_param_value_near_zero_tolerance, Z)
)

# Handle near zero values by interpolation
if idx.any():
log_prob_pos = log_prob[idx]
log_prob_neg = _unsafe_stable_log_prob(
alpha[idx], beta[idx], -per_param_value_near_zero_tolerance[idx]
)
weights = Z[idx] / (2 * per_param_value_near_zero_tolerance[idx]) + 0.5
log_prob[idx] = torch.logsumexp(
torch.stack(
(log_prob_pos + weights.log(), log_prob_neg + (1 - weights).log()),
dim=0,
),
dim=0,
)

return log_prob


def _unsafe_stable_log_prob(alpha, beta, Z):
# Calculate log-probability of Z. This will fail if alpha is close to 1
# or if Z is close to 0
ha = math.pi / 2 * alpha
b = beta * ha.tan()
atan_b = b.atan()
u_zero = -alpha.reciprocal() * atan_b

# If sample should be negative calculate with flipped beta and flipped value
flip_beta_x = Z < 0
beta = torch.where(flip_beta_x, -beta, beta)
u_zero = torch.where(flip_beta_x, -u_zero, u_zero)
Z = torch.where(flip_beta_x, -Z, Z)

# Set integration domwin
domain = torch.stack((u_zero, 0.5 * math.pi * u_zero.new_ones(u_zero.shape)), dim=0)

integrand = partial(
_unsafe_stable_given_uniform_log_prob, alpha=alpha, beta=beta, Z=Z
)

return integrate(integrand, domain) - math.log(math.pi)


def _unsafe_stable_given_uniform_log_prob(V, alpha, beta, Z):
# Calculate log-probability of Z given V. This will fail if alpha is close to 1
# or if Z is close to 0
inv_alpha_minus_one = (alpha - 1).reciprocal()
half_pi = math.pi / 2
eps = torch.finfo(V.dtype).eps
# make V belong to the open interval (-pi/2, pi/2)
V = V.clamp(min=2 * eps - half_pi, max=half_pi - 2 * eps)
ha = half_pi * alpha
b = beta * ha.tan()
atan_b = b.atan()
cos_V = V.cos()

# +/- `ha` term to keep the precision of alpha * (V + half_pi) when V ~ -half_pi
v = atan_b - ha + alpha * (V + half_pi)

term1_log = atan_b.cos().log() * inv_alpha_minus_one
term2_log = (Z * cos_V / v.sin()).log() * alpha * inv_alpha_minus_one
term3_log = ((v - V).cos() / cos_V).log()

W_log = term1_log + term2_log + term3_log

W = W_log.clamp(min=MIN_LOG, max=MAX_LOG).exp()

log_prob = -W + (alpha * W / Z / (alpha - 1)).abs().log()

# Infinite W means zero-probability
log_prob = torch.where(W == torch.inf, -torch.inf, log_prob)

log_prob = log_prob.clamp(min=MIN_LOG, max=MAX_LOG)

return log_prob


def _unsafe_alpha_stable_log_prob_at_zero(alpha, beta):
# Calculate log-probability at value of zero. This will fail if alpha is close to 1
inv_alpha = alpha.reciprocal()
half_pi = math.pi / 2
ha = half_pi * alpha
b = beta * ha.tan()
atan_b = b.atan()

term1_log = (inv_alpha * atan_b).cos().log()
term2_log = atan_b.cos().log() * inv_alpha
term3_log = torch.lgamma(1 + inv_alpha)

log_prob = term1_log - term2_log + term3_log - math.log(math.pi)

log_prob = log_prob.clamp(min=MIN_LOG, max=MAX_LOG)

return log_prob
6 changes: 5 additions & 1 deletion pyro/infer/reparam/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,11 @@ def apply(self, msg):
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.Stable) and fn.coords == "S0"
assert (
isinstance(fn, dist.Stable)
and fn.coords == "S0"
and not isinstance(fn, dist.StableWithLogProb)
)

# Strategy: Let X ~ S0(a,b,s,m) be the stable variable of interest.
# 1. WLOG scale and shift so s=1 and m=0, additionally shifting to convert
Expand Down
2 changes: 1 addition & 1 deletion pyro/infer/reparam/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def _minimal_reparam(fn, is_observed):
return TransformReparam() # Then reparametrize new sites.
fn = fn.base_dist

if isinstance(fn, dist.Stable):
if isinstance(fn, dist.Stable) and not isinstance(fn, dist.StableWithLogProb):
if not is_observed:
return LatentStableReparam()
elif fn.skew.requires_grad or fn.skew.any():
Expand Down
Loading

0 comments on commit 0678b35

Please sign in to comment.