diff --git a/pyro/contrib/zuko.py b/pyro/contrib/zuko.py index 775c2e1b6b..bca120b2c2 100644 --- a/pyro/contrib/zuko.py +++ b/pyro/contrib/zuko.py @@ -21,12 +21,24 @@ class Zuko2Pyro(pyro.distributions.TorchDistribution): :param dist: A distribution instance. :type dist: torch.distributions.Distribution - Example: - >>> flow = zuko.flows.MAF(features=5) - >>> dist = Zuko2Pyro(flow()) - >>> dist((2, 3)).shape - torch.Size([2, 3, 5]) - >>> x = pyro.sample("x", dist) + .. code-block:: python + + flow = zuko.flows.MAF(features=5) + + # flow() is a torch.distributions.Distribution + + dist = flow() + x = dist.sample((2, 3)) + log_p = dist.log_prob(x) + + # Zuko2Pyro(flow()) is a pyro.distributions.Distribution + + dist = Zuko2Pyro(flow()) + x = dist((2, 3)) + log_p = dist.log_prob(x) + + with pyro.plate("data", 42): + z = pyro.sample("z", dist) """ def __init__(self, dist: torch.distributions.Distribution):