Skip to content

Commit

Permalink
Fix doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jan 4, 2024
1 parent 10e0679 commit 910243c
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions pyro/contrib/zuko.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,24 @@ class Zuko2Pyro(pyro.distributions.TorchDistribution):
:param dist: A distribution instance.
:type dist: torch.distributions.Distribution
Example:
>>> flow = zuko.flows.MAF(features=5)
>>> dist = Zuko2Pyro(flow())
>>> dist((2, 3)).shape
torch.Size([2, 3, 5])
>>> x = pyro.sample("x", dist)
.. code-block:: python
flow = zuko.flows.MAF(features=5)
# flow() is a torch.distributions.Distribution
dist = flow()
x = dist.sample((2, 3))
log_p = dist.log_prob(x)
# Zuko2Pyro(flow()) is a pyro.distributions.Distribution
dist = Zuko2Pyro(flow())
x = dist((2, 3))
log_p = dist.log_prob(x)
with pyro.plate("data", 42):
z = pyro.sample("z", dist)
"""

def __init__(self, dist: torch.distributions.Distribution):
Expand Down

0 comments on commit 910243c

Please sign in to comment.