Skip to content

Commit

Permalink
RF RunningMean, update only in train
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 3, 2025
1 parent 0dd1b3f commit fb9087a
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions returnn/frontend/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,16 @@ def __init__(
alpha: float,
dtype: Optional[str] = None,
is_prob_distribution: Optional[bool] = None,
update_only_in_train: bool = True,
):
"""
:param in_dim: the dim of the mean vector, or the shape.
:param alpha: factor for new_value. 0.0 means no update, 1.0 means always the new value.
Also called momentum. E.g. 0.1 is a common value, or less, like 0.001.
:param dtype: the dtype of the mean vector
:param is_prob_distribution: if True, will initialize the mean vector with 1/in_dim.
:param update_only_in_train: if True (default), will only update the mean vector in training mode.
False means it will always update.
"""
super().__init__()
self.in_dim = in_dim
Expand All @@ -221,15 +224,20 @@ def __init__(
if is_prob_distribution:
assert in_dim.dimension is not None
self.mean.initial = 1.0 / in_dim.dimension
self.update_only_in_train = update_only_in_train

def __call__(self, x: Tensor) -> Tensor:
"""
:param x: shape [..., F]
:return: shape [F]
"""
assert all(d in self.shape for d in x.dims)
x_ = rf.reduce_mean(x, axis=[d for d in x.dims if d not in self.shape])
self.mean.assign_add(self.alpha * (x_ - self.mean))

def _update_running_stats():
assert all(d in self.shape for d in x.dims)
x_ = rf.reduce_mean(x, axis=[d for d in x.dims if d not in self.shape])
self.mean.assign_add(self.alpha * (x_ - self.mean))

rf.cond((not self.update_only_in_train) or rf.get_run_ctx().train_flag, _update_running_stats, lambda: None)
return self.mean


Expand Down

0 comments on commit fb9087a

Please sign in to comment.