From e31cd3f857c34efee0297534e501861308a06f73 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Thu, 26 Sep 2024 17:06:16 +0000 Subject: [PATCH] kwargs passing to _debug_inferer and a test --- src/resp_ode/mechanistic_inferer.py | 15 ++++----------- tests/test_inferer.py | 7 +++++++ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/resp_ode/mechanistic_inferer.py b/src/resp_ode/mechanistic_inferer.py index 831edce..f935dc4 100644 --- a/src/resp_ode/mechanistic_inferer.py +++ b/src/resp_ode/mechanistic_inferer.py @@ -336,14 +336,11 @@ def infer(self, obs_metrics: jax.Array) -> MCMC: self.infer_complete = True return self.inference_algo - def _debug_likelihood(self, obs_metrics) -> bx.Model: + def _debug_likelihood(self, **kwargs) -> bx.Model: """uses Bayeux to recreate the self.likelihood function for purposes of basic sanity checking - Parameters - ---------- - obs_metrics: jnp.array - observed metrics on which likelihood will be calculated on to tune parameters. - See `likelihood()` method for implemented definition of `obs_metrics` + passes all parameters given to it to `self.likelihood`, initializes with `self.INITIAL_STATE` + and passes `self.config.INFERENCE_PRNGKEY` as seed for randomness. Returns ------- @@ -351,11 +348,7 @@ def _debug_likelihood(self, obs_metrics) -> bx.Model: model object used to debug """ bx_model = bx.Model.from_numpyro( - jax.tree_util.Partial( - self.likelihood, - tf=len(obs_metrics), - obs_metrics=obs_metrics, - ), + jax.tree_util.Partial(self.likelihood, **kwargs), # this does not work for non-one/sampled self.INITIAL_INFECTIONS_SCALE initial_state=self.INITIAL_STATE, ) diff --git a/tests/test_inferer.py b/tests/test_inferer.py index bbe75ce..03673bd 100644 --- a/tests/test_inferer.py +++ b/tests/test_inferer.py @@ -176,3 +176,10 @@ def test_random_sampling_across_chains_and_particles(): "Unable to run all tests within test_random_sampling_across_chains_and_particles " "since you have only one chain! check test_config_inferer.json" ) + + +def test_debug_inferer(): + """A simple test to make sure the _debug_likelihood function does not explode and correctly passes kwargs""" + inferer._debug_likelihood( + tf=len(synthetic_hosp_obs), obs_metrics=synthetic_hosp_obs + )