Skip to content

Commit

Permalink
Resampler for weighed samples (#3352)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenZickel authored Apr 16, 2024
1 parent ca4903e commit 91bc2b3
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 15 deletions.
3 changes: 2 additions & 1 deletion pyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyro.infer.mcmc.hmc import HMC
from pyro.infer.mcmc.nuts import NUTS
from pyro.infer.mcmc.rwkernel import RandomWalkKernel
from pyro.infer.predictive import Predictive, WeighedPredictive
from pyro.infer.predictive import MHResampler, Predictive, WeighedPredictive
from pyro.infer.renyi_elbo import RenyiELBO
from pyro.infer.rws import ReweightedWakeSleep
from pyro.infer.smcfilter import SMCFilter
Expand Down Expand Up @@ -44,6 +44,7 @@
"JitTraceMeanField_ELBO",
"JitTrace_ELBO",
"MCMC",
"MHResampler",
"NUTS",
"Predictive",
"RandomWalkKernel",
Expand Down
193 changes: 189 additions & 4 deletions pyro/infer/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from dataclasses import dataclass
from dataclasses import dataclass, fields
from functools import reduce
from typing import List, Union
from typing import Callable, List, Union

import torch

import pyro
import pyro.poutine as poutine
from pyro.infer.importance import LogWeightsMixin
from pyro.infer.util import plate_log_prob_sum
from pyro.infer.util import CloneMixin, plate_log_prob_sum
from pyro.poutine.trace_struct import Trace
from pyro.poutine.util import prune_subsample_sites

Expand Down Expand Up @@ -320,7 +320,7 @@ def get_vectorized_trace(self, *args, **kwargs):


@dataclass(frozen=True, eq=False)
class WeighedPredictiveResults(LogWeightsMixin):
class WeighedPredictiveResults(LogWeightsMixin, CloneMixin):
"""
Return value of call to instance of :class:`WeighedPredictive`.
"""
Expand Down Expand Up @@ -450,3 +450,188 @@ def forward(self, *args, **kwargs):
guide_log_prob=guide_log_prob,
model_log_prob=model_log_prob,
)


class MHResampler(torch.nn.Module):
r"""
Resampler for weighed samples that generates equally weighed samples from the distribution
specified by the weighed samples ``sampler``.
The resampling is based on the Metropolis-Hastings algorithm.
Given an initial sample :math:`x` subsequent samples are generated by:
- Sampling from the ``guide`` a new sample candidate :math:`x'` with probability :math:`g(x')`.
- Calculate an acceptance probability
:math:`A(x', x) = \min\left(1, \frac{P(x')}{P(x)} \frac{g(x)}{g(x')}\right)`
with :math:`P` being the ``model``.
- With probability :math:`A(x', x)` accept the new sample candidate :math:`x'`
as the next sample, otherwise set the current sample :math:`x` as the next sample.
The above is the Metropolis-Hastings algorithm with the new sample candidate
proposal distribution being equal to the ``guide`` and independent of the
current sample such that :math:`g(x')=g(x' \mid x)`.
:param callable sampler: When called returns :class:`WeighedPredictiveResults`.
:param slice source_samples_slice: Select source samples for storage (default is `slice(0)`, i.e. none).
:param slice stored_samples_slice: Select output samples for storage (default is `slice(0)`, i.e. none).
The typical use case of :class:`MHResampler` would be to convert weighed samples
generated by :class:`WeighedPredictive` into equally weighed samples from the target distribution.
Each time an instance of :class:`MHResampler` is called it returns a new set of samples, with the
samples generated by the first call being distributed according to the ``guide``, and with each
subsequent call the distribution of the samples becomes closer to that of the posterior predictive
disdtribution. It might take some experimentation in order to find out in each case how many times one would
need to call an instance of :class:`MHResampler` in order to be close enough to the posterior
predictive distribution.
Example::
def model():
...
def guide():
...
def conditioned_model():
...
# Fit guide
elbo = Trace_ELBO(num_particles=100, vectorize_particles=True)
svi = SVI(conditioned_model, guide, optim.Adam(dict(lr=3.0)), elbo)
for i in range(num_svi_steps):
svi.step()
# Create callable that returns weighed samples
posterior_predictive = WeighedPredictive(model,
guide=guide,
num_samples=num_samples,
parallel=parallel,
return_sites=["_RETURN"])
prob = 0.95
weighed_samples = posterior_predictive(model_guide=conditioned_model)
# Calculate quantile directly from weighed samples
weighed_samples_quantile = weighed_quantile(weighed_samples.samples['_RETURN'],
[prob],
weighed_samples.log_weights)[0]
resampler = MHResampler(posterior_predictive)
num_mh_steps = 10
for mh_step_count in range(num_mh_steps):
resampled_weighed_samples = resampler(model_guide=conditioned_model)
# Calculate quantile from resampled weighed samples (samples are equally weighed)
resampled_weighed_samples_quantile = quantile(resampled_weighed_samples.samples[`_RETURN`],
[prob])[0]
# Quantiles calculated using both methods should be identical
assert_close(weighed_samples_quantile, resampled_weighed_samples_quantile, rtol=0.01)
.. _mhsampler-behavior:
**Notes on Sampler Behavior:**
- In case the ``guide`` perfectly tracks the ``model`` this sampler will do nothing
as the acceptance probability :math:`A(x', x)` will always be one.
- Furtheremore, if the guide is approximately separable, i.e. :math:`g(z_A, z_B) \approx g_A(z_A) g_B(z_B)`,
with :math:`g_A(z_A)` pefectly tracking the ``model`` and :math:`g_B(z_B)` poorly tracking the ``model``,
quantiles of :math:`z_A` calculated from samples taken from :class:`MHResampler`, will have much lower
variance then quantiles of :math:`z_A` calculated by using :any:`weighed_quantile`, as the effective sample size
of the calculation using :any:`weighed_quantile` will be low due to :math:`g_B(z_B)` poorly tracking
the ``model``, whereas when using :class:`MHResampler` the poor ``model`` tracking of :math:`g_B(z_B)` has
negligible affect on the effective sample size of :math:`z_A` samples.
"""

def __init__(
self,
sampler: Callable,
source_samples_slice: slice = slice(0),
stored_samples_slice: slice = slice(0),
):
super().__init__()
self.sampler = sampler
self.samples = None
self.transition_count = torch.tensor(0, dtype=torch.long)
self.source_samples = []
self.source_samples_slice = source_samples_slice
self.stored_samples = []
self.stored_samples_slice = stored_samples_slice

def forward(self, *args, **kwargs):
"""
Perform single resampling step.
Returns :class:`WeighedPredictiveResults`
"""
with torch.no_grad():
new_samples = self.sampler(*args, **kwargs)
# Store samples
self.source_samples.append(new_samples)
self.source_samples = self.source_samples[self.source_samples_slice]
if self.samples is None:
# First set of samples
self.samples = new_samples.clone()
self.transition_count = torch.zeros_like(
new_samples.log_weights, dtype=torch.long
)
else:
# Apply Metropolis-Hastings algorithm
prob = torch.clamp(
new_samples.log_weights - self.samples.log_weights, max=0.0
).exp()
idx = torch.rand(*prob.shape) <= prob
self.transition_count[idx] += 1
for field_desc in fields(self.samples):
field, new_field = getattr(self.samples, field_desc.name), getattr(
new_samples, field_desc.name
)
if isinstance(field, dict):
for key in field:
field[key][idx] = new_field[key][idx]
else:
field[idx] = new_field[idx]
self.stored_samples.append(self.samples.clone())
self.stored_samples = self.stored_samples[self.stored_samples_slice]
return self.samples

def get_min_sample_transition_count(self):
"""
Return transition count of sample with minimal amount of transitions.
"""
return self.transition_count.min()

def get_total_transition_count(self):
"""
Return total number of transitions.
"""
return self.transition_count.sum()

def get_source_samples(self):
"""
Return source samples that were the input to the Metropolis-Hastings algorithm.
"""
return self.get_samples(self.source_samples)

def get_stored_samples(self):
"""
Return stored samples that were the output of the Metropolis-Hastings algorithm.
"""
return self.get_samples(self.stored_samples)

def get_samples(self, samples):
"""
Return samples that were sampled during execution of the Metropolis-Hastings algorithm.
"""
retval = dict()
for field_desc in fields(self.samples):
field_name, value = field_desc.name, getattr(self.samples, field_desc.name)
if isinstance(value, dict):
retval[field_name] = dict()
for key in value:
retval[field_name][key] = torch.cat(
[getattr(sample, field_name)[key] for sample in samples]
)
else:
retval[field_name] = torch.cat(
[getattr(sample, field_name) for sample in samples]
)
return self.samples.__class__(**retval)
20 changes: 20 additions & 0 deletions pyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numbers
from collections import Counter, defaultdict
from contextlib import contextmanager
from dataclasses import fields

import torch
from opt_einsum import shared_intermediates
Expand Down Expand Up @@ -358,3 +359,22 @@ def plate_log_prob_sum(trace: Trace, plate_symbol: str) -> torch.Tensor:
[site["packed"]["log_prob"]],
)
return log_prob_sum


class CloneMixin:
"""
Mixin class that adds ``.clone`` method to ``@dataclasses.dataclass`` decorated classes
that are made up of ``torch.Tensor`` fields.
"""

def clone(self):
retval = dict()
for field_desc in fields(self):
field_name, value = field_desc.name, getattr(self, field_desc.name)
if isinstance(value, dict):
retval[field_name] = dict()
for key in value:
retval[field_name][key] = value[key].clone()
else:
retval[field_name] = value.clone()
return self.__class__(**retval)
73 changes: 63 additions & 10 deletions tests/infer/test_predictive.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import logging

import pytest
import torch

import pyro
import pyro.distributions as dist
import pyro.optim as optim
import pyro.poutine as poutine
from pyro.infer import SVI, Predictive, Trace_ELBO, WeighedPredictive
from pyro.infer import SVI, MHResampler, Predictive, Trace_ELBO, WeighedPredictive
from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal
from pyro.ops.stats import quantile, weighed_quantile
from tests.common import assert_close


Expand Down Expand Up @@ -39,9 +42,18 @@ def beta_guide(num_trials):
pyro.sample("phi", phi_posterior)


@pytest.mark.parametrize("predictive", [Predictive, WeighedPredictive])
@pytest.mark.parametrize(
"predictive, num_svi_steps, test_unweighed_convergence",
[
(Predictive, 5000, None),
(WeighedPredictive, 5000, True),
(WeighedPredictive, 1000, False),
],
)
@pytest.mark.parametrize("parallel", [False, True])
def test_posterior_predictive_svi_manual_guide(parallel, predictive):
def test_posterior_predictive_svi_manual_guide(
parallel, predictive, num_svi_steps, test_unweighed_convergence
):
true_probs = torch.ones(5) * 0.7
num_trials = (
torch.ones(5) * 400
Expand All @@ -51,9 +63,7 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive):
conditioned_model = poutine.condition(model, data={"obs": num_success})
elbo = Trace_ELBO(num_particles=100, vectorize_particles=True)
svi = SVI(conditioned_model, beta_guide, optim.Adam(dict(lr=3.0)), elbo)
for i in range(
5000
): # Increased to 5000 from 1000 in order for guide optimization to converge
for i in range(num_svi_steps):
svi.step(num_trials)
posterior_predictive = predictive(
model,
Expand All @@ -70,10 +80,53 @@ def test_posterior_predictive_svi_manual_guide(parallel, predictive):
)
marginal_return_vals = weighed_samples.samples["_RETURN"]
assert marginal_return_vals.shape[:1] == weighed_samples.log_weights.shape
# Weights should be uniform as the guide has the same distribution as the model
assert weighed_samples.log_weights.std() < 0.6
# Effective sample size should be close to actual number of samples taken from the guide
assert weighed_samples.get_ESS() > 0.8 * num_samples
# Resample weighed samples
resampler = MHResampler(posterior_predictive)
num_mh_steps = 10
for mh_step_count in range(num_mh_steps):
resampled_weighed_samples = resampler(
num_trials, model_guide=conditioned_model
)
resampled_marginal_return_vals = resampled_weighed_samples.samples["_RETURN"]
# Calculate CDF quantiles
quantile_test_point = 0.95
quantile_test_point_value = quantile(
marginal_return_vals, [quantile_test_point]
)[0]
weighed_quantile_test_point_value = weighed_quantile(
marginal_return_vals, [quantile_test_point], weighed_samples.log_weights
)[0]
resampled_quantile_test_point_value = quantile(
resampled_marginal_return_vals, [quantile_test_point]
)[0]
logging.info(
"Unweighed quantile at test point is: " + str(quantile_test_point_value)
)
logging.info(
"Weighed quantile at test point is: "
+ str(weighed_quantile_test_point_value)
)
logging.info(
"Resampled quantile at test point is: "
+ str(resampled_quantile_test_point_value)
)
# Weighed and resampled quantiles should match
assert_close(
weighed_quantile_test_point_value,
resampled_quantile_test_point_value,
rtol=0.01,
)
if test_unweighed_convergence:
# Weights should be uniform as the guide has the same distribution as the model
assert weighed_samples.log_weights.std() < 0.6
# Effective sample size should be close to actual number of samples taken from the guide
assert weighed_samples.get_ESS() > 0.8 * num_samples
# Weighed and unweighed quantiles should match if guide converged to true model
assert_close(
quantile_test_point_value,
resampled_quantile_test_point_value,
rtol=0.01,
)
assert_close(marginal_return_vals.mean(dim=0), torch.ones(5) * 280, rtol=0.1)


Expand Down

0 comments on commit 91bc2b3

Please sign in to comment.