diff --git a/dynamax/hidden_markov_model/models/test_models.py b/dynamax/hidden_markov_model/models/test_models.py index f5810768..a653b09f 100644 --- a/dynamax/hidden_markov_model/models/test_models.py +++ b/dynamax/hidden_markov_model/models/test_models.py @@ -1,5 +1,4 @@ import pytest -from datetime import datetime import jax.numpy as jnp import jax.random as jr from jax import vmap @@ -21,7 +20,7 @@ (models.LowRankGaussianHMM, dict(num_states=4, emission_dim=3, emission_rank=1), None), (models.GaussianMixtureHMM, dict(num_states=4, num_components=2, emission_dim=3, emission_prior_mean_concentration=1.0), None), (models.DiagonalGaussianMixtureHMM, dict(num_states=4, num_components=2, emission_dim=3, emission_prior_mean_concentration=1.0), None), - (models.LinearRegressionHMM, dict(num_states=4, emission_dim=3, input_dim=5), jnp.ones((NUM_TIMESTEPS, 5))), + (models.LinearRegressionHMM, dict(num_states=3, emission_dim=3, input_dim=5), jr.normal(jr.PRNGKey(0),(NUM_TIMESTEPS, 5))), (models.LogisticRegressionHMM, dict(num_states=4, input_dim=5), jnp.ones((NUM_TIMESTEPS, 5))), (models.MultinomialHMM, dict(num_states=4, emission_dim=3, num_classes=5, num_trials=10), None), (models.PoissonHMM, dict(num_states=4, emission_dim=3), None), @@ -31,7 +30,6 @@ @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS) def test_sample_and_fit(cls, kwargs, inputs): hmm = cls(**kwargs) - #key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp()))) key1, key2 = jr.split(jr.PRNGKey(42)) params, param_props = hmm.initialize(key1) states, emissions = hmm.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs)