Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add param_store.py type hints #3271

Merged
merged 9 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 70 additions & 36 deletions pyro/params/param_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,26 @@
import warnings
import weakref
from contextlib import contextmanager
from typing import (
Callable,
Dict,
ItemsView,
Iterator,
KeysView,
Optional,
Tuple,
Union,
)

import torch
from torch.distributions import constraints, transform_to
from torch.serialization import MAP_LOCATION
from typing_extensions import TypedDict


class StateDict(TypedDict):
params: Dict[str, torch.Tensor]
constraints: Dict[str, constraints.Constraint]


class ParamStoreDict:
Expand Down Expand Up @@ -39,80 +56,86 @@ class ParamStoreDict:
# -------------------------------------------------------------------------------
# New dict-like interface

def __init__(self):
def __init__(self) -> None:
"""
initialize ParamStore data structures
"""
self._params = {} # dictionary from param name to param
self._param_to_name = {} # dictionary from unconstrained param to param name
self._constraints = {} # dictionary from param name to constraint object
self._params: Dict[
str, torch.Tensor
] = {} # dictionary from param name to param
self._param_to_name: Dict[
torch.Tensor, str
] = {} # dictionary from unconstrained param to param name
self._constraints: Dict[
str, constraints.Constraint
] = {} # dictionary from param name to constraint object

def clear(self):
def clear(self) -> None:
"""
Clear the ParamStore
"""
self._params = {}
self._param_to_name = {}
self._constraints = {}

def items(self):
def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
"""
Iterate over ``(name, constrained_param)`` pairs. Note that `constrained_param` is
in the constrained (i.e. user-facing) space.
"""
for name in self._params:
yield name, self[name]

def keys(self):
def keys(self) -> KeysView[str]:
"""
Iterate over param names.
"""
return self._params.keys()

def values(self):
def values(self) -> Iterator[torch.Tensor]:
"""
Iterate over constrained parameter values.
"""
for name, constrained_param in self.items():
yield constrained_param

def __bool__(self):
def __bool__(self) -> bool:
return bool(self._params)

def __len__(self):
def __len__(self) -> int:
return len(self._params)

def __contains__(self, name):
def __contains__(self, name: str) -> bool:
return name in self._params

def __iter__(self):
def __iter__(self) -> Iterator[str]:
"""
Iterate over param names.
"""
return iter(self.keys())

def __delitem__(self, name):
def __delitem__(self, name) -> None:
"""
Remove a parameter from the param store.
"""
unconstrained_value = self._params.pop(name)
self._param_to_name.pop(unconstrained_value)
self._constraints.pop(name)

def __getitem__(self, name):
def __getitem__(self, name: str) -> torch.Tensor:
"""
Get the *constrained* value of a named parameter.
"""
unconstrained_value = self._params[name]

# compute the constrained value
constraint = self._constraints[name]
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
constrained_value: torch.Tensor = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value) # type: ignore[attr-defined]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat, I haven't seen the error-specific type: ignore[-] syntax before. Did you add it manually, or via an IDE that automatically finds the narrowest ignore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an output from mypy error message:

pyro/params/param_store.py:134: error: "Tensor" has no attribute "unconstrained" [attr-defined]

But github copilot is also pretty good at finding those.

Copy link
Member Author

@ordabayevy ordabayevy Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also described here: https://mypy.readthedocs.io/en/stable/common_issues.html#spurious-errors-and-locally-silencing-the-checker

You can use the form # type: ignore[<code>] to only ignore specific errors on the line. This way you are less likely to silence unexpected errors that are not safe to ignore, and this will also document what the purpose of the comment is. See Error codes for more information.


return constrained_value

def __setitem__(self, name, new_constrained_value):
def __setitem__(self, name: str, new_constrained_value: torch.Tensor) -> None:
"""
Set the constrained value of an existing parameter, or the value of a
new *unconstrained* parameter. To declare a new parameter with
Expand All @@ -132,7 +155,12 @@ def __setitem__(self, name, new_constrained_value):
self._params[name] = unconstrained_value
self._param_to_name[unconstrained_value] = name

def setdefault(self, name, init_constrained_value, constraint=constraints.real):
def setdefault(
self,
name: str,
init_constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]],
constraint: constraints.Constraint = constraints.real,
) -> torch.Tensor:
"""
Retrieve a *constrained* parameter value from the if it exists, otherwise
set the initial value. Note that this is a little fancier than
Expand Down Expand Up @@ -170,32 +198,38 @@ def setdefault(self, name, init_constrained_value, constraint=constraints.real):
# -------------------------------------------------------------------------------
# Old non-dict interface

def named_parameters(self):
def named_parameters(self) -> ItemsView[str, torch.Tensor]:
"""
Returns an iterator over ``(name, unconstrained_value)`` tuples for
each parameter in the ParamStore. Note that, in the event the parameter is constrained,
`unconstrained_value` is in the unconstrained space implicitly used by the constraint.
"""
return self._params.items()

def get_all_param_names(self):
def get_all_param_names(self) -> KeysView[str]:
warnings.warn(
"ParamStore.get_all_param_names() is deprecated; use .keys() instead.",
DeprecationWarning,
)
return self.keys()

def replace_param(self, param_name, new_param, old_param):
def replace_param(
self, param_name: str, new_param: torch.Tensor, old_param: torch.Tensor
) -> None:
warnings.warn(
"ParamStore.replace_param() is deprecated; use .__setitem__() instead.",
DeprecationWarning,
)
assert self._params[param_name] is old_param.unconstrained()
assert self._params[param_name] is old_param.unconstrained() # type: ignore[attr-defined]
self[param_name] = new_param

def get_param(
self, name, init_tensor=None, constraint=constraints.real, event_dim=None
):
self,
name: str,
init_tensor: Optional[torch.Tensor] = None,
constraint: constraints.Constraint = constraints.real,
event_dim: Optional[int] = None,
) -> torch.Tensor:
"""
Get parameter from its name. If it does not yet exist in the
ParamStore, it will be created and stored.
Expand All @@ -216,7 +250,7 @@ def get_param(
else:
return self.setdefault(name, init_tensor, constraint)

def match(self, name):
def match(self, name: str) -> Dict[str, torch.Tensor]:
"""
Get all parameters that match regex. The parameter must exist.

Expand All @@ -227,7 +261,7 @@ def match(self, name):
pattern = re.compile(name)
return {name: self[name] for name in self if pattern.match(name)}

def param_name(self, p):
def param_name(self, p: torch.Tensor) -> Optional[str]:
"""
Get parameter name from parameter

Expand All @@ -239,18 +273,18 @@ def param_name(self, p):
# -------------------------------------------------------------------------------
# Persistence interface

def get_state(self) -> dict:
def get_state(self) -> StateDict:
"""
Get the ParamStore state.
"""
params = self._params.copy()
# Remove weakrefs in preparation for pickling.
for param in params.values():
param.__dict__.pop("unconstrained", None)
state = {"params": params, "constraints": self._constraints.copy()}
state: StateDict = {"params": params, "constraints": self._constraints.copy()}
return state

def set_state(self, state: dict):
def set_state(self, state: StateDict) -> None:
"""
Set the ParamStore state using state from a previous :meth:`get_state` call
"""
Expand All @@ -269,7 +303,7 @@ def set_state(self, state: dict):
constraint = constraints.real
self._constraints[param_name] = constraint

def save(self, filename):
def save(self, filename: str) -> None:
"""
Save parameters to file

Expand All @@ -279,7 +313,7 @@ def save(self, filename):
with open(filename, "wb") as output_file:
torch.save(self.get_state(), output_file)

def load(self, filename, map_location=None):
def load(self, filename: str, map_location: MAP_LOCATION = None) -> None:
"""
Loads parameters from file

Expand All @@ -301,7 +335,7 @@ def load(self, filename, map_location=None):
self.set_state(state)

@contextmanager
def scope(self, state=None) -> dict:
def scope(self, state: Optional[StateDict] = None) -> Iterator[StateDict]:
"""
Context manager for using multiple parameter stores within the same process.

Expand Down Expand Up @@ -343,19 +377,19 @@ def scope(self, state=None) -> dict:
_MODULE_NAMESPACE_DIVIDER = "$$$"


def param_with_module_name(pyro_name, param_name):
def param_with_module_name(pyro_name: str, param_name: str) -> str:
return _MODULE_NAMESPACE_DIVIDER.join([pyro_name, param_name])


def module_from_param_with_module_name(param_name):
def module_from_param_with_module_name(param_name: str) -> str:
return param_name.split(_MODULE_NAMESPACE_DIVIDER)[0]


def user_param_name(param_name):
def user_param_name(param_name: str) -> str:
if _MODULE_NAMESPACE_DIVIDER in param_name:
return param_name.split(_MODULE_NAMESPACE_DIVIDER)[1]
return param_name


def normalize_param_name(name):
def normalize_param_name(name: str) -> str:
return name.replace(_MODULE_NAMESPACE_DIVIDER, ".")
4 changes: 0 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,6 @@ warn_unused_ignores = True
[mypy-pyro.optm.*]
warn_unused_ignores = True

[mypy-pyro.params.*]
ignore_errors = True
warn_unused_ignores = True

[mypy-pyro.poutine.*]
ignore_errors = True

Expand Down