Skip to content

Commit

Permalink
Formatting and linting.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed May 29, 2024
1 parent 6fe22b2 commit 2bbbd58
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyro/distributions/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Stable(TorchDistribution):
This implements a reparametrized sampler :meth:`rsample` , and a relatively
expensive :meth:`log_prob` calculation by numerical integration which makes
inference slow (compared to other distributions) , but with better
inference slow (compared to other distributions) , but with better
convergence properties especially for :math:`\alpha`-stable distributions
that are skewed (see the ``skew`` parameter below). Faster
inference can be performed using either likelihood-free algorithms such as
Expand Down
12 changes: 10 additions & 2 deletions pyro/infer/reparam/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def apply(self, msg):
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.Stable) and fn.coords == "S0" and not isinstance(fn, dist.StableWithLogProb)
assert (
isinstance(fn, dist.Stable)
and fn.coords == "S0"
and not isinstance(fn, dist.StableWithLogProb)
)
if is_observed:
raise NotImplementedError(
f"At pyro.sample({repr(name)},...), "
Expand Down Expand Up @@ -101,7 +105,11 @@ def apply(self, msg):
is_observed = msg["is_observed"]

fn, event_dim = self._unwrap(fn)
assert isinstance(fn, dist.Stable) and fn.coords == "S0" and not isinstance(fn, dist.StableWithLogProb)
assert (
isinstance(fn, dist.Stable)
and fn.coords == "S0"
and not isinstance(fn, dist.StableWithLogProb)
)
if is_validation_enabled():
if not (fn.skew == 0).all():
raise ValueError("SymmetricStableReparam found nonzero skew")
Expand Down
11 changes: 8 additions & 3 deletions tests/distributions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,18 @@ def __init__(self, von_loc, von_conc, skewness):
"skew": 0.0,
"scale": 2.0,
"loc": -2.0,
"test_data": [10.0, -10.0]
"test_data": [10.0, -10.0],
},
],
scipy_arg_fn=lambda stability, skew, scale, loc: (
(),
{"alpha": np.array(stability), "beta": np.array(skew), "scale": np.array(scale), "loc": np.array(loc)}
)
{
"alpha": np.array(stability),
"beta": np.array(skew),
"scale": np.array(scale),
"loc": np.array(loc),
},
),
),
Fixture(
pyro_dist=dist.MultivariateStudentT,
Expand Down
2 changes: 1 addition & 1 deletion tests/distributions/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_mean(continuous_dist):
"SineBivariateVonMises",
"VonMises",
"ProjectedNormal",
"Stable"
"Stable",
]:
pytest.xfail(reason="Euclidean mean is not defined")
for i in range(continuous_dist.get_num_test_data()):
Expand Down
3 changes: 1 addition & 2 deletions tests/distributions/test_stable_log_prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import pyro
import pyro.distributions
import pyro.distributions.stable_log_prob
from pyro.distributions import Stable
from pyro.distributions import constraints
from pyro.distributions import Stable, constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from tests.common import assert_close
Expand Down

0 comments on commit 2bbbd58

Please sign in to comment.