Skip to content

Commit

Permalink
Add documentation for the pyro.infer.predictive.MHResampler weighed s…
Browse files Browse the repository at this point in the history
…amples resampler.
  • Loading branch information
Ben Zickel committed Apr 5, 2024
1 parent 09e6aab commit eae323b
Showing 1 changed file with 25 additions and 2 deletions.
27 changes: 25 additions & 2 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,30 @@ def forward(self, *args, **kwargs):


class MHResampler(torch.nn.Module):
"""
Resampler for weighed samples that is based on the Metropolis-Hastings algorithm.
r"""
Resampler for weighed samples that generates equally weighed samples from the distribution
specified by the weighed samples ``sampler``.
The resampling is based on the Metropolis-Hastings algorithm.
Given an initial sample :math:`x` subsequent samples are generated by:
- Sampling from the ``guide`` a new sample candidate :math:`x'` with probability :math:`g(x')`.
- Calculate an acceptance probability
:math:`A(x', x) = \min\left(1, \frac{P(x')}{P(x)} \frac{g(x)}{g(x')}\right)`
with :math:`P` being the ``model``.
- With probability :math:`A(x', x)` accept the new sample candidate :math:`x'`
as the next sample, otherwise set the current sample :math:`x` as the next sample.
The above is the Metropolis-Hastings algorithm with the new sample candidate
proposal distribution being equal to the ``guide`` and independent of the
current sample such that :math:`g(x')=g(x' \mid x)`.
In case the ``guide`` perfectly tracks the ``model`` this sampler will do nothing
as the acceptance probability :math:`A(x', x)` will always be one.
:param callable sampler: When called returns :class:`WeighedPredictiveResults`.
:param slice source_samples_slice: Select source samples for storage (default is `slice(0)`, i.e. none).
:param slice stored_samples_slice: Select output samples for storage (default is `slice(0)`, i.e. none).
"""

def __init__(
Expand All @@ -475,6 +497,7 @@ def __init__(
def forward(self, *args, **kwargs):
"""
Perform single resampling step.
Returns :class:`WeighedPredictiveResults`
"""
with torch.no_grad():
new_samples = self.sampler(*args, **kwargs)
Expand Down

0 comments on commit eae323b

Please sign in to comment.