From 34923013a98d105894f15522b231b278c868ff2f Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 7 May 2024 09:38:02 -0400 Subject: [PATCH 1/2] test --- tests/nn/test_module.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 5508bf9cbb..dda5fb03e3 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -1045,3 +1045,24 @@ def test_with_slice_indexing(self) -> None: def test_module_list() -> None: assert PyroModule[torch.nn.ModuleList] is pyro.nn.PyroModuleList + + +@pytest.mark.parametrize("use_module_local_params", [True, False]) +def test_render_constrained_param(use_module_local_params): + + class Model(PyroModule): + + @PyroParam(constraint=constraints.positive) + def x(self): + return torch.tensor(1.234) + + @PyroParam(constraint=constraints.real) + def y(self): + return torch.tensor(0.456) + + def forward(self): + return self.x + self.y + + with pyro.settings.context(module_local_params=use_module_local_params): + model = Model() + pyro.render_model(model) From f8cfb3514b0b0514819ab494235551445a2366d5 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 7 May 2024 09:38:54 -0400 Subject: [PATCH 2/2] add constraint kwarg to fake param statements --- pyro/nn/module.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 553b33d95d..afa1ac5851 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -582,7 +582,12 @@ def __getattr__(self, name: str) -> Any: constrained_value.unconstrained = weakref.ref(unconstrained_value) return pyro.poutine.runtime.effectful(type="param")( lambda *_, **__: constrained_value - )(fullname, event_dim=event_dim, name=fullname) + )( + fullname, + constraint=constraint, + event_dim=event_dim, + name=fullname, + ) else: # Cannot determine supermodule and hence cannot compute fullname. constrained_value = transform_to(constraint)(unconstrained_value) constrained_value.unconstrained = weakref.ref(unconstrained_value) @@ -621,7 +626,7 @@ def __getattr__(self, name: str) -> Any: # even though we don't use the contents of the local parameter store fullname = self._pyro_get_fullname(name) pyro.poutine.runtime.effectful(type="param")(lambda *_, **__: result)( - fullname, result, name=fullname + fullname, result, constraint=constraints.real, name=fullname ) if isinstance(result, torch.nn.Module): @@ -645,7 +650,12 @@ def __getattr__(self, name: str) -> Any: ) pyro.poutine.runtime.effectful(type="param")( lambda *_, **__: param_value - )(fullname_param, param_value, name=fullname_param) + )( + fullname_param, + param_value, + constraint=constraints.real, + name=fullname_param, + ) return result