From bc4acc19ffab2366a14468c97294992dbb7c50d1 Mon Sep 17 00:00:00 2001 From: ebsmothers Date: Fri, 1 Nov 2024 15:56:52 -0700 Subject: [PATCH] Fix grad accum + FSDP CPU offload, pass None via CLI (#1941) --- torchtune/config/_utils.py | 5 +++++ torchtune/training/_grad_scaler.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/torchtune/config/_utils.py b/torchtune/config/_utils.py index a5d1291802..93c19571c5 100644 --- a/torchtune/config/_utils.py +++ b/torchtune/config/_utils.py @@ -173,6 +173,11 @@ def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictC # key string to reflect this if k in yaml_kwargs and _has_component(yaml_kwargs[k]): k += "._component_" + + # None passed via CLI will be parsed as string, but we really want OmegaConf null + if v == "None": + v = "!!null" + # TODO: this is a hack but otherwise we can't pass strings with leading zeroes # to define the checkpoint file format. We manually override OmegaConf behavior # by prepending the value with !!str to force a string type diff --git a/torchtune/training/_grad_scaler.py b/torchtune/training/_grad_scaler.py index aab938bc90..484cd8f372 100644 --- a/torchtune/training/_grad_scaler.py +++ b/torchtune/training/_grad_scaler.py @@ -21,6 +21,11 @@ def scale_grads(model: nn.Module, scaler: torch.Tensor) -> None: Outputs: None (grad fields are modified in place) """ + device = None for p in model.parameters(): + # First ensure scaler is on the same device as the model + if not device: + device = p.device + scaler = scaler.to(device) if p.grad is not None: p.grad *= scaler