Skip to content

Commit

Permalink
Fixed norm const for SBVM (#3411)
Browse files Browse the repository at this point in the history
* added lognorm terms for high conc sbvm

* lint
  • Loading branch information
OlaRonning authored Dec 2, 2024
1 parent 455f7b3 commit b34ba6c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
19 changes: 13 additions & 6 deletions pyro/distributions/sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class SineBivariateVonMises(TorchDistribution):
This distribution is a submodel of the Bivariate von Mises distribution, called the Sine Distribution [2] in
directional statistics.
This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains.
To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that
avoid parameterizations where the distribution becomes bimodal; see note below.
Expand All @@ -44,10 +43,12 @@ class SineBivariateVonMises(TorchDistribution):
.. math::
\frac{\rho}{\kappa_1\kappa_2} \rightarrow 1
\frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1
because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation`
parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1].
because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the
`weighted_correlation` parameter with a skew away from one (e.g.,
`TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation`
should be in [-1,1].
.. note:: The correlation and weighted_correlation params are mutually exclusive.
Expand All @@ -65,7 +66,7 @@ class SineBivariateVonMises(TorchDistribution):
:param torch.Tensor psi_concentration: concentration of second angle
:param torch.Tensor correlation: correlation between the two angles
:param torch.Tensor weighted_correlation: set correlation to weighted_corr * sqrt(phi_conc*psi_conc)
to avoid bimodality (see note). The `weighted_correlation` should be in [0,1].
to avoid bimodality (see note). The `weighted_correlation` should be in [-1,1].
"""

arg_constraints = {
Expand Down Expand Up @@ -139,7 +140,13 @@ def norm_const(self):
+ m * torch.log((corr**2).clamp(min=tiny))
- m * torch.log(4 * torch.prod(conc, dim=-1))
)
fs += log_I1(m.max(), conc, 51).sum(-1)
num_I1terms = torch.maximum(
torch.tensor(501),
torch.max(self.phi_concentration) + torch.max(self.psi_concentration),
).int()

fs += log_I1(m.max(), conc, num_I1terms).sum(-1)

mfs = fs.max()
norm_const = 2 * torch.log(torch.tensor(2 * pi)) + mfs + (fs - mfs).logsumexp(0)
return norm_const.reshape(self.phi_loc.shape)
Expand Down
13 changes: 13 additions & 0 deletions tests/distributions/test_sine_bivariate_von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,16 @@ def guide(data):
) # k == 'corr'

assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2)


@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0])
def test_sine_bivariate_von_mises_norm(conc):
dist = SineBivariateVonMises(0, 0, conc, conc, 0.0)
num_samples = 500
x = torch.linspace(-torch.pi, torch.pi, num_samples)
y = torch.linspace(-torch.pi, torch.pi, num_samples)
mesh = torch.stack(torch.meshgrid(x, y, indexing="ij"), axis=-1)
integral_torus = (
torch.exp(dist.log_prob(mesh)) * (2 * torch.pi) ** 2 / num_samples**2
).sum()
assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)

0 comments on commit b34ba6c

Please sign in to comment.