diff --git a/docs/source/pyro.poutine.txt b/docs/source/pyro.poutine.txt index f975759586..443b527ce6 100644 --- a/docs/source/pyro.poutine.txt +++ b/docs/source/pyro.poutine.txt @@ -54,6 +54,14 @@ ________________ :undoc-members: :show-inheritance: +EqualizeMessenger +____________________ + +.. automodule:: pyro.poutine.equalize_messenger + :members: + :undoc-members: + :show-inheritance: + EscapeMessenger ________________ diff --git a/pyro/poutine/__init__.py b/pyro/poutine/__init__.py index 6ea794c6bc..78d11a9655 100644 --- a/pyro/poutine/__init__.py +++ b/pyro/poutine/__init__.py @@ -8,6 +8,7 @@ condition, do, enum, + equalize, escape, infer_config, lift, @@ -36,6 +37,7 @@ "enable_validation", "enum", "escape", + "equalize", "get_mask", "infer_config", "is_validation_enabled", diff --git a/pyro/poutine/equalize_messenger.py b/pyro/poutine/equalize_messenger.py new file mode 100644 index 0000000000..035d96c06b --- /dev/null +++ b/pyro/poutine/equalize_messenger.py @@ -0,0 +1,77 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import re +from typing import List, Optional, Union + +from typing_extensions import Self + +from pyro.distributions import Delta +from pyro.poutine.messenger import Messenger +from pyro.poutine.runtime import Message + + +class EqualizeMessenger(Messenger): + """ + Given a stochastic function with some primitive statements and a list of names, + force the primitive statements at those names to have the same value, + with that value being the result of the first primitive statement matching those names. + + Consider the following Pyro program: + + >>> def per_category_model(category): + ... shift = pyro.param(f'{category}_shift', torch.randn(1)) + ... mean = pyro.sample(f'{category}_mean', pyro.distributions.Normal(0, 1)) + ... std = pyro.sample(f'{category}_std', pyro.distributions.LogNormal(0, 1)) + ... return pyro.sample(f'{category}_values', pyro.distributions.Normal(mean + shift, std)) + + Running the program for multiple categories can be done by + + >>> def model(categories): + ... return {category:per_category_model(category) for category in categories} + + To make the `std` sample sites have the same value, we can write + + >>> equal_std_model = pyro.poutine.equalize(model, '.+_std') + + If on top of the above we would like to make the 'shift' parameters identical, we can write + + >>> equal_std_param_model = pyro.poutine.equalize(equal_std_model, '.+_shift', 'param') + + :param fn: a stochastic function (callable containing Pyro primitive calls) + :param sites: a string or list of strings to match site names (the strings can be regular expressions) + :param type: a string specifying the site type (default is 'sample') + :returns: stochastic function decorated with a :class:`~pyro.poutine.equalize_messenger.EqualizeMessenger` + """ + + def __init__( + self, sites: Union[str, List[str]], type: Optional[str] = "sample" + ) -> None: + super().__init__() + self.sites = [sites] if isinstance(sites, str) else sites + self.type = type + + def __enter__(self) -> Self: + self.value = None + return super().__enter__() + + def _is_matching(self, msg: Message) -> bool: + if msg["type"] == self.type: + for site in self.sites: + if re.compile(site).fullmatch(msg["name"]) is not None: # type: ignore[arg-type] + return True + return False + + def _postprocess_message(self, msg: Message) -> None: + if self.value is None and self._is_matching(msg): + value = msg["value"] + assert value is not None + self.value = value + + def _process_message(self, msg: Message) -> None: + if self.value is not None and self._is_matching(msg): # type: ignore[unreachable] + msg["value"] = self.value # type: ignore[unreachable] + if msg["type"] == "sample": + msg["fn"] = Delta(self.value, event_dim=msg["fn"].event_dim) + msg["infer"] = {"_deterministic": True} + msg["is_observed"] = True diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index 278f6a60f2..343b1a1f4b 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -74,6 +74,7 @@ from pyro.poutine.condition_messenger import ConditionMessenger from pyro.poutine.do_messenger import DoMessenger from pyro.poutine.enum_messenger import EnumMessenger +from pyro.poutine.equalize_messenger import EqualizeMessenger from pyro.poutine.escape_messenger import EscapeMessenger from pyro.poutine.infer_config_messenger import InferConfigMessenger from pyro.poutine.lift_messenger import LiftMessenger @@ -301,6 +302,29 @@ def escape( # type: ignore[empty-body] ) -> Union[EscapeMessenger, Callable[_P, _T]]: ... +@overload +def equalize( + sites: Union[str, List[str]], + type: Optional[str], +) -> ConditionMessenger: ... + + +@overload +def equalize( + fn: Callable[_P, _T], + sites: Union[str, List[str]], + type: Optional[str], +) -> Callable[_P, _T]: ... + + +@_make_handler(EqualizeMessenger) +def equalize( # type: ignore[empty-body] + fn: Callable[_P, _T], + sites: Union[str, List[str]], + type: Optional[str], +) -> Union[EqualizeMessenger, Callable[_P, _T]]: ... + + @overload def infer_config( config_fn: Callable[["Message"], "InferDict"], diff --git a/tests/poutine/test_poutines.py b/tests/poutine/test_poutines.py index 944d26988d..7e2f7cfa8e 100644 --- a/tests/poutine/test_poutines.py +++ b/tests/poutine/test_poutines.py @@ -755,6 +755,50 @@ def test_infer_config_sample(self): assert tr.nodes["p"]["infer"] == {} +class EqualizeHandlerTests(TestCase): + def setUp(self): + def per_category_model(category): + shift = pyro.param(f"{category}_shift", torch.randn(1)) + mean = pyro.sample(f"{category}_mean", pyro.distributions.Normal(0, 1)) + std = pyro.sample(f"{category}_std", pyro.distributions.LogNormal(0, 1)) + with pyro.plate(f"{category}_num_samples", 5): + return pyro.sample( + f"{category}_values", pyro.distributions.Normal(mean + shift, std) + ) + + def model(categories=["dogs", "cats"]): + return {category: per_category_model(category) for category in categories} + + self.model = model + + def test_sample_site_equalization(self): + pyro.set_rng_seed(20240616) + pyro.clear_param_store() + model = poutine.equalize(self.model, ".+_std") + tr = pyro.poutine.trace(model).get_trace() + assert_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"]) + assert_not_equal( + tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"] + ) + guide = pyro.infer.autoguide.AutoNormal(model) + guide_sites = [*guide()] + assert guide_sites == [ + "dogs_mean", + "dogs_std", + "dogs_values", + "cats_mean", + "cats_values", + ] + + def test_param_equalization(self): + pyro.set_rng_seed(20240616) + pyro.clear_param_store() + model = poutine.equalize(self.model, ".+_shift", "param") + tr = pyro.poutine.trace(model).get_trace() + assert_equal(tr.nodes["cats_shift"]["value"], tr.nodes["dogs_shift"]["value"]) + assert_not_equal(tr.nodes["cats_std"]["value"], tr.nodes["dogs_std"]["value"]) + + @pytest.mark.parametrize("first_available_dim", [-1, -2, -3]) @pytest.mark.parametrize("depth", [0, 1, 2]) def test_enumerate_poutine(depth, first_available_dim):