Skip to content

Commit

Permalink
Equalize effect handler (#3375)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenZickel authored Jul 10, 2024
1 parent b55aa9d commit daea9a6
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/pyro.poutine.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ ________________
:undoc-members:
:show-inheritance:

EqualizeMessenger
____________________

.. automodule:: pyro.poutine.equalize_messenger
:members:
:undoc-members:
:show-inheritance:

EscapeMessenger
________________

Expand Down
2 changes: 2 additions & 0 deletions pyro/poutine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
condition,
do,
enum,
equalize,
escape,
infer_config,
lift,
Expand Down Expand Up @@ -36,6 +37,7 @@
"enable_validation",
"enum",
"escape",
"equalize",
"get_mask",
"infer_config",
"is_validation_enabled",
Expand Down
77 changes: 77 additions & 0 deletions pyro/poutine/equalize_messenger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import re
from typing import List, Optional, Union

from typing_extensions import Self

from pyro.distributions import Delta
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message


class EqualizeMessenger(Messenger):
"""
Given a stochastic function with some primitive statements and a list of names,
force the primitive statements at those names to have the same value,
with that value being the result of the first primitive statement matching those names.
Consider the following Pyro program:
>>> def per_category_model(category):
... shift = pyro.param(f'{category}_shift', torch.randn(1))
... mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1))
... std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1))
... return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std))
Running the program for multiple categories can be done by
>>> def model(categories):
... return {category:per_category_model(category) for category in categories}
To make the `std` sample sites have the same value, we can write
>>> equal_std_model = pyro.poutine.equalize(model, '.+_std')
If on top of the above we would like to make the 'shift' parameters identical, we can write
>>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param')
:param fn: a stochastic function (callable containing Pyro primitive calls)
:param sites: a string or list of strings to match site names (the strings can be regular expressions)
:param type: a string specifying the site type (default is 'sample')
:returns: stochastic function decorated with a :class:`~pyro.poutine.equalize_messenger.EqualizeMessenger`
"""

def __init__(
self, sites: Union[str, List[str]], type: Optional[str] = "sample"
) -> None:
super().__init__()
self.sites = [sites] if isinstance(sites, str) else sites
self.type = type

def __enter__(self) -> Self:
self.value = None
return super().__enter__()

def _is_matching(self, msg: Message) -> bool:
if msg["type"] == self.type:
for site in self.sites:
if re.compile(site).fullmatch(msg["name"]) is not None: # type: ignore[arg-type]
return True
return False

def _postprocess_message(self, msg: Message) -> None:
if self.value is None and self._is_matching(msg):
value = msg["value"]
assert value is not None
self.value = value

def _process_message(self, msg: Message) -> None:
if self.value is not None and self._is_matching(msg): # type: ignore[unreachable]
msg["value"] = self.value # type: ignore[unreachable]
if msg["type"] == "sample":
msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim)
msg["infer"] = {"_deterministic": True}
msg["is_observed"] = True
24 changes: 24 additions & 0 deletions pyro/poutine/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from pyro.poutine.condition_messenger import ConditionMessenger
from pyro.poutine.do_messenger import DoMessenger
from pyro.poutine.enum_messenger import EnumMessenger
from pyro.poutine.equalize_messenger import EqualizeMessenger
from pyro.poutine.escape_messenger import EscapeMessenger
from pyro.poutine.infer_config_messenger import InferConfigMessenger
from pyro.poutine.lift_messenger import LiftMessenger
Expand Down Expand Up @@ -301,6 +302,29 @@ def escape( # type: ignore[empty-body]
) -> Union[EscapeMessenger, Callable[_P, _T]]: ...


@overload
def equalize(
sites: Union[str, List[str]],
type: Optional[str],
) -> ConditionMessenger: ...


@overload
def equalize(
fn: Callable[_P, _T],
sites: Union[str, List[str]],
type: Optional[str],
) -> Callable[_P, _T]: ...


@_make_handler(EqualizeMessenger)
def equalize( # type: ignore[empty-body]
fn: Callable[_P, _T],
sites: Union[str, List[str]],
type: Optional[str],
) -> Union[EqualizeMessenger, Callable[_P, _T]]: ...


@overload
def infer_config(
config_fn: Callable[["Message"], "InferDict"],
Expand Down
44 changes: 44 additions & 0 deletions tests/poutine/test_poutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,50 @@ def test_infer_config_sample(self):
assert tr.nodes["p"]["infer"] == {}


class EqualizeHandlerTests(TestCase):
def setUp(self):
def per_category_model(category):
shift = pyro.param(f"{category}_shift", torch.randn(1))
mean = pyro.sample(f"{category}_mean", pyro.distributions.Normal(0, 1))
std = pyro.sample(f"{category}_std", pyro.distributions.LogNormal(0, 1))
with pyro.plate(f"{category}_num_samples", 5):
return pyro.sample(
f"{category}_values", pyro.distributions.Normal(mean + shift, std)
)

def model(categories=["dogs", "cats"]):
return {category: per_category_model(category) for category in categories}

self.model = model

def test_sample_site_equalization(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_std")
tr = pyro.poutine.trace(model).get_trace()
assert_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"])
assert_not_equal(
tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"]
)
guide = pyro.infer.autoguide.AutoNormal(model)
guide_sites = [*guide()]
assert guide_sites == [
"dogs_mean",
"dogs_std",
"dogs_values",
"cats_mean",
"cats_values",
]

def test_param_equalization(self):
pyro.set_rng_seed(20240616)
pyro.clear_param_store()
model = poutine.equalize(self.model, ".+_shift", "param")
tr = pyro.poutine.trace(model).get_trace()
assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"])
assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"])


@pytest.mark.parametrize("first_available_dim", [-1, -2, -3])
@pytest.mark.parametrize("depth", [0, 1, 2])
def test_enumerate_poutine(depth, first_available_dim):
Expand Down

0 comments on commit daea9a6

Please sign in to comment.