Skip to content

Commit

Permalink
Linting and formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed May 20, 2024
1 parent 037f094 commit 1f0a696
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 40 deletions.
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,11 @@

from . import constraints, kl, transforms


class StableWithLogProb(StableLogProb, Stable):
pass


__all__ = [
"AVFMultivariateNormal",
"AffineBeta",
Expand Down
75 changes: 53 additions & 22 deletions pyro/distributions/stable_log_prob.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import math
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from functools import partial

import torch
from scipy.special import roots_legendre


value_near_zero_tolerance = 0.01
alpha_near_one_tolerance = 0.05

Expand All @@ -21,11 +22,16 @@ def create_integrator(num_points):
weights = torch.Tensor(weights).double()
log_weights = weights.log()
half_roots = roots * 0.5

def integrate(fn, domain):
sl = [slice(None)] + (len(domain.shape) - 1) * [None]
half_roots_sl = half_roots[sl]
value = domain[0] * (0.5 - half_roots_sl) + domain[1] * (0.5 + half_roots_sl)
return torch.logsumexp(fn(value) + log_weights[sl], dim=0) + ((domain[1] - domain[0]) / 2).log()
return (
torch.logsumexp(fn(value) + log_weights[sl], dim=0)
+ ((domain[1] - domain[0]) / 2).log()
)

return integrate


Expand All @@ -46,7 +52,7 @@ def log_prob(self, value):
alpha = self.stability.double()
beta = self.skew.double()
value = value.double()

return _stable_log_prob(alpha, beta, value, self.coords) - self.scale.log()


Expand All @@ -55,23 +61,35 @@ def _stable_log_prob(alpha, beta, value, coords):
# continuously on (alpha,beta), allowing interpolation around the hole at
# alpha=1.
if coords == "S":
value = torch.where(alpha == 1, value, value - beta * (math.pi / 2 * alpha)).tan()
value = torch.where(
alpha == 1, value, value - beta * (math.pi / 2 * alpha)
).tan()
elif coords != "S0":
raise ValueError("Unknown coords: {}".format(coords))

# Find near one alpha
idx = (alpha - 1).abs() < alpha_near_one_tolerance

log_prob = _unsafe_alpha_stable_log_prob_S0(torch.where(idx, 1 + alpha_near_one_tolerance, alpha), beta, value)
log_prob = _unsafe_alpha_stable_log_prob_S0(
torch.where(idx, 1 + alpha_near_one_tolerance, alpha), beta, value
)

# Handle alpha near one by interpolation
if idx.any():
log_prob_pos = log_prob[idx]
log_prob_neg = _unsafe_alpha_stable_log_prob_S0(
(1 - alpha_near_one_tolerance) * log_prob_pos.new_ones(log_prob_pos.shape), beta[idx], value[idx])
(1 - alpha_near_one_tolerance) * log_prob_pos.new_ones(log_prob_pos.shape),
beta[idx],
value[idx],
)
weights = (alpha[idx] - 1) / (2 * alpha_near_one_tolerance) + 0.5
log_prob[idx] = torch.logsumexp(torch.stack((log_prob_pos + weights.log(),
log_prob_neg + (1 - weights).log()), dim=0), dim=0)
log_prob[idx] = torch.logsumexp(
torch.stack(
(log_prob_pos + weights.log(), log_prob_neg + (1 - weights).log()),
dim=0,
),
dim=0,
)

return log_prob

Expand All @@ -83,22 +101,33 @@ def _unsafe_alpha_stable_log_prob_S0(alpha, beta, Z):
# continuously on (alpha,beta), allowing interpolation around the hole at
# alpha=1.
Z = Z + beta * (math.pi / 2 * alpha).tan()

# Find near zero values
per_alpha_value_near_zero_tolerance = value_near_zero_tolerance * alpha / (1 - alpha).abs()
per_alpha_value_near_zero_tolerance = (
value_near_zero_tolerance * alpha / (1 - alpha).abs()
)
idx = Z.abs() < per_alpha_value_near_zero_tolerance

# Calculate log-prob at safe values
log_prob = _unsafe_stable_log_prob(alpha, beta, torch.where(idx, per_alpha_value_near_zero_tolerance, Z))
log_prob = _unsafe_stable_log_prob(
alpha, beta, torch.where(idx, per_alpha_value_near_zero_tolerance, Z)
)

# Handle near zero values by interpolation
if idx.any():
log_prob_pos = log_prob[idx]
log_prob_neg = _unsafe_stable_log_prob(alpha[idx], beta[idx], -per_alpha_value_near_zero_tolerance[idx])
log_prob_neg = _unsafe_stable_log_prob(
alpha[idx], beta[idx], -per_alpha_value_near_zero_tolerance[idx]
)
weights = Z[idx] / (2 * per_alpha_value_near_zero_tolerance[idx]) + 0.5
log_prob[idx] = torch.logsumexp(torch.stack((log_prob_pos + weights.log(),
log_prob_neg + (1 - weights).log()), dim=0), dim=0)

log_prob[idx] = torch.logsumexp(
torch.stack(
(log_prob_pos + weights.log(), log_prob_neg + (1 - weights).log()),
dim=0,
),
dim=0,
)

return log_prob


Expand All @@ -109,7 +138,7 @@ def _unsafe_stable_log_prob(alpha, beta, Z):
b = beta * ha.tan()
atan_b = b.atan()
u_zero = -alpha.reciprocal() * atan_b

# If sample should be negative calculate with flipped beta and flipped value
flip_beta_x = Z < 0
beta = torch.where(flip_beta_x, -beta, beta)
Expand All @@ -119,8 +148,10 @@ def _unsafe_stable_log_prob(alpha, beta, Z):
# Set integration domwin
domain = torch.stack((u_zero, 0.5 * math.pi * u_zero.new_ones(u_zero.shape)), dim=0)

integrand = partial(_unsafe_stable_given_uniform_log_prob, alpha=alpha, beta=beta, Z=Z)

integrand = partial(
_unsafe_stable_given_uniform_log_prob, alpha=alpha, beta=beta, Z=Z
)

return integrate(integrand, domain) - math.log(math.pi)


Expand Down Expand Up @@ -151,8 +182,8 @@ def _unsafe_stable_given_uniform_log_prob(V, alpha, beta, Z):
log_prob = -W + (alpha * W / Z / (alpha - 1)).abs().log()

# Infinite W means zero-probability
log_prob = torch.where(W==torch.inf, -torch.inf, log_prob)
log_prob = torch.where(W == torch.inf, -torch.inf, log_prob)

log_prob = log_prob.clamp(min=MIN_LOG, max=MAX_LOG)

return log_prob
return log_prob
41 changes: 23 additions & 18 deletions tests/distributions/test_stable_with_log_prob.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import logging

import pytest
import pyro
import torch

import pyro
from pyro.distributions import StableWithLogProb as Stable
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal

from tests.common import assert_close


torch.set_default_dtype(torch.float64)


@pytest.mark.parametrize(
"alpha, beta, c, mu",
[
(1.00, 0.8, 2.0, 3.0),
(1.02, -0.8, 2.0, -3.0),
(0.98, 0.5, 1.0, -3.0),
(0.95, -0.5, 1.0, 3.0),
(1.10, 0.0, 1.0, 0.0),
(1.80, -0.5, 1.0, -2.0),
(0.50, 0.0, 1.0, 2.0),
(1.00, 0.8, 2.0, 3.0),
(1.02, -0.8, 2.0, -3.0),
(0.98, 0.5, 1.0, -3.0),
(0.95, -0.5, 1.0, 3.0),
(1.10, 0.0, 1.0, 0.0),
(1.80, -0.5, 1.0, -2.0),
(0.50, 0.0, 1.0, 2.0),
],
)
@pytest.mark.parametrize(
Expand All @@ -39,9 +40,13 @@ def test_stable_with_log_prob_param_fit(alpha, beta, c, mu, alpha_0, beta_0, c_0
pyro.set_rng_seed(20240520)
data = Stable(alpha, beta, c, mu).sample((n,))

def model(data):
alpha = pyro.param("alpha", torch.tensor(alpha_0), constraint=constraints.interval(0, 2))
beta = pyro.param("beta", torch.tensor(beta_0), constraint=constraints.interval(-1, 1))
def model(data):
alpha = pyro.param(
"alpha", torch.tensor(alpha_0), constraint=constraints.interval(0, 2)
)
beta = pyro.param(
"beta", torch.tensor(beta_0), constraint=constraints.interval(-1, 1)
)
c = pyro.param("c", torch.tensor(c_0), constraint=constraints.positive)
mu = pyro.param("mu", torch.tensor(mu_0), constraint=constraints.real)
with pyro.plate("data", data.shape[0]):
Expand Down Expand Up @@ -75,13 +80,13 @@ def log_progress():

# Fit model to data
guide = AutoNormal(model)
train(model, guide);
train(model, guide)

# Verify fit accuracy
assert_close(alpha, pyro.param('alpha').item(), atol=0.03)
assert_close(beta, pyro.param('beta').item(), atol=0.04)
assert_close(c, pyro.param('c').item(), atol=0.2)
assert_close(mu, pyro.param('mu').item(), atol=0.2)
assert_close(alpha, pyro.param("alpha").item(), atol=0.03)
assert_close(beta, pyro.param("beta").item(), atol=0.04)
assert_close(c, pyro.param("c").item(), atol=0.2)
assert_close(mu, pyro.param("mu").item(), atol=0.2)


# # The below tests will be executed:
Expand Down

0 comments on commit 1f0a696

Please sign in to comment.