Skip to content

Commit

Permalink
RF set_parameter_copy_behavior_ctx, allows to share params
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Jan 8, 2025
1 parent 7b8fe93 commit fee68a8
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion returnn/frontend/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from __future__ import annotations
from typing import Optional, Union, TypeVar, Sequence
from contextlib import contextmanager
from returnn.tensor import Tensor, Dim
import returnn.frontend as rf
from ._backend import global_backend as _global_backend


__all__ = ["Parameter"]
__all__ = ["Parameter", "set_parameter_copy_behavior_ctx"]


T = TypeVar("T")
Expand Down Expand Up @@ -97,8 +98,11 @@ def __init__(
self.initial = initial # use setter

def __copy__(self):
if _copy_behavior == "share":
return self
# Should return new copy. https://github.com/rwth-i6/returnn_common/pull/215#issuecomment-1269651064
# Note that the values are *not* copied, but rather it will use the same param init scheme.
assert _copy_behavior == "copy_init"
res = type(self)(
dims=self.dims,
dtype=self.dtype,
Expand All @@ -111,10 +115,14 @@ def __copy__(self):
return res

def __deepcopy__(self, memo=None):
if _copy_behavior == "share":
return self

# Should return new copy. https://github.com/rwth-i6/returnn_common/pull/215#issuecomment-1269651064
# Note that the values are *not* copied, but rather it will use the same param init scheme.
from copy import deepcopy

assert _copy_behavior == "copy_init"
res = self.__copy__()
if isinstance(self.initial, rf.init.ParamInit):
res.initial = deepcopy(self.initial, memo=memo) # noqa
Expand Down Expand Up @@ -241,3 +249,27 @@ def non_critical_for_restore(self) -> bool:
@non_critical_for_restore.setter
def non_critical_for_restore(self, value: bool):
self._non_critical_for_restore = value


_copy_behavior = "copy_init"


@contextmanager
def set_parameter_copy_behavior_ctx(mode: str):
"""
Set the copy behavior for parameters.
:param mode: "copy_init" or "share".
"copy_init" (default when not set): When copying a parameter, a new parameter will be created.
It is not a shared parameter.
It does not copy the values, but it just copies the param init scheme.
"share": The parameter will not be copied, it will just return itself.
"""
global _copy_behavior
assert mode in ("copy_init", "share")
old = _copy_behavior
try:
_copy_behavior = mode
yield
finally:
_copy_behavior = old

0 comments on commit fee68a8

Please sign in to comment.