Skip to content

Commit

Permalink
typo, scope_iter and scope fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Emile Mathieu committed Aug 11, 2017
1 parent 1cdf507 commit 8a031e6
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions edward/inferences/hmcda.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def initialize(self, n_adapt, delta=0.65, Lambda=0.15, *args, **kwargs):
Lambda : float, optional
Target leapfrog length
"""
self.scope_iter = 0 # a convenient counter for log joint calculations
# store global scope for log joint calculations
self._scope = tf.get_default_graph().unique_name("inference") + '/'

# Find initial epsilon
step_size = self.find_good_eps()
Expand Down Expand Up @@ -137,7 +138,7 @@ def build_update(self):
should_adapt = self.t <= self.n_adapt
assign_ops = tf.cond(should_adapt,
lambda: self._adapt_step_size(alpha),
lambda: self._do_not__adapt_step_size(alpha))
lambda: self._do_not_adapt_step_size(alpha))

# Update Empirical random variables.
for z, qz in six.iteritems(self.latent_vars):
Expand All @@ -148,7 +149,7 @@ def build_update(self):
assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0)))
return tf.group(*assign_ops)

def _do_not__adapt_step_size(self, alpha):
def _do_not_adapt_step_size(self, alpha):
# Do not adapt step size but assign last running averaged epsilon to epsilon
assign_ops = []
assign_ops.append(tf.assign(self.H_B, self.H_B).op)
Expand Down Expand Up @@ -188,7 +189,7 @@ def find_good_eps(self):
old_r[z] = normal.sample()

# Initialize espilon at 1.0
epsilon = tf.Variable(1.0, trainable=False)
epsilon = tf.constant(1.0)

# Calculate log joint probability
old_z = {z: tf.gather(qz.params, 0)
Expand Down Expand Up @@ -257,8 +258,7 @@ def _log_joint(self, z_sample):
z_sample : dict
Latent variable keys to samples.
"""
self.scope_iter += 1
scope = 'inference_' + str(id(self)) + '/' + str(self.scope_iter)
scope = self._scope + tf.get_default_graph().unique_name("sample")
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
dict_swap = z_sample.copy()
Expand Down

0 comments on commit 8a031e6

Please sign in to comment.