Skip to content

Commit

Permalink
Type annotate pyro.nn.dense_nn and pyro.nn.auto_reg_nn (#3342)
Browse files Browse the repository at this point in the history
  • Loading branch information
ordabayevy authored Mar 18, 2024
1 parent 0474cc9 commit 3acd77d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 35 deletions.
61 changes: 37 additions & 24 deletions pyro/nn/auto_reg_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from torch.nn import functional as F


def sample_mask_indices(input_dim, hidden_dim, simple=True):
def sample_mask_indices(
input_dim: int, hidden_dim: int, simple: bool = True
) -> torch.Tensor:
"""
Samples the indices assigned to hidden units during the construction of MADE masks
Expand All @@ -33,8 +36,12 @@ def sample_mask_indices(input_dim, hidden_dim, simple=True):


def create_mask(
input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier
):
input_dim: int,
context_dim: int,
hidden_dims: List[int],
permutation: torch.LongTensor,
output_dim_multiplier: int,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""
Creates MADE masks for a conditional distribution
Expand Down Expand Up @@ -109,11 +116,13 @@ class MaskedLinear(nn.Linear):
:type bias: bool
"""

def __init__(self, in_features, out_features, mask, bias=True):
def __init__(
self, in_features: int, out_features: int, mask: torch.Tensor, bias: bool = True
) -> None:
super().__init__(in_features, out_features, bias)
self.register_buffer("mask", mask.data)

def forward(self, _input):
def forward(self, _input: torch.Tensor) -> torch.Tensor:
masked_weight = self.weight * self.mask
return F.linear(_input, masked_weight, self.bias)

Expand Down Expand Up @@ -166,14 +175,14 @@ class ConditionalAutoRegressiveNN(nn.Module):

def __init__(
self,
input_dim,
context_dim,
hidden_dims,
param_dims=[1, 1],
permutation=None,
skip_connections=False,
nonlinearity=nn.ReLU(),
):
input_dim: int,
context_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
permutation: Optional[torch.LongTensor] = None,
skip_connections: bool = False,
nonlinearity: torch.nn.Module = nn.ReLU(),
) -> None:
super().__init__()
if input_dim == 1:
warnings.warn(
Expand Down Expand Up @@ -206,6 +215,7 @@ def __init__(
else:
# The permutation is chosen by the user
P = permutation.type(dtype=torch.int64)
self.permutation: torch.LongTensor
self.register_buffer("permutation", P)

# Create masks
Expand All @@ -230,6 +240,7 @@ def __init__(
)
self.layers = nn.ModuleList(layers)

self.skip_layer: Optional[MaskedLinear]
if skip_connections:
self.skip_layer = MaskedLinear(
input_dim + context_dim,
Expand All @@ -243,13 +254,15 @@ def __init__(
# Save the nonlinearity
self.f = nonlinearity

def get_permutation(self):
def get_permutation(self) -> torch.LongTensor:
"""
Get the permutation applied to the inputs (by default this is chosen at random)
"""
return self.permutation

def forward(self, x, context=None):
def forward(
self, x: torch.Tensor, context: Optional[torch.Tensor] = None
) -> Union[Sequence[torch.Tensor], torch.Tensor]:
# We must be able to broadcast the size of the context over the input
if context is None:
context = self.context
Expand All @@ -258,7 +271,7 @@ def forward(self, x, context=None):
x = torch.cat([context, x], dim=-1)
return self._forward(x)

def _forward(self, x):
def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]:
h = x
for layer in self.layers[:-1]:
h = self.f(layer(h))
Expand Down Expand Up @@ -328,13 +341,13 @@ class AutoRegressiveNN(ConditionalAutoRegressiveNN):

def __init__(
self,
input_dim,
hidden_dims,
param_dims=[1, 1],
permutation=None,
skip_connections=False,
nonlinearity=nn.ReLU(),
):
input_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
permutation: Optional[torch.LongTensor] = None,
skip_connections: bool = False,
nonlinearity: torch.nn.Module = nn.ReLU(),
) -> None:
super(AutoRegressiveNN, self).__init__(
input_dim,
0,
Expand All @@ -345,5 +358,5 @@ def __init__(
nonlinearity=nonlinearity,
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override]
return self._forward(x)
30 changes: 19 additions & 11 deletions pyro/nn/dense_nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from typing import List, Sequence, Union

import torch


Expand Down Expand Up @@ -35,12 +37,12 @@ class ConditionalDenseNN(torch.nn.Module):

def __init__(
self,
input_dim,
context_dim,
hidden_dims,
param_dims=[1, 1],
nonlinearity=torch.nn.ReLU(),
):
input_dim: int,
context_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
) -> None:
super().__init__()

self.input_dim = input_dim
Expand All @@ -65,14 +67,16 @@ def __init__(
# Save the nonlinearity
self.f = nonlinearity

def forward(self, x, context):
def forward(
self, x: torch.Tensor, context: torch.Tensor
) -> Union[Sequence[torch.Tensor], torch.Tensor]:
# We must be able to broadcast the size of the context over the input
context = context.expand(x.size()[:-1] + (context.size(-1),))

x = torch.cat([context, x], dim=-1)
return self._forward(x)

def _forward(self, x):
def _forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]:
"""
The forward method
"""
Expand Down Expand Up @@ -122,11 +126,15 @@ class DenseNN(ConditionalDenseNN):
"""

def __init__(
self, input_dim, hidden_dims, param_dims=[1, 1], nonlinearity=torch.nn.ReLU()
):
self,
input_dim: int,
hidden_dims: List[int],
param_dims: List[int] = [1, 1],
nonlinearity: torch.nn.Module = torch.nn.ReLU(),
) -> None:
super(DenseNN, self).__init__(
input_dim, 0, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity
)

def forward(self, x):
def forward(self, x: torch.Tensor) -> Union[Sequence[torch.Tensor], torch.Tensor]: # type: ignore[override]
return self._forward(x)

0 comments on commit 3acd77d

Please sign in to comment.