From fb9087a3cf4dd143b47788c306bd528414d64039 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 3 Jan 2025 03:20:04 +0100 Subject: [PATCH] RF RunningMean, update only in train --- returnn/frontend/reduce.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/returnn/frontend/reduce.py b/returnn/frontend/reduce.py index 035c30d5c..55ad8a7aa 100644 --- a/returnn/frontend/reduce.py +++ b/returnn/frontend/reduce.py @@ -203,6 +203,7 @@ def __init__( alpha: float, dtype: Optional[str] = None, is_prob_distribution: Optional[bool] = None, + update_only_in_train: bool = True, ): """ :param in_dim: the dim of the mean vector, or the shape. @@ -210,6 +211,8 @@ def __init__( Also called momentum. E.g. 0.1 is a common value, or less, like 0.001. :param dtype: the dtype of the mean vector :param is_prob_distribution: if True, will initialize the mean vector with 1/in_dim. + :param update_only_in_train: if True (default), will only update the mean vector in training mode. + False means it will always update. """ super().__init__() self.in_dim = in_dim @@ -221,15 +224,20 @@ def __init__( if is_prob_distribution: assert in_dim.dimension is not None self.mean.initial = 1.0 / in_dim.dimension + self.update_only_in_train = update_only_in_train def __call__(self, x: Tensor) -> Tensor: """ :param x: shape [..., F] :return: shape [F] """ - assert all(d in self.shape for d in x.dims) - x_ = rf.reduce_mean(x, axis=[d for d in x.dims if d not in self.shape]) - self.mean.assign_add(self.alpha * (x_ - self.mean)) + + def _update_running_stats(): + assert all(d in self.shape for d in x.dims) + x_ = rf.reduce_mean(x, axis=[d for d in x.dims if d not in self.shape]) + self.mean.assign_add(self.alpha * (x_ - self.mean)) + + rf.cond((not self.update_only_in_train) or rf.get_run_ctx().train_flag, _update_running_stats, lambda: None) return self.mean