From b34ba6c266cfabce56963bac6a559c3126dd14c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ola=20R=C3=B8nning?= Date: Mon, 2 Dec 2024 14:17:37 +0100 Subject: [PATCH] Fixed norm const for SBVM (#3411) * added lognorm terms for high conc sbvm * lint --- .../distributions/sine_bivariate_von_mises.py | 19 +++++++++++++------ .../test_sine_bivariate_von_mises.py | 13 +++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/pyro/distributions/sine_bivariate_von_mises.py b/pyro/distributions/sine_bivariate_von_mises.py index ea0a7e05e9..40be29ec9a 100644 --- a/pyro/distributions/sine_bivariate_von_mises.py +++ b/pyro/distributions/sine_bivariate_von_mises.py @@ -35,7 +35,6 @@ class SineBivariateVonMises(TorchDistribution): This distribution is a submodel of the Bivariate von Mises distribution, called the Sine Distribution [2] in directional statistics. - This distribution is helpful for modeling coupled angles such as torsion angles in peptide chains. To infer parameters, use :class:`~pyro.infer.NUTS` or :class:`~pyro.infer.HMC` with priors that avoid parameterizations where the distribution becomes bimodal; see note below. @@ -44,10 +43,12 @@ class SineBivariateVonMises(TorchDistribution): .. math:: - \frac{\rho}{\kappa_1\kappa_2} \rightarrow 1 + \frac{\rho^2}{\kappa_1\kappa_2} \rightarrow 1 - because the distribution becomes increasingly bimodal. To avoid bimodality use the `weighted_correlation` - parameter with a skew away from one (e.g., Beta(1,3)). The `weighted_correlation` should be in [0,1]. + because the distribution becomes increasingly bimodal. To avoid inefficient sampling use the + `weighted_correlation` parameter with a skew away from one (e.g., + `TransformedDistribution(Beta(5,5), AffineTransform(loc=-1, scale=2))`). The `weighted_correlation` + should be in [-1,1]. .. note:: The correlation and weighted_correlation params are mutually exclusive. @@ -65,7 +66,7 @@ class SineBivariateVonMises(TorchDistribution): :param torch.Tensor psi_concentration: concentration of second angle :param torch.Tensor correlation: correlation between the two angles :param torch.Tensor weighted_correlation: set correlation to weighted_corr * sqrt(phi_conc*psi_conc) - to avoid bimodality (see note). The `weighted_correlation` should be in [0,1]. + to avoid bimodality (see note). The `weighted_correlation` should be in [-1,1]. """ arg_constraints = { @@ -139,7 +140,13 @@ def norm_const(self): + m * torch.log((corr**2).clamp(min=tiny)) - m * torch.log(4 * torch.prod(conc, dim=-1)) ) - fs += log_I1(m.max(), conc, 51).sum(-1) + num_I1terms = torch.maximum( + torch.tensor(501), + torch.max(self.phi_concentration) + torch.max(self.psi_concentration), + ).int() + + fs += log_I1(m.max(), conc, num_I1terms).sum(-1) + mfs = fs.max() norm_const = 2 * torch.log(torch.tensor(2 * pi)) + mfs + (fs - mfs).logsumexp(0) return norm_const.reshape(self.phi_loc.shape) diff --git a/tests/distributions/test_sine_bivariate_von_mises.py b/tests/distributions/test_sine_bivariate_von_mises.py index bd676a6690..6212220dcb 100644 --- a/tests/distributions/test_sine_bivariate_von_mises.py +++ b/tests/distributions/test_sine_bivariate_von_mises.py @@ -130,3 +130,16 @@ def guide(data): ) # k == 'corr' assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2) + + +@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0]) +def test_sine_bivariate_von_mises_norm(conc): + dist = SineBivariateVonMises(0, 0, conc, conc, 0.0) + num_samples = 500 + x = torch.linspace(-torch.pi, torch.pi, num_samples) + y = torch.linspace(-torch.pi, torch.pi, num_samples) + mesh = torch.stack(torch.meshgrid(x, y, indexing="ij"), axis=-1) + integral_torus = ( + torch.exp(dist.log_prob(mesh)) * (2 * torch.pi) ** 2 / num_samples**2 + ).sum() + assert torch.allclose(integral_torus, torch.tensor(1.0), rtol=1e-2)