From fee68a83e334c464fc8d6c55dda31c16e2d2f8b4 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 8 Jan 2025 10:49:03 +0100 Subject: [PATCH] RF set_parameter_copy_behavior_ctx, allows to share params --- returnn/frontend/parameter.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/returnn/frontend/parameter.py b/returnn/frontend/parameter.py index 1b480a60c..136c00bee 100644 --- a/returnn/frontend/parameter.py +++ b/returnn/frontend/parameter.py @@ -4,12 +4,13 @@ from __future__ import annotations from typing import Optional, Union, TypeVar, Sequence +from contextlib import contextmanager from returnn.tensor import Tensor, Dim import returnn.frontend as rf from ._backend import global_backend as _global_backend -__all__ = ["Parameter"] +__all__ = ["Parameter", "set_parameter_copy_behavior_ctx"] T = TypeVar("T") @@ -97,8 +98,11 @@ def __init__( self.initial = initial # use setter def __copy__(self): + if _copy_behavior == "share": + return self # Should return new copy. https://github.com/rwth-i6/returnn_common/pull/215#issuecomment-1269651064 # Note that the values are *not* copied, but rather it will use the same param init scheme. + assert _copy_behavior == "copy_init" res = type(self)( dims=self.dims, dtype=self.dtype, @@ -111,10 +115,14 @@ def __copy__(self): return res def __deepcopy__(self, memo=None): + if _copy_behavior == "share": + return self + # Should return new copy. https://github.com/rwth-i6/returnn_common/pull/215#issuecomment-1269651064 # Note that the values are *not* copied, but rather it will use the same param init scheme. from copy import deepcopy + assert _copy_behavior == "copy_init" res = self.__copy__() if isinstance(self.initial, rf.init.ParamInit): res.initial = deepcopy(self.initial, memo=memo) # noqa @@ -241,3 +249,27 @@ def non_critical_for_restore(self) -> bool: @non_critical_for_restore.setter def non_critical_for_restore(self, value: bool): self._non_critical_for_restore = value + + +_copy_behavior = "copy_init" + + +@contextmanager +def set_parameter_copy_behavior_ctx(mode: str): + """ + Set the copy behavior for parameters. + + :param mode: "copy_init" or "share". + "copy_init" (default when not set): When copying a parameter, a new parameter will be created. + It is not a shared parameter. + It does not copy the values, but it just copies the param init scheme. + "share": The parameter will not be copied, it will just return itself. + """ + global _copy_behavior + assert mode in ("copy_init", "share") + old = _copy_behavior + try: + _copy_behavior = mode + yield + finally: + _copy_behavior = old