From 910243c6a0753eeac28e9d74ea23db676678618c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Thu, 4 Jan 2024 17:31:21 +0100 Subject: [PATCH] Fix doctests --- pyro/contrib/zuko.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/pyro/contrib/zuko.py b/pyro/contrib/zuko.py index 775c2e1b6b..bca120b2c2 100644 --- a/pyro/contrib/zuko.py +++ b/pyro/contrib/zuko.py @@ -21,12 +21,24 @@ class Zuko2Pyro(pyro.distributions.TorchDistribution): :param dist: A distribution instance. :type dist: torch.distributions.Distribution - Example: - >>> flow = zuko.flows.MAF(features=5) - >>> dist = Zuko2Pyro(flow()) - >>> dist((2, 3)).shape - torch.Size([2, 3, 5]) - >>> x = pyro.sample("x", dist) + .. code-block:: python + + flow = zuko.flows.MAF(features=5) + + # flow() is a torch.distributions.Distribution + + dist = flow() + x = dist.sample((2, 3)) + log_p = dist.log_prob(x) + + # Zuko2Pyro(flow()) is a pyro.distributions.Distribution + + dist = Zuko2Pyro(flow()) + x = dist((2, 3)) + log_p = dist.log_prob(x) + + with pyro.plate("data", 42): + z = pyro.sample("z", dist) """ def __init__(self, dist: torch.distributions.Distribution):