Skip to content

Commit

Permalink
Fixes #3156 (#3157)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Nov 22, 2022
1 parent 44267ff commit 39a4111
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class HMC(MCMCKernel):
step size, hence the sampling will be slower and more robust. Default to 0.8.
:param callable init_strategy: A per-site initialization function.
See :ref:`autoguide-initialization` section for available functions.
:param min_stepsize (float): Lower bound on stepsize in adaptation strategy.
:param max_stepsize (float): Upper bound on stepsize in adaptation strategy.
.. note:: Internally, the mass matrix will be ordered according to the order
of the names of latent variables, not the order of their appearance in
Expand Down Expand Up @@ -108,6 +110,9 @@ def __init__(
ignore_jit_warnings=False,
target_accept_prob=0.8,
init_strategy=init_to_uniform,
*,
min_stepsize: float = 1e-10,
max_stepsize: float = 1e10,
):
if not ((model is None) ^ (potential_fn is None)):
raise ValueError("Only one of `model` or `potential_fn` must be specified.")
Expand All @@ -119,6 +124,8 @@ def __init__(
self._jit_options = jit_options
self._ignore_jit_warnings = ignore_jit_warnings
self._init_strategy = init_strategy
self._min_stepsize = min_stepsize
self._max_stepsize = max_stepsize

self.potential_fn = potential_fn
if trajectory_length is not None:
Expand Down Expand Up @@ -188,9 +195,11 @@ def _find_reasonable_step_size(self, z):
step_size_scale = 2**direction
direction_new = direction
# keep scale step_size until accept_prob crosses its target
# TODO: make thresholds for too small step_size or too large step_size
t = 0
while direction_new == direction:
while (
direction_new == direction
and self._min_stepsize < step_size < self._max_stepsize
):
t += 1
step_size = step_size_scale * step_size
r, r_unscaled = self._sample_r(name="r_presample_{}".format(t))
Expand All @@ -206,6 +215,8 @@ def _find_reasonable_step_size(self, z):
energy_new = self._kinetic_energy(r_new_unscaled) + potential_energy_new
delta_energy = energy_new - energy_current
direction_new = 1 if self._direction_threshold < -delta_energy else -1
step_size = max(step_size, self._min_stepsize)
step_size = min(step_size, self._max_stepsize)
return step_size

def _sample_r(self, name):
Expand Down

0 comments on commit 39a4111

Please sign in to comment.