From 08efaedf4556195e918548cb12ff33cbaee33197 Mon Sep 17 00:00:00 2001 From: Philip Bontrager Date: Sat, 9 Nov 2024 11:09:17 -0500 Subject: [PATCH] Refactor Recipe State Dict Code (#1964) --- docs/source/api_ref_modules.rst | 1 + docs/source/api_ref_training.rst | 1 + recipes/full_finetune_distributed.py | 4 +- recipes/knowledge_distillation_distributed.py | 10 +- .../knowledge_distillation_single_device.py | 6 +- recipes/lora_dpo_distributed.py | 29 ++--- recipes/lora_dpo_single_device.py | 3 +- recipes/lora_finetune_distributed.py | 31 +++--- recipes/lora_finetune_single_device.py | 3 +- recipes/qat_distributed.py | 4 +- .../recipes/test_full_finetune_distributed.py | 88 ++++++++++++++- .../test_full_finetune_single_device.py | 2 +- tests/torchtune/modules/peft/test_utils.py | 33 ++++-- tests/torchtune/training/test_distributed.py | 8 +- torchtune/modules/peft/__init__.py | 2 + torchtune/modules/peft/_utils.py | 21 ++++ torchtune/training/__init__.py | 4 +- torchtune/training/_distributed.py | 105 ++++++------------ 18 files changed, 226 insertions(+), 129 deletions(-) diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index f360b4f02c..2153e4fd20 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -75,6 +75,7 @@ PEFT Components peft.AdapterModule peft.get_adapter_params peft.set_trainable_params + peft.get_adapter_state_dict peft.validate_missing_and_unexpected_for_lora peft.validate_state_dict_for_lora peft.disable_adapter diff --git a/docs/source/api_ref_training.rst b/docs/source/api_ref_training.rst index 4e7f186523..c904ddc32b 100644 --- a/docs/source/api_ref_training.rst +++ b/docs/source/api_ref_training.rst @@ -56,6 +56,7 @@ Utilities for enabling and working with distributed training. get_world_size_and_rank get_full_finetune_fsdp_wrap_policy lora_fsdp_wrap_policy + gather_cpu_state_dict .. _ac_label: diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 81b6f71b0b..98d34b5f94 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -645,8 +645,8 @@ def save_checkpoint( # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), self._is_rank_zero, device=self._device, ) diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index b40be9e89e..e5f1047923 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -25,6 +25,7 @@ from torchtune.modules.peft import ( DoRALinear, get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -707,8 +708,8 @@ def save_checkpoint(self, epoch: int) -> None: intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), self._is_rank_zero, device=self._device, ) @@ -728,10 +729,7 @@ def save_checkpoint(self, epoch: int) -> None: # Filter out the adapter keys and weights from the model state dict. These will # be saved separately - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) - } + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # merge the adapter weights and base weights to create the model checkpoint diff --git a/recipes/knowledge_distillation_single_device.py b/recipes/knowledge_distillation_single_device.py index 1a2c3f0e4b..4521f42da3 100644 --- a/recipes/knowledge_distillation_single_device.py +++ b/recipes/knowledge_distillation_single_device.py @@ -23,6 +23,7 @@ from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -586,10 +587,7 @@ def save_checkpoint(self, epoch: int) -> None: ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) # Construct the adapter weights - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) - } + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) adapter_config = { "r": self._lora_rank, diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 1ab88deaf8..ee7ca5e729 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -25,6 +25,7 @@ disable_adapter, DoRALinear, get_adapter_params, + get_adapter_state_dict, get_merged_lora_ckpt, load_dora_magnitudes, LoRALinear, @@ -504,8 +505,12 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + state_dict = self._model.state_dict() + if self._save_adapter_weights_only: + state_dict = get_adapter_state_dict(state_dict, device=None) + + cpu_state_dict = training.gather_cpu_state_dict( + state_dict, self._is_rank_zero, device=self._device, ) @@ -521,23 +526,21 @@ def save_checkpoint( # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: - - # Filter out the adapter keys and weights from the model state dict. These will - # be saved separately - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) - } - checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) - - # merge the adapter weights and base weights to create the model checkpoint - if not self._save_adapter_weights_only: + if self._save_adapter_weights_only: + adapter_state_dict = cpu_state_dict + else: + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) + + # merge the adapter weights and base weights to create the model checkpoint merged_state_dict = get_merged_lora_ckpt( cpu_state_dict, rank=self._lora_rank, alpha=self._lora_alpha, ) checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state # as well. diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index f34694ccc8..26e6a236bc 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -23,6 +23,7 @@ from torchtune.modules.peft import ( disable_adapter, get_adapter_params, + get_adapter_state_dict, get_merged_lora_ckpt, set_trainable_params, validate_missing_and_unexpected_for_lora, @@ -407,7 +408,7 @@ def save_checkpoint(self, epoch: int) -> None: } ) - adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) if not self._save_adapter_weights_only: # Construct the full state dict with LoRA weights merged into base LLM weights diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 6840c557aa..a900cea103 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -26,6 +26,7 @@ from torchtune.modules.peft import ( DoRALinear, get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -452,8 +453,7 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) - self.adapter_params = get_adapter_params(model) - set_trainable_params(model, self.adapter_params) + set_trainable_params(model, get_adapter_params(model)) if self._compile: training.compile_model(model, verbose=self._is_rank_zero) @@ -664,11 +664,14 @@ def save_checkpoint( # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + state_dict = self._model.state_dict() + if self._save_adapter_weights_only: + state_dict = get_adapter_state_dict(state_dict, device=None) + + cpu_state_dict = training.gather_cpu_state_dict( + state_dict, self._is_rank_zero, device=self._device, - trainable_only=self._save_adapter_weights_only, ) if self._is_rank_zero: log.info( @@ -694,22 +697,22 @@ def save_checkpoint( # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: start = time.perf_counter() - # Filter out the adapter keys and weights from the model state dict. These will - # be saved separately - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) - } - checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) - # merge the adapter weights and base weights to create the model checkpoint - if not self._save_adapter_weights_only: + if self._save_adapter_weights_only: + adapter_state_dict = cpu_state_dict + else: + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) + + # merge the adapter weights and base weights to create the model checkpoint merged_state_dict = get_merged_lora_ckpt( cpu_state_dict, rank=self._lora_rank, alpha=self._lora_alpha, ) checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state # as well. diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index daf0ea8cdc..fcdb3e4ea5 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -24,6 +24,7 @@ from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -592,7 +593,7 @@ def save_checkpoint(self, epoch: int) -> None: } ) - adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()} + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) if not self._save_adapter_weights_only: diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index b1040880d0..1aa622ba63 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -673,8 +673,8 @@ def save_checkpoint( # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( - self._model, + cpu_state_dict = training.gather_cpu_state_dict( + self._model.state_dict(), self._is_rank_zero, device=self._device, ) diff --git a/tests/recipes/test_full_finetune_distributed.py b/tests/recipes/test_full_finetune_distributed.py index f1f4256411..9c8d0eacd5 100644 --- a/tests/recipes/test_full_finetune_distributed.py +++ b/tests/recipes/test_full_finetune_distributed.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import runpy - import sys from pathlib import Path @@ -113,3 +113,89 @@ def test_loss( torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 ) + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd", + [ + ("llama3/8B_full", "llama3", "tune", 1, 4, False), + ], + ) + @gpu_test(gpu_count=2) + def test_training_state_on_resume( + self, + micro_batch_size, + gradient_accumulation_steps, + config, + model_type, + ckpt_type, + optim_in_bwd, + tmpdir, + monkeypatch, + ): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + clip_grad_norm=100 \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type] + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \ + --config {config} \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{tmpdir}' \ + checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\ + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + clip_grad_norm=100 \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config + + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values = self._fetch_expected_loss_values(model_type)[2:] + + loss_values = get_loss_values_from_metric_logger(log_file) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + ) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 819c70fdf0..85df960b22 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -181,7 +181,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): checkpointer._component_=torchtune.training.FullModelHFCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\ - checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\ checkpointer.output_dir={tmpdir} \ checkpointer.model_type=LLAMA2 \ tokenizer.path=/tmp/test-artifacts/tokenizer.model \ diff --git a/tests/torchtune/modules/peft/test_utils.py b/tests/torchtune/modules/peft/test_utils.py index 032cd88ec4..de90150195 100644 --- a/tests/torchtune/modules/peft/test_utils.py +++ b/tests/torchtune/modules/peft/test_utils.py @@ -16,6 +16,7 @@ disable_adapter, DoRALinear, get_adapter_params, + get_adapter_state_dict, get_merged_lora_ckpt, LoRALinear, set_trainable_params, @@ -38,30 +39,30 @@ class DummyAdapterModule(nn.Module, AdapterModule): def __init__(self, in_dim, out_dim): super().__init__() - self.adapter = nn.Linear(in_dim, out_dim, bias=False) + self.lora = nn.Linear(in_dim, out_dim, bias=False) self.linear = nn.Linear(in_dim, out_dim) def adapter_params(self): - return ["adapter.weight"] + return ["lora.weight"] def forward(self, x): - return self.adapter(x) + self.non_adapter(x) + return self.lora(x) + self.non_adapter(x) class DummyAdapterParentModel(nn.Module, AdapterModule): def __init__(self, in_dim, out_dim): super().__init__() self.dummy_adapter_module = DummyAdapterModule(in_dim, out_dim) - self.parent_adapter = nn.Linear(in_dim, out_dim) + self.parent_lora = nn.Linear(in_dim, out_dim) self.parent_base_model = nn.Linear(in_dim, out_dim) def adapter_params(self): - return ["parent_adapter.weight", "parent_adapter.bias"] + return ["parent_lora.weight", "parent_lora.bias"] def forward(self, x): return ( self.dummy_adapter_module(x) - + self.parent_adapter(x) + + self.parent_lora(x) + self.parent_base_model(x) ) @@ -79,9 +80,9 @@ def dummy_model_expected_adapter_keys(): for i in range(N_LAYERS): keys.extend( [ - f"{i}.parent_adapter.weight", - f"{i}.parent_adapter.bias", - f"{i}.dummy_adapter_module.adapter.weight", + f"{i}.parent_lora.weight", + f"{i}.parent_lora.bias", + f"{i}.dummy_adapter_module.lora.weight", ] ) return keys @@ -204,6 +205,20 @@ def test_get_adapter_params(self, request, model_name, expected_keys): expected = request.getfixturevalue(expected_keys) assert set(expected) == set(adapter_params.keys()) + @pytest.mark.parametrize( + "model_name, expected_keys", + [ + ("dummy_adapter_parent_model", "dummy_model_expected_adapter_keys"), + ("lora_llama2_model", "lora_llama2_expected_adapter_keys"), + ("dora_llama2_model", "dora_llama2_expected_adapter_keys"), + ], + ) + def test_get_adapter_state_dict(self, request, model_name, expected_keys): + model = request.getfixturevalue(model_name) + adapter_state_dict = get_adapter_state_dict(model.state_dict()) + expected = request.getfixturevalue(expected_keys) + assert set(expected) == set(adapter_state_dict.keys()) + @pytest.mark.parametrize( "model_name, expected_trainable_keys, expected_frozen_keys", [ diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 1f4b92b4de..830a2ab4a8 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -312,8 +312,8 @@ def test_lora_state_dict(self): fsdp_optim_to_save.zero_grad() expected_model_sd = base_model.state_dict() expected_optim_sd = base_optim.state_dict() - model_full_sd = training.get_full_model_state_dict( - fsdp_model_to_save, is_rank_zero + model_full_sd = training.gather_cpu_state_dict( + fsdp_model_to_save.state_dict(), is_rank_zero ) optim_full_sd = training.get_full_optimizer_state_dict( fsdp_optim_to_save, @@ -467,8 +467,8 @@ def _test_qlora_state_dict(self, enable_activation_checkpointing: bool): fsdp_model_to_save(inp) expected_model_sd = {k: v.cpu() for k, v in base_model.state_dict().items()} - model_full_sd = training.get_full_model_state_dict( - fsdp_model_to_save, is_rank_zero + model_full_sd = training.gather_cpu_state_dict( + fsdp_model_to_save.state_dict(), is_rank_zero ) if is_rank_zero: self.assertEqual(set(model_full_sd.keys()), set(expected_model_sd.keys())) diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py index 44922aa83d..4d678ea6ab 100644 --- a/torchtune/modules/peft/__init__.py +++ b/torchtune/modules/peft/__init__.py @@ -8,6 +8,7 @@ AdapterModule, disable_adapter, get_adapter_params, + get_adapter_state_dict, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -30,6 +31,7 @@ "validate_state_dict_for_lora", "load_dora_magnitudes", "disable_adapter", + "get_adapter_state_dict", "get_merged_lora_ckpt", "get_lora_module_names", ] diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 318ab4136a..e0e29bb716 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -109,6 +109,27 @@ def get_lora_module_names( return lora_module_keys +def get_adapter_state_dict( + state_dict: Dict[str, Any], device: Optional[str] = "cpu" +) -> Dict[str, Any]: + """ + Return the subset of the full state_dict from a model that correspond to an adapter. + Assumes that "lora" and "magnitude" are unique names for adapter parameters, and + that the state_dict is not sharded. All returned parameters are moved to CPU. + + Args: + state_dict (Dict[str, Any]): Full model state dict. + device (Optional[str]): device to move adapter parameters to. Default: 'cpu' + + Returns: + Dict[str, Any]: the subset of model's state dict containing + only adapter parameters. + + """ + adapter_key_filter = lambda x: "lora" in x or "magnitude" in x + return {k: v.to(device) for k, v in state_dict.items() if adapter_key_filter(k)} + + def validate_state_dict_for_lora( lora_attn_modules: List[LORA_ATTN_MODULES], apply_lora_to_mlp: bool, diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index a1e1cdbd73..06ec1b5312 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -12,8 +12,8 @@ from torchtune.training._distributed import ( contains_fsdp, FSDPPolicyType, + gather_cpu_state_dict, get_full_finetune_fsdp_wrap_policy, - get_full_model_state_dict, get_full_optimizer_state_dict, get_shard_conditions, get_world_size_and_rank, @@ -120,7 +120,7 @@ "FSDPPolicyType", "get_full_finetune_fsdp_wrap_policy", "lora_fsdp_wrap_policy", - "get_full_model_state_dict", + "gather_cpu_state_dict", "get_full_optimizer_state_dict", "load_from_full_model_state_dict", "load_from_full_optimizer_state_dict", diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 0184fd7af6..3511662442 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -17,9 +17,6 @@ from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard from torch.distributed._tensor import distribute_tensor, DTensor from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - _CHECKPOINT_WRAPPED_MODULE, -) from torch.distributed.checkpoint.state_dict import _init_optim_state from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.wrap import ModuleWrapPolicy @@ -348,88 +345,58 @@ def load_from_full_model_state_dict( return model.load_state_dict(sharded_sd, strict=strict, assign=True) -def get_full_model_state_dict( - model: "FSDPModule", # noqa +def gather_cpu_state_dict( + sharded_sd: Dict[str, DTensor], # noqa is_rank_zero: bool, device: Optional[torch.device] = None, - trainable_only: bool = False, ) -> Dict[str, Any]: """ Converting sharded state dict into a full state dict on CPU - Returning non-empty result on rank0 to avoid peaking CPU memory + Returning non-empty result only on rank0 to avoid peaking CPU memory Args: - model (FSDPModule): wrapped module + sharded_sd (Dict[str, DTensor]): Sharded state dict of DTensors is_rank_zero (bool): flag to check if the process is on rank 0 device (Optional[torch.device]): device to use for sharded tensors. Default: None - trainable_only (bool): flag to check if only trainable parameters should be returned. Default: False - - Raises: - AssertionError: if the model contains NF4Tensor and the model is not wrapped with FSDP Returns: Dict[str, Any]: State dict on CPU """ - # [Warning] FSDPModel.state_dict converts all Parameter Tensors to DTensors - sharded_sd = model.state_dict() cpu_state_dict = {} - has_nf4 = any( - isinstance(param._local_tensor, NF4Tensor) for param in model.parameters() - ) - if has_nf4: - from torch.distributed._composable.fsdp.fully_shard import FSDPModule - - # Iterating from lowerer modules to higher - # Unsharding lora adapters before unsharding transformer block - for module_name, module in reversed(list(model.named_modules())): - if not isinstance(module, FSDPModule): - continue - module.unshard(async_op=False) - if is_rank_zero: - module_name = module_name.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "") - for local_fqn, param in module.named_parameters(): - local_fqn = local_fqn.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "") - if len(module_name) > 0: - full_fqn = module_name + "." + local_fqn - else: - full_fqn = local_fqn - if trainable_only and not param.requires_grad: - # skip trainable params when trainable_only is True - continue - if full_fqn in cpu_state_dict: - # Iterate over every param in every module bottoms-up - # When lower TransformerBlock gets unsharded, - # we insert (full_fqn, full_tensor) into cpu_state_dict. - # When higher Transformer gets unsharded, we avoid updating - # params from lower TransformerBlockonly again. Instead, only updating - # tok_embeddings etc that belongs to Transformer - continue - if isinstance(param, NF4Tensor): - # upcasting NF4 to original dtype - param = param.to(param.dtype) - if isinstance(param, DTensor): - raise AssertionError( - f"Internal error: expect unsharded {full_fqn} in plain torch.Tensor but got DTensor." - " Might be a bug in get_full_model_state_dict" - ) - cpu_state_dict[full_fqn] = param.cpu() - module.reshard() - else: - for param_name, sharded_param in sharded_sd.items(): - # without this, it may hang forever for +70B models. - torch.distributed.barrier() - if sharded_param.is_cpu: - assert device is not None and device.type == "cuda", ( - f"Expect cuda but got device={device}. " - "Please call get_full_model_state_dict(..., device=self._device)," - " so DTensor can communicate over NCCL." + for param_name, sharded_param in sharded_sd.items(): + if sharded_param.is_cpu: + # Move back to device if offloaded to CPU + sharded_param = sharded_param.to(device) + if isinstance(sharded_param._local_tensor, NF4Tensor): + # NF4Tensor does not support all_gather from DTensor + # so we need to manually all_gather + mesh = sharded_param.device_mesh + nf4_tensor = sharded_param._local_tensor + quant_params, metadata = nf4_tensor.fsdp_pre_all_gather(mesh) + full_quant_params = [] + for quant_param in quant_params: + d0, *dn = quant_param.shape + shape = (d0 * mesh.get_group().size(), *dn) + full_quant_param = torch.empty( + shape, device=quant_param.device, dtype=quant_param.dtype + ) + dist.all_gather_into_tensor( + full_quant_param, quant_param, mesh.get_group(), async_op=False ) - sharded_param = sharded_param.to(device) + full_quant_params.append(full_quant_param) + full_param, _ = nf4_tensor.fsdp_post_all_gather( + full_quant_params, metadata, nf4_tensor.dtype + ) + # upcasting NF4 to original dtype + full_param = full_param.to(full_param.dtype) + else: + # Gather DTensor full_param = sharded_param.full_tensor() - if is_rank_zero: - cpu_state_dict[param_name] = full_param.cpu() - else: - del full_param + if is_rank_zero: + cpu_state_dict[param_name] = full_param.cpu() + else: + del full_param + torch.distributed.barrier() return cpu_state_dict