Skip to content

Commit

Permalink
kwargs passing to _debug_inferer and a test
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Sep 26, 2024
1 parent f76dace commit e31cd3f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
15 changes: 4 additions & 11 deletions src/resp_ode/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,26 +336,19 @@ def infer(self, obs_metrics: jax.Array) -> MCMC:
self.infer_complete = True
return self.inference_algo

def _debug_likelihood(self, obs_metrics) -> bx.Model:
def _debug_likelihood(self, **kwargs) -> bx.Model:
"""uses Bayeux to recreate the self.likelihood function for purposes of basic sanity checking
Parameters
----------
obs_metrics: jnp.array
observed metrics on which likelihood will be calculated on to tune parameters.
See `likelihood()` method for implemented definition of `obs_metrics`
passes all parameters given to it to `self.likelihood`, initializes with `self.INITIAL_STATE`
and passes `self.config.INFERENCE_PRNGKEY` as seed for randomness.
Returns
-------
Bayeux.Model
model object used to debug
"""
bx_model = bx.Model.from_numpyro(
jax.tree_util.Partial(
self.likelihood,
tf=len(obs_metrics),
obs_metrics=obs_metrics,
),
jax.tree_util.Partial(self.likelihood, **kwargs),
# this does not work for non-one/sampled self.INITIAL_INFECTIONS_SCALE
initial_state=self.INITIAL_STATE,
)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,10 @@ def test_random_sampling_across_chains_and_particles():
"Unable to run all tests within test_random_sampling_across_chains_and_particles "
"since you have only one chain! check test_config_inferer.json"
)


def test_debug_inferer():
"""A simple test to make sure the _debug_likelihood function does not explode and correctly passes kwargs"""
inferer._debug_likelihood(
tf=len(synthetic_hosp_obs), obs_metrics=synthetic_hosp_obs
)

0 comments on commit e31cd3f

Please sign in to comment.