Skip to content

Commit

Permalink
Implement parallelized DiscreteHMM distribution (#1958)
Browse files Browse the repository at this point in the history
* Sketch log-time HMM distribution

* Revise interface; add tests

* Add DiscreteHMM to examples/hmm.py

* Add more tests

* Simplify examples/hmm.py

* flake8

* Updates per review
  • Loading branch information
fritzo authored and eb8680 committed Jul 16, 2019
1 parent bf2f954 commit 503e57f
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 8 deletions.
7 changes: 7 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ DirichletMultinomial
:undoc-members:
:show-inheritance:

DiscreteHMM
--------------------
.. autoclass:: pyro.distributions.DiscreteHMM
:members:
:undoc-members:
:show-inheritance:

EmpiricalDistribution
----------------------
.. autoclass:: pyro.distributions.Empirical
Expand Down
48 changes: 40 additions & 8 deletions examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,17 +380,12 @@ def __init__(self, args, data_dim):
self.relu = nn.ReLU()

def forward(self, x, y):
# Check dimension of y so this can be used with and without enumeration.
if y.dim() < 2:
y = y.unsqueeze(0)

# Hidden units depend on two inputs: a one-hot encoded categorical variable x, and
# a bernoulli variable y. Whereas x will typically be enumerated, y will be observed.
# We apply x_to_hidden independently from y_to_hidden, then broadcast the non-enumerated
# y part up to the enumerated x part in the + operation.
x_onehot = (torch.zeros(x.shape[:-1] + (self.args.hidden_dim,), dtype=y.dtype, device=y.device)
.scatter_(-1, x, 1))
y_conv = self.relu(self.conv(y.unsqueeze(-2))).reshape(y.shape[:-1] + (-1,))
x_onehot = y.new_zeros(x.shape[:-1] + (self.args.hidden_dim,)).scatter_(-1, x, 1)
y_conv = self.relu(self.conv(y.reshape(-1, 1, self.data_dim))).reshape(y.shape[:-1] + (-1,))
h = self.relu(self.x_to_hidden(x_onehot) + self.y_to_hidden(y_conv))
return self.hidden_to_logits(h)

Expand Down Expand Up @@ -493,6 +488,41 @@ def model_6(sequences, lengths, args, batch_size=None, include_prior=False):
obs=sequences[batch, t])


# Next we demonstrate how to parallelize the neural HMM above using Pyro's
# DiscreteHMM distribution. This model is equivalent to model_5 above, but we
# manually unroll loops and fuse ops, leading to a single sample statement.
# DiscreteHMM can lead to over 10x speedup in models where it is applicable.
def model_7(sequences, lengths, args, batch_size=None, include_prior=True):
with ignore_jit_warnings():
num_sequences, max_length, data_dim = map(int, sequences.shape)
assert lengths.shape == (num_sequences,)
assert lengths.max() <= max_length

# Initialize a global module instance if needed.
global tones_generator
if tones_generator is None:
tones_generator = TonesGenerator(args, data_dim)
pyro.module("tones_generator", tones_generator)

with poutine.mask(mask=include_prior):
probs_x = pyro.sample("probs_x",
dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1)
.to_event(1))
with pyro.plate("sequences", num_sequences, batch_size, dim=-1) as batch:
lengths = lengths[batch]
y = sequences[batch] if args.jit else sequences[batch, :lengths.max()]
x = torch.arange(args.hidden_dim)
t = torch.arange(y.size(1))
init_logits = torch.full((args.hidden_dim,), -float('inf'))
init_logits[0] = 0
trans_logits = probs_x.log()
with ignore_jit_warnings():
obs_dist = dist.Bernoulli(logits=tones_generator(x, y.unsqueeze(-2))).to_event(1)
obs_dist = obs_dist.mask((t < lengths.unsqueeze(-1)).unsqueeze(-1))
hmm_dist = dist.DiscreteHMM(init_logits, trans_logits, obs_dist)
pyro.sample("y", hmm_dist, obs=y)


models = {name[len('model_'):]: model
for name, model in globals().items()
if name.startswith('model_')}
Expand Down Expand Up @@ -546,7 +576,9 @@ def main(args):
# Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
# All of our models have two plates: "data" and "tones".
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2)
elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2,
strict_enumeration_warning=(model is not model_7),
jit_options={"optimize": model is model_7})
optim = Adam({'lr': args.learning_rate})
svi = SVI(model, guide, optim, elbo)

Expand Down
2 changes: 2 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pyro.distributions.distribution import Distribution
from pyro.distributions.empirical import Empirical
from pyro.distributions.gaussian_scale_mixture import GaussianScaleMixture
from pyro.distributions.hmm import DiscreteHMM
from pyro.distributions.inverse_gamma import InverseGamma
from pyro.distributions.mixture import MaskedMixture
from pyro.distributions.omt_mvn import OMTMultivariateNormal
Expand All @@ -32,6 +33,7 @@
"BetaBinomial",
"Delta",
"DirichletMultinomial",
"DiscreteHMM",
"Distribution",
"Empirical",
"GammaPoisson",
Expand Down
134 changes: 134 additions & 0 deletions pyro/distributions/hmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import absolute_import, division, print_function

import torch
from torch.distributions import constraints

from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.util import broadcast_shape


def _logmatmulexp(x, y):
"""
Numerically stable version of ``(x.log() @ y.log()).exp()``.
"""
x_shift = x.max(-1, keepdim=True)[0]
y_shift = y.max(-2, keepdim=True)[0]
xy = torch.matmul((x - x_shift).exp(), (y - y_shift).exp()).log()
return xy + x_shift + y_shift


def _sequential_logmatmulexp(logits):
"""
For a tensor ``x`` whose time dimension is -3, computes::
x[..., 0, :, :] @ x[..., 1, :, :] @ ... @ x[..., T-1, :, :]
but does so numerically stably in log space.
"""
batch_shape = logits.shape[:-3]
state_dim = logits.size(-1)
while logits.size(-3) > 1:
time = logits.size(-3)
even_time = time // 2 * 2
even_part = logits[..., :even_time, :, :]
x_y = even_part.reshape(batch_shape + (even_time // 2, 2, state_dim, state_dim))
x, y = x_y.unbind(-3)
contracted = _logmatmulexp(x, y)
if time > even_time:
contracted = torch.cat((contracted, logits[..., -1:, :, :]), dim=-3)
logits = contracted
return logits.squeeze(-3)


class DiscreteHMM(TorchDistribution):
"""
Hidden Markov Model with discrete latent state and arbitrary observation
distribution. This uses [1] to parallelize over time, achieving
O(log(time)) parallel complexity.
The event_shape of this distribution includes time on the left::
event_shape = (num_steps,) + observation_dist.event_shape
This distribution supports any combination of homogeneous/heterogeneous
time dependency of ``transition_logits`` and ``observation_dist``. However,
because time is included in this distribution's event_shape, the
homogeneous+homogeneous case will have a broadcastable event_shape with
``num_steps = 1``, allowing :meth:`log_prob` to work with arbitrary length
data::
# homogeneous + homogeneous case:
event_shape = (1,) + observation_dist.event_shape
**References:**
[1] Simo Sarkka, Angel F. Garcia-Fernandez (2019)
"Temporal Parallelization of Bayesian Filters and Smoothers"
https://arxiv.org/pdf/1905.13002.pdf
:param torch.Tensor initial_logits: A logits tensor for an initial
categorical distribution over latent states. Should have rightmost size
``state_dim`` and be broadcastable to ``batch_shape + (state_dim,)``.
:param torch.Tensor transition_logits: A logits tensor for transition
conditional distributions between latent states. Should have rightmost
shape ``(state_dim, state_dim)`` (old, new), and be broadcastable to
``batch_shape + (num_steps, state_dim, state_dim)``.
:param torch.distributions.Distribution observation_dist: A conditional
distribution of observed data conditioned on latent state. The
``.batch_shape`` should have rightmost size ``state_dim`` and be
broadcastable to ``batch_shape + (num_steps, state_dim)``. The
``.event_shape`` may be arbitrary.
"""
arg_constraints = {"initial_logits": constraints.real,
"transition_logits": constraints.real}

def __init__(self, initial_logits, transition_logits, observation_dist, validate_args=None):
if initial_logits.dim() < 1:
raise ValueError("expected initial_logits to have at least one dim, "
"actual shape = {}".format(initial_logits.shape))
if transition_logits.dim() < 2:
raise ValueError("expected transition_logits to have at least two dims, "
"actual shape = {}".format(transition_logits.shape))
if len(observation_dist.batch_shape) < 1:
raise ValueError("expected observation_dist to have at least one batch dim, "
"actual .batch_shape = {}".format(observation_dist.batch_shape))
time_shape = broadcast_shape((1,), transition_logits.shape[-3:-2],
observation_dist.batch_shape[-2:-1])
event_shape = time_shape + observation_dist.event_shape
batch_shape = broadcast_shape(initial_logits.shape[:-1],
transition_logits.shape[:-3],
observation_dist.batch_shape[:-2])
self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
self.transition_logits = transition_logits - transition_logits.logsumexp(-1, True)
self.observation_dist = observation_dist
super(DiscreteHMM, self).__init__(batch_shape, event_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(DiscreteHMM, _instance)
batch_shape = torch.Size(broadcast_shape(self.batch_shape, batch_shape))
# We only need to expand one of the inputs, since batch_shape is determined
# by broadcasting all three. To save computation in _sequential_logmatmulexp(),
# we expand only initial_logits, which is applied only after the logmatmulexp.
# This is similar to the ._unbroadcasted_* pattern used elsewhere in distributions.
new.initial_logits = self.initial_logits.expand(batch_shape + (-1,))
new.transition_logits = self.transition_logits
new.observation_dist = self.observation_dist
super(DiscreteHMM, new).__init__(batch_shape, self.event_shape, validate_args=False)
new.validate_args = self.__dict__.get('_validate_args')
return new

def log_prob(self, value):
# Combine observation and transition factors.
value = value.unsqueeze(-1 - self.observation_dist.event_dim)
observation_logits = self.observation_dist.log_prob(value)
result = self.transition_logits + observation_logits.unsqueeze(-2)

# Eliminate time dimension.
result = _sequential_logmatmulexp(result)

# Combine initial factor.
result = _logmatmulexp(self.initial_logits.unsqueeze(-2), result).squeeze(-2)

# Marginalize out final state.
result = result.logsumexp(-1)
return result
Loading

0 comments on commit 503e57f

Please sign in to comment.