Skip to content

Commit

Permalink
Add Zuko2Pyro test
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 18, 2023
1 parent c04675a commit 492af35
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tests/contrib/test_zuko.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


import pytest
import pyro
import torch

from pyro.contrib.zuko import Zuko2Pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO


@pytest.mark.parametrize("multivariate", [True, False])
def test_Zuko2Pyro(multivariate: bool):
# Distribution
if multivariate:
normal = torch.distributions.MultivariateNormal
mu = torch.zeros(3)
sigma = torch.eye(3)
else:
normal = torch.distributions.Normal
mu = torch.zeros(())
sigma = torch.ones(())

dist = normal(mu, sigma)

# Sample
x1 = pyro.sample("x1", Zuko2Pyro(dist))

assert x1.shape == dist.event_shape

# Sample within plate
with pyro.plate("data", 4):
x2 = pyro.sample("x2", Zuko2Pyro(dist))

assert x2.shape == (4, *dist.event_shape)

# SVI
def model():
pyro.sample("a", Zuko2Pyro(dist))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(dist))

def guide():
mu_ = pyro.param("mu", mu)
sigma_ = pyro.param("sigma", sigma)

pyro.sample("a", Zuko2Pyro(normal(mu_, sigma_)))

with pyro.plate("data", 4):
pyro.sample("b", Zuko2Pyro(normal(mu_, sigma_)))

svi = SVI(model, guide, optim=Adam({"lr": 1e-3}), loss=Trace_ELBO())
svi.step()

0 comments on commit 492af35

Please sign in to comment.