diff --git a/pyro/nn/module.py b/pyro/nn/module.py index afa1ac5851..b84a17875b 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -163,6 +163,7 @@ class PyroSample: assert isinstance(my_module, PyroModule) my_module.x = PyroSample(Normal(0, 1)) # independent my_module.y = PyroSample(lambda self: Normal(self.x, 1)) # dependent + my_module.z = PyroSample(lambda self: self.y ** 2) # deterministic dependent or EXPERIMENTALLY as a decorator on lazy initialization methods:: @@ -175,16 +176,22 @@ def x(self): def y(self): return Normal(self.x, 1) # dependent + @PyroSample + def z(self): + return self.y ** 2 # deterministic dependent + def forward(self): - return self.y # accessed like a @property + return self.z # accessed like a @property :param prior: distribution object or function that inputs the :class:`PyroModule` instance ``self`` and returns a distribution - object. + object or a deterministic value. """ prior: Union[ - "TorchDistributionMixin", Callable[["PyroModule"], "TorchDistributionMixin"] + "TorchDistributionMixin", + Callable[["PyroModule"], "TorchDistributionMixin"], + Callable[["PyroModule"], torch.Tensor], ] def __post_init__(self) -> None: @@ -605,13 +612,17 @@ def __getattr__(self, name: str) -> Any: if value is None: if not hasattr(prior, "sample"): # if not a distribution prior = prior(self) - value = pyro.sample(fullname, prior) + value = ( + pyro.deterministic(fullname, prior) + if isinstance(prior, torch.Tensor) + else pyro.sample(fullname, prior) + ) context.set(fullname, value) return value else: # Cannot determine supermodule and hence cannot compute fullname. if not hasattr(prior, "sample"): # if not a distribution prior = prior(self) - return prior() + return prior if isinstance(prior, torch.Tensor) else prior() result = super().__getattr__(name) diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index dda5fb03e3..07c4daedd1 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -491,9 +491,10 @@ def __init__(self, size): ) self.s = PyroSample(dist.Normal(0, 1)) self.t = PyroSample(lambda self: dist.Normal(self.s, self.z)) + self.u = PyroSample(lambda self: self.t**2) def forward(self): - return self.x + self.y + self.t + return self.x + self.y + self.u class DecoratorModel(PyroModule): @@ -521,8 +522,12 @@ def s(self): def t(self): return dist.Normal(self.s, self.z).to_event(1) + @PyroSample + def u(self): + return self.t**2 + def forward(self): - return self.x + self.y + self.t + return self.x + self.y + self.u @pytest.mark.parametrize("Model", [AttributeModel, DecoratorModel]) @@ -531,19 +536,32 @@ def test_decorator(Model, size): model = Model(size) for i in range(2): trace = poutine.trace(model).get_trace() - assert set(trace.nodes.keys()) == {"_INPUT", "x", "y", "z", "s", "t", "_RETURN"} + assert set(trace.nodes.keys()) == { + "_INPUT", + "x", + "y", + "z", + "s", + "t", + "u", + "_RETURN", + } assert trace.nodes["x"]["type"] == "param" assert trace.nodes["y"]["type"] == "param" assert trace.nodes["z"]["type"] == "param" assert trace.nodes["s"]["type"] == "sample" assert trace.nodes["t"]["type"] == "sample" + assert trace.nodes["u"]["type"] == "sample" assert trace.nodes["x"]["value"].shape == (size,) assert trace.nodes["y"]["value"].shape == (size,) assert trace.nodes["z"]["value"].shape == (size,) assert trace.nodes["s"]["value"].shape == () assert trace.nodes["t"]["value"].shape == (size,) + assert trace.nodes["u"]["value"].shape == (size,) + + assert trace.nodes["u"]["infer"] == {"_deterministic": True} def test_mixin_factory():