Skip to content

Commit

Permalink
Fix autoguide latent shape mismatch (#1961)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and fritzo committed Jul 16, 2019
1 parent 503e57f commit e03ec0d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pyro/contrib/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,12 @@ def _unpack_latent(self, latent):
batch_shape = latent.shape[:-1] # for plates outside of _setup_prototype, e.g. parallel particles
pos = 0
for name, site in self.prototype_trace.iter_stochastic_nodes():
constrained_shape = site["value"].shape
unconstrained_shape = self._unconstrained_shapes[name]
size = _product(unconstrained_shape)
event_dim = site["fn"].event_dim + len(unconstrained_shape) - len(constrained_shape)
unconstrained_shape = broadcast_shape(unconstrained_shape,
batch_shape + (1,) * site["fn"].event_dim)
batch_shape + (1,) * event_dim)
unconstrained_value = latent[..., pos:pos + size].view(unconstrained_shape)
yield site, unconstrained_value
pos += size
Expand Down
10 changes: 10 additions & 0 deletions tests/contrib/autoguide/test_advi.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,13 @@ def model():
guide = AutoDiagonalNormal(model)
with pytest.raises(RuntimeError):
guide()


def test_unpack_latent():
def model():
return pyro.sample('x', dist.LKJCorrCholesky(2, torch.tensor(1.)))

guide = AutoDiagonalNormal(model)
assert guide()['x'].shape == model().shape
latent = guide.sample_latent()
assert list(guide._unpack_latent(latent))[0][1].shape == (1,)

0 comments on commit e03ec0d

Please sign in to comment.