diff --git a/returnn/torch/util/gradient_checkpoint.py b/returnn/torch/util/gradient_checkpoint.py index 7c566b4a0..55dc9aa0b 100644 --- a/returnn/torch/util/gradient_checkpoint.py +++ b/returnn/torch/util/gradient_checkpoint.py @@ -26,6 +26,7 @@ import contextlib from weakref import ref, WeakSet import threading +import atexit import torch from torch.utils.weak import WeakTensorKeyDictionary # needs Torch >=2.0.0 @@ -178,6 +179,8 @@ def _maybe_exit_saved_tensors_hooks_scope(self): self.exit_saved_tensors_hooks_scope() def __del__(self): + if _python_exit: + return # Note, be very careful what we do in __del__ because it might be called in a different thread! # Note that the __del__ will likely be called very late, # as the reference to the _Graph is kept alive until we used it for backprop, @@ -220,6 +223,8 @@ def _unpack_hook(x: Union[torch.Tensor, _GraphTensor]) -> torch.Tensor: return x def _tensor_del_hook(self): + if _python_exit: + return # Some of the relevant tensors got deleted. # If we are in the right thread, maybe we can do the cleanup now. self._maybe_exit_saved_tensors_hooks_scope() @@ -601,3 +606,12 @@ def _custom_saved_tensors_hooks_call_callbacks(): assert not _custom_saved_tensors_hooks_tls_ctx.callbacks and not _custom_saved_tensors_hooks_tls_ctx.stack finally: _custom_saved_tensors_hooks_tls_ctx.in_callback = False + + +def _python_exit_handler(): + global _python_exit + _python_exit = True + + +_python_exit = False +atexit.register(_python_exit_handler)