Skip to content

Commit

Permalink
support for padded arrays in hmm inference
Browse files Browse the repository at this point in the history
  • Loading branch information
slinderman committed Sep 19, 2023
1 parent 95a09ba commit b1f7249
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 30 deletions.
87 changes: 57 additions & 30 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def hmm_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None,
num_timesteps: Optional[Int] = None,
) -> HMMPosteriorFiltered:
r"""Forwards filtering
Expand All @@ -115,27 +116,32 @@ def hmm_filter(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.
Returns:
filtered posterior distribution
"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps = log_likelihoods.shape[0]
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

def _step(carry, t):
log_normalizer, predicted_probs = carry

A = get_trans_mat(transition_matrix, transition_fn, t)
ll = log_likelihoods[t]

# Ignore observations after specified number of timesteps
ll = jnp.where(t < num_timesteps, ll, 0.0)

filtered_probs, log_norm = _condition_on(predicted_probs, ll)
log_normalizer += log_norm
predicted_probs_next = _predict(filtered_probs, A)

return (log_normalizer, predicted_probs_next), (filtered_probs, predicted_probs)

carry = (0.0, initial_distribution)
(log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
(log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(_step, carry, jnp.arange(max_num_timesteps))

post = HMMPosteriorFiltered(marginal_loglik=log_normalizer,
filtered_probs=filtered_probs,
Expand All @@ -149,7 +155,8 @@ def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int] = None
) -> Tuple[Float, Float[Array, "num_timesteps num_states"]]:
r"""Run the filter backwards in time. This is the second step of the forward-backward algorithm.
Expand All @@ -163,30 +170,33 @@ def hmm_backward_filter(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.
Returns:
marginal log likelihood and backward messages.
"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, num_states = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

def _step(carry, t):
log_normalizer, backward_pred_probs = carry

A = get_trans_mat(transition_matrix, transition_fn, t)
ll = log_likelihoods[t]

# Ignore observations after specified number of timesteps
ll = jnp.where(t < num_timesteps, ll, 0.0)

# Condition on emission at time t, being careful not to overflow.
backward_filt_probs, log_norm = _condition_on(backward_pred_probs, ll)
# Update the log normalizer.
log_normalizer += log_norm

# Predict the next state (going backward in time).
next_backward_pred_probs = _predict(backward_filt_probs, A.T)
return (log_normalizer, next_backward_pred_probs), backward_pred_probs

carry = (0.0, jnp.ones(num_states))
(log_normalizer, _), rev_backward_pred_probs = lax.scan(_step, carry, jnp.arange(num_timesteps)[::-1])
backward_pred_probs = rev_backward_pred_probs[::-1]
(log_normalizer, _), backward_pred_probs = lax.scan(_step, carry, jnp.arange(max_num_timesteps), reverse=True)
return log_normalizer, backward_pred_probs


Expand All @@ -197,6 +207,7 @@ def hmm_two_filter_smoother(
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int] = None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using the two-filter
Expand All @@ -212,16 +223,19 @@ def hmm_two_filter_smoother(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.
Returns:
posterior distribution
"""
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
# Forward
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps)
ll = post.marginal_loglik
filtered_probs, predicted_probs = post.filtered_probs, post.predicted_probs

_, backward_pred_probs = hmm_backward_filter(transition_matrix, log_likelihoods, transition_fn)
# Backward
_, backward_pred_probs = hmm_backward_filter(transition_matrix, log_likelihoods, transition_fn, num_timesteps)

# Compute smoothed probabilities
smoothed_probs = filtered_probs * backward_pred_probs
Expand Down Expand Up @@ -251,6 +265,7 @@ def hmm_smoother(
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int]=None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using a general
Expand All @@ -268,15 +283,17 @@ def hmm_smoother(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.
Returns:
posterior distribution
"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, _ = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps)
ll = post.marginal_loglik
filtered_probs, predicted_probs = post.filtered_probs, post.predicted_probs

Expand All @@ -294,16 +311,15 @@ def _step(carry, args):
smoothed_probs_next / predicted_probs_next)
smoothed_probs = filtered_probs * (A @ relative_probs_next)
smoothed_probs /= smoothed_probs.sum()

return smoothed_probs, smoothed_probs

# Run the HMM smoother
carry = filtered_probs[-1]
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_probs[:-1][::-1], predicted_probs[1:][::-1])
_, rev_smoothed_probs = lax.scan(_step, carry, args)
args = (jnp.arange(max_num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:])
_, smoothed_probs = lax.scan(_step, carry, args, reverse=True)

# Reverse the arrays and return
smoothed_probs = jnp.vstack([rev_smoothed_probs[::-1], filtered_probs[-1]])
smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]])

# Package into a posterior
posterior = HMMPosterior(
Expand Down Expand Up @@ -352,6 +368,7 @@ def hmm_fixed_lag_smoother(
posterior distribution
"""
# TODO: Update to allow variable length time series
num_timesteps, num_states = log_likelihoods.shape

def _step(carry, t):
Expand Down Expand Up @@ -441,7 +458,8 @@ def hmm_posterior_mode(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int]=None,
) -> Int[Array, "num_timesteps"]:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.
Expand All @@ -450,12 +468,14 @@ def hmm_posterior_mode(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.
Returns:
most likely state sequence
"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, _ = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

# Run the backward pass
def _backward_pass(best_next_score, t):
Expand All @@ -464,14 +484,19 @@ def _backward_pass(best_next_score, t):
scores = jnp.log(A) + best_next_score + log_likelihoods[t + 1]
best_next_state = jnp.argmax(scores, axis=1)
best_next_score = jnp.max(scores, axis=1)

# Only update if log_likelihoods[t+1] is valid
best_next_score = jnp.where(t + 1 < num_timesteps, best_next_score, jnp.zeros(num_states))
best_next_state = jnp.where(t + 1 < num_timesteps, best_next_state, jnp.zeros(num_states, dtype=int))

return best_next_score, best_next_state

num_states = log_likelihoods.shape[1]
best_second_score, rev_best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 2, -1, -1)
best_second_score, best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(max_num_timesteps - 1),
reverse=True
)
best_next_states = rev_best_next_states[::-1]


# Run the forward pass
def _forward_pass(state, best_next_state):
next_state = best_next_state[state]
Expand All @@ -490,7 +515,8 @@ def hmm_posterior_sample(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None,
num_timesteps: Optional[Int] = None,
) -> Int[Array, "num_timesteps"]:
r"""Sample a latent sequence from the posterior.
Expand All @@ -505,10 +531,11 @@ def hmm_posterior_sample(
:sample of the latent states, $z_{1:T}$
"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, num_states = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps)
log_normalizer, filtered_probs = post.marginal_loglik, post.filtered_probs

# Run the sampler backward in time
Expand All @@ -528,13 +555,13 @@ def _step(carry, args):
return state, state

# Run the HMM smoother
rngs = jr.split(rng, num_timesteps)
rngs = jr.split(rng, max_num_timesteps)
last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1])
args = (jnp.arange(num_timesteps - 1, 0, -1), rngs[:-1][::-1], filtered_probs[:-1][::-1])
_, rev_states = lax.scan(_step, last_state, args)
args = (jnp.arange(max_num_timesteps - 1), rngs[:-1], filtered_probs[:-1])
_, states = lax.scan(_step, last_state, args, reverse=True)

# Reverse the arrays and return
states = jnp.concatenate([rev_states[::-1], jnp.array([last_state])])
states = jnp.concatenate([states, jnp.array([last_state])])
return log_normalizer, states

def _compute_sum_transition_probs(
Expand Down
38 changes: 38 additions & 0 deletions dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dynamax.hidden_markov_model.inference as core
import dynamax.hidden_markov_model.parallel_inference as parallel

from jax import vmap
from jax.scipy.special import logsumexp

def big_log_joint(initial_probs, transition_matrix, log_likelihoods):
Expand Down Expand Up @@ -259,6 +260,43 @@ def trans_mat_callable(t):
assert jnp.allclose(sample, sample2)


def test_hmm_padding(key=0, num_timesteps=10, num_states=5, padding=3):
if isinstance(key, int):
key = jr.PRNGKey(key)

initial_probs, transition_matrix, log_lkhds = random_hmm_args(key, num_timesteps + padding, num_states)

# Run the HMM filter with a 3d list of transition matrices and a callable
post = core.hmm_filter(initial_probs, transition_matrix, log_lkhds[:num_timesteps])
post2 = core.hmm_filter(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps)
assert jnp.allclose(post.marginal_loglik, post2.marginal_loglik, atol=1e-4)
assert jnp.allclose(post.filtered_probs, post2.filtered_probs[:num_timesteps], atol=1e-4)

# Run the HMM smoother with a 3d list of transition matrices and a callable
post = core.hmm_smoother(initial_probs, transition_matrix, log_lkhds[:num_timesteps])
post2 = core.hmm_smoother(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps)
assert jnp.allclose(post.smoothed_probs, post2.smoothed_probs[:num_timesteps], atol=1e-4)

# Run Viterbi
mode = core.hmm_posterior_mode(initial_probs, transition_matrix, log_lkhds[:num_timesteps])
mode2 = core.hmm_posterior_mode(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps)
assert jnp.allclose(mode, mode2[:num_timesteps])


# Test vmap
def test_hmm_variable_length_vmap(key=0, max_num_timesteps=10, num_states=5, num_seqs=10):
if isinstance(key, int):
key = jr.PRNGKey(key)

all_args = vmap(random_hmm_args, in_axes=(0, None, None))(
jr.split(key, num_seqs), max_num_timesteps, num_states)

all_num_timesteps = jr.randint(key, (num_seqs,), 1, max_num_timesteps)

# Just make sure vmap runs without throwing a concretization error
posteriors = vmap(core.hmm_filter)(*all_args, num_timesteps=all_num_timesteps)


def test_parallel_filter(key=0, num_timesteps=100, num_states=3):
if isinstance(key, int):
key = jr.PRNGKey(key)
Expand Down
1 change: 1 addition & 0 deletions dynamax/nonlinear_gaussian_ssm/inference_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,5 +341,6 @@ def _step(carry, _):
smoothed_posterior = extended_kalman_smoother(params, emissions, smoothed_prior, inputs)
return smoothed_posterior, None

# TODO: Does this even work with None as initial carry?
smoothed_posterior, _ = lax.scan(_step, None, jnp.arange(num_iter))
return smoothed_posterior

0 comments on commit b1f7249

Please sign in to comment.