diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 2e9498e99..57a83151c 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -186,6 +186,11 @@ def potential_fn(z_gibbs, z_hmc): return HMCGibbsState(z, hmc_state, rng_key) + def __getstate__(self): + state = self.__dict__.copy() + state["_prototype_trace"] = None + return state + def _discrete_gibbs_proposal_body_fn( z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 863418602..726d18c4c 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -307,4 +307,6 @@ def body_fn(i, vals): def __getstate__(self): state = self.__dict__.copy() state["_wa_update"] = None + state["_prototype_trace"] = None + state["_support_sizes_flat"] = None return state