Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotate messengers #3309

Merged
merged 6 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyro/distributions/torch_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import warnings
from collections import OrderedDict
from typing import Callable

import torch
from torch.distributions.kl import kl_divergence, register_kl
Expand All @@ -15,7 +16,7 @@
from .util import broadcast_shape, scale_and_mask


class TorchDistributionMixin(Distribution):
class TorchDistributionMixin(Distribution, Callable):
fritzo marked this conversation as resolved.
Show resolved Hide resolved
"""
Mixin to provide Pyro compatibility for PyTorch distributions.

Expand Down
29 changes: 10 additions & 19 deletions pyro/infer/reparam/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,20 @@
from typing import Callable, Optional

import torch
from typing_extensions import TypedDict

try:
from typing import TypedDict
except ImportError:

def TypedDict(*args, **kwargs):
return dict
class ReparamMessage(TypedDict):
name: str
fn: Callable
value: Optional[torch.Tensor]
is_observed: Optional[bool]


ReparamMessage = TypedDict(
"ReparamMessage",
name=str,
fn=Callable,
value=Optional[torch.Tensor],
is_observed=Optional[bool],
)

ReparamResult = TypedDict(
"ReparamResult",
fn=Callable,
value=Optional[torch.Tensor],
is_observed=Optional[bool],
)
class ReparamResult(TypedDict):
fn: Callable
value: Optional[torch.Tensor]
is_observed: bool


class Reparam(ABC):
Expand Down
25 changes: 17 additions & 8 deletions pyro/poutine/messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
from contextlib import contextmanager
from functools import partial
from types import TracebackType
from typing import Any, Callable, Iterator, List, Optional, Type, TypeVar, cast
from typing import (
Any,
Callable,
Iterator,
List,
Optional,
Type,
TypeVar,
)

from typing_extensions import Self
from typing_extensions import ParamSpec, Self

from .runtime import _PYRO_STACK, Message
from pyro.poutine.runtime import _PYRO_STACK, Message

_F = TypeVar("_F", bound=Callable)
_P = ParamSpec("_P")
_T = TypeVar("_T")


def _context_wrap(
Expand Down Expand Up @@ -76,13 +85,13 @@ class Messenger:
Most inference operations are implemented in subclasses of this.
"""

def __call__(self, fn: _F) -> _F:
def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
if not callable(fn):
raise ValueError(
f"{fn!r} is not callable, did you mean to pass it as a keyword arg?"
)
wraps = _bound_partial(partial(_context_wrap, self, fn))
return cast(_F, wraps)
return wraps
ordabayevy marked this conversation as resolved.
Show resolved Hide resolved

def __enter__(self) -> Self:
"""
Expand Down Expand Up @@ -118,8 +127,8 @@ def __enter__(self) -> Self:

def __exit__(
self,
exc_type: Optional[Type[Exception]],
exc_value: Optional[Exception],
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""
Expand Down
25 changes: 17 additions & 8 deletions pyro/poutine/plate_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator, Optional

from .broadcast_messenger import BroadcastMessenger
from .messenger import block_messengers
from .subsample_messenger import SubsampleMessenger
from pyro.poutine.broadcast_messenger import BroadcastMessenger
from pyro.poutine.messenger import Messenger, block_messengers
from pyro.poutine.subsample_messenger import SubsampleMessenger

if TYPE_CHECKING:
import torch

from pyro.poutine.runtime import Message


class PlateMessenger(SubsampleMessenger):
Expand All @@ -14,19 +20,21 @@ class PlateMessenger(SubsampleMessenger):
combines shape inference, independence annotation, and subsampling
"""

def _process_message(self, msg):
def _process_message(self, msg: "Message") -> None:
super()._process_message(msg)
return BroadcastMessenger._pyro_sample(msg)
BroadcastMessenger._pyro_sample(msg)

def __enter__(self):
def __enter__(self) -> Optional["torch.Tensor"]: # type: ignore[override]
super().__enter__()
if self._vectorized and self._indices is not None:
return self.indices
return None


@contextmanager
def block_plate(name=None, dim=None, *, strict=True):
def block_plate(
name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True
) -> Iterator[None]:
"""
EXPERIMENTAL Context manager to temporarily block a single enclosing plate.

Expand Down Expand Up @@ -63,13 +71,14 @@ def model_2(data):
assert isinstance(dim, int)
assert dim < 0

def predicate(messenger):
def predicate(messenger: Messenger) -> bool:
if not isinstance(messenger, PlateMessenger):
return False
if name is not None:
return messenger.name == name
if dim is not None:
return messenger.dim == dim
raise ValueError("Unreachable")

with block_messengers(predicate) as matches:
if strict and len(matches) != 1:
Expand Down
22 changes: 17 additions & 5 deletions pyro/poutine/reentrant_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,37 @@
# SPDX-License-Identifier: Apache-2.0

import functools
from types import TracebackType
from typing import Callable, Optional, Type, TypeVar

from .messenger import Messenger
from typing_extensions import ParamSpec, Self

from pyro.poutine.messenger import Messenger

_P = ParamSpec("_P")
_T = TypeVar("_T")


class ReentrantMessenger(Messenger):
def __init__(self):
def __init__(self) -> None:
self._ref_count = 0
super().__init__()

def __call__(self, fn):
def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
return functools.wraps(fn)(super().__call__(fn))

def __enter__(self):
def __enter__(self) -> Self:
self._ref_count += 1
if self._ref_count == 1:
super().__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self._ref_count -= 1
if self._ref_count == 0:
super().__exit__(exc_type, exc_value, traceback)
54 changes: 39 additions & 15 deletions pyro/poutine/reparam_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,33 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Callable, Dict, Union
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generic,
List,
Optional,
TypeVar,
Union,
)

import torch
from typing_extensions import ParamSpec

from .messenger import Messenger
from .runtime import effectful
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import Message, effectful

if TYPE_CHECKING:
from pyro.infer.reparam.reparam import Reparam

_P = ParamSpec("_P")
_T = TypeVar("_T")


@effectful(type="get_init_messengers")
def _get_init_messengers():
def _get_init_messengers() -> List[Messenger]:
return []


Expand All @@ -34,24 +51,29 @@ class ReparamMessenger(Messenger):

:param config: Configuration, either a dict mapping site name to
:class:`~pyro.infer.reparam.reparam.Reparameterizer` , or a function
mapping site to :class:`~pyro.infer.reparam.reparam.Reparameterizer` or
mapping site to :class:`~pyro.infer.reparam.reparam.Reparam` or
None. See :mod:`pyro.infer.reparam.strategies` for built-in
configuration strategies.
:type config: dict or callable
"""

def __init__(self, config: Union[Dict[str, object], Callable]):
def __init__(
self,
config: Union[Dict[str, "Reparam"], Callable[[Message], Optional["Reparam"]]],
) -> None:
super().__init__()
assert isinstance(config, dict) or callable(config)
self.config = config
self._args_kwargs = None

def __call__(self, fn):
def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]:
return ReparamHandler(self, fn)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg: Message) -> None:
if type(msg["fn"]).__name__ == "_Subsample":
return
assert msg["name"] is not None
assert isinstance(msg["fn"], TorchDistributionMixin)
if isinstance(self.config, dict):
reparam = self.config.get(msg["name"])
else:
Expand Down Expand Up @@ -79,11 +101,13 @@ def _pyro_sample(self, msg):
# ReplayMessenger we would need to ensure those messengers can
# similarly be safely applied twice, with the second application
# avoiding overwriting the original application.
for m in _get_init_messengers():
m._pyro_sample(msg)
_get_init_messengers_iter = _get_init_messengers()
assert _get_init_messengers_iter is not None
for m in _get_init_messengers_iter:
m._process_message(msg)

# Pass args_kwargs to the reparam via a side channel.
reparam.args_kwargs = self._args_kwargs
reparam.args_kwargs = self._args_kwargs # type: ignore[attr-defined]
try:
new_msg = reparam.apply(
{
Expand All @@ -94,7 +118,7 @@ def _pyro_sample(self, msg):
}
)
finally:
reparam.args_kwargs = None
reparam.args_kwargs = None # type: ignore[attr-defined]

if new_msg["value"] is not None:
# Validate while the original msg["fn"] is known.
Expand All @@ -121,17 +145,17 @@ def _pyro_sample(self, msg):
msg["is_observed"] = new_msg["is_observed"]


class ReparamHandler(object):
class ReparamHandler(Generic[_P, _T]):
"""
Reparameterization poutine.
"""

def __init__(self, msngr, fn):
def __init__(self, msngr, fn: Callable[_P, _T]) -> None:
self.msngr = msngr
self.fn = fn
super().__init__()

def __call__(self, *args, **kwargs):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
# This saves args,kwargs for optional use by reparameterizers.
self.msngr._args_kwargs = args, kwargs
try:
Expand Down
Loading
Loading