Skip to content

Commit

Permalink
Vectorize LowerCholeskyTransform (#2007)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored and neerajprad committed Aug 9, 2019
1 parent 2e98ca3 commit ac37ad1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
12 changes: 12 additions & 0 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def _Transform__getstate__(self):
return attrs


# This can be removed after release of https://github.com/pytorch/pytorch/pull/24131
@patch_dependency('torch.distributions.LowerCholeskyTransform._call')
def _LowerCholeskyTransform_call(self, x):
return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()


# This can be removed after release of https://github.com/pytorch/pytorch/pull/24131
@patch_dependency('torch.distributions.LowerCholeskyTransform._inverse')
def _LowerCholeskyTransform_inverse(self, y):
return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()


# Fixes a shape error in Multinomial.support with inhomogeneous .total_count
@patch_dependency('torch.distributions.Multinomial.support')
@torch.distributions.constraints.dependent_property
Expand Down
18 changes: 17 additions & 1 deletion tests/distributions/test_torch_patch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
import torch

import pyro.distributions as dist
from tests.common import requires_cuda
from tests.common import assert_close, requires_cuda


@requires_cuda
Expand All @@ -14,3 +15,18 @@ def test_dirichlet_grad_cuda():
def test_linspace():
x = torch.linspace(-1., 1., 100, device="cuda")
assert x.device.type == "cuda"


@pytest.mark.parametrize("batch_shape", [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize("dim", [1, 2, 3, 4])
def test_lower_cholesky_transform(batch_shape, dim):
t = torch.distributions.transform_to(torch.distributions.constraints.lower_cholesky)
x = torch.randn(batch_shape + (dim, dim))
y = t(x)
assert y.shape == x.shape
actual = y.matmul(y.transpose(-1, -2)).cholesky()
assert_close(actual, y)
x2 = t.inv(y)
assert x2.shape == x.shape
y2 = t(x2)
assert_close(y2, y)

0 comments on commit ac37ad1

Please sign in to comment.