From 2bbbd587638964d9f2ebb2a6a8018e7321e74c2a Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Wed, 29 May 2024 18:40:18 +0300 Subject: [PATCH] Formatting and linting. --- pyro/distributions/stable.py | 2 +- pyro/infer/reparam/stable.py | 12 ++++++++++-- tests/distributions/conftest.py | 11 ++++++++--- tests/distributions/test_distributions.py | 2 +- tests/distributions/test_stable_log_prob.py | 3 +-- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/pyro/distributions/stable.py b/pyro/distributions/stable.py index c9160a1f32..b988b6264c 100644 --- a/pyro/distributions/stable.py +++ b/pyro/distributions/stable.py @@ -107,7 +107,7 @@ class Stable(TorchDistribution): This implements a reparametrized sampler :meth:`rsample` , and a relatively expensive :meth:`log_prob` calculation by numerical integration which makes - inference slow (compared to other distributions) , but with better + inference slow (compared to other distributions) , but with better convergence properties especially for :math:`\alpha`-stable distributions that are skewed (see the ``skew`` parameter below). Faster inference can be performed using either likelihood-free algorithms such as diff --git a/pyro/infer/reparam/stable.py b/pyro/infer/reparam/stable.py index 89e5504c59..a33a4d8255 100644 --- a/pyro/infer/reparam/stable.py +++ b/pyro/infer/reparam/stable.py @@ -44,7 +44,11 @@ def apply(self, msg): is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) - assert isinstance(fn, dist.Stable) and fn.coords == "S0" and not isinstance(fn, dist.StableWithLogProb) + assert ( + isinstance(fn, dist.Stable) + and fn.coords == "S0" + and not isinstance(fn, dist.StableWithLogProb) + ) if is_observed: raise NotImplementedError( f"At pyro.sample({repr(name)},...), " @@ -101,7 +105,11 @@ def apply(self, msg): is_observed = msg["is_observed"] fn, event_dim = self._unwrap(fn) - assert isinstance(fn, dist.Stable) and fn.coords == "S0" and not isinstance(fn, dist.StableWithLogProb) + assert ( + isinstance(fn, dist.Stable) + and fn.coords == "S0" + and not isinstance(fn, dist.StableWithLogProb) + ) if is_validation_enabled(): if not (fn.skew == 0).all(): raise ValueError("SymmetricStableReparam found nonzero skew") diff --git a/tests/distributions/conftest.py b/tests/distributions/conftest.py index f544fa54c7..58e308cd73 100644 --- a/tests/distributions/conftest.py +++ b/tests/distributions/conftest.py @@ -515,13 +515,18 @@ def __init__(self, von_loc, von_conc, skewness): "skew": 0.0, "scale": 2.0, "loc": -2.0, - "test_data": [10.0, -10.0] + "test_data": [10.0, -10.0], }, ], scipy_arg_fn=lambda stability, skew, scale, loc: ( (), - {"alpha": np.array(stability), "beta": np.array(skew), "scale": np.array(scale), "loc": np.array(loc)} - ) + { + "alpha": np.array(stability), + "beta": np.array(skew), + "scale": np.array(scale), + "loc": np.array(loc), + }, + ), ), Fixture( pyro_dist=dist.MultivariateStudentT, diff --git a/tests/distributions/test_distributions.py b/tests/distributions/test_distributions.py index fec7730d27..546803ebc7 100644 --- a/tests/distributions/test_distributions.py +++ b/tests/distributions/test_distributions.py @@ -171,7 +171,7 @@ def test_mean(continuous_dist): "SineBivariateVonMises", "VonMises", "ProjectedNormal", - "Stable" + "Stable", ]: pytest.xfail(reason="Euclidean mean is not defined") for i in range(continuous_dist.get_num_test_data()): diff --git a/tests/distributions/test_stable_log_prob.py b/tests/distributions/test_stable_log_prob.py index b2d5a92d25..2e35a6e59b 100644 --- a/tests/distributions/test_stable_log_prob.py +++ b/tests/distributions/test_stable_log_prob.py @@ -10,8 +10,7 @@ import pyro import pyro.distributions import pyro.distributions.stable_log_prob -from pyro.distributions import Stable -from pyro.distributions import constraints +from pyro.distributions import Stable, constraints from pyro.infer import SVI, Trace_ELBO from pyro.infer.autoguide import AutoNormal from tests.common import assert_close