diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index 2153e4fd20..36ea3637c8 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -77,7 +77,6 @@ PEFT Components peft.set_trainable_params peft.get_adapter_state_dict peft.validate_missing_and_unexpected_for_lora - peft.validate_state_dict_for_lora peft.disable_adapter diff --git a/docs/source/api_ref_training.rst b/docs/source/api_ref_training.rst index c904ddc32b..0f0a392efe 100644 --- a/docs/source/api_ref_training.rst +++ b/docs/source/api_ref_training.rst @@ -50,12 +50,9 @@ Utilities for enabling and working with distributed training. :toctree: generated/ :nosignatures: - FSDPPolicyType init_distributed is_distributed get_world_size_and_rank - get_full_finetune_fsdp_wrap_policy - lora_fsdp_wrap_policy gather_cpu_state_dict .. _ac_label: diff --git a/docs/source/tutorials/lora_finetune.rst b/docs/source/tutorials/lora_finetune.rst index 8449077b32..ad842f7ab5 100644 --- a/docs/source/tutorials/lora_finetune.rst +++ b/docs/source/tutorials/lora_finetune.rst @@ -205,8 +205,7 @@ model without any wrappers or custom checkpoint conversion logic. .. note:: Whenever loading weights with :code:`strict=False`, you should verify that any missing or extra keys in - the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via e.g. - :func:`validate_state_dict_for_lora() ` or + the loaded :code:`state_dict` are as expected. torchtune's LoRA recipes do this by default via :func:`validate_missing_and_unexpected_for_lora() `. Once we've loaded the base model weights, we also want to set only LoRA parameters to trainable. diff --git a/docs/source/tutorials/qat_finetune.rst b/docs/source/tutorials/qat_finetune.rst index 6f83bb6c02..59359bfb88 100644 --- a/docs/source/tutorials/qat_finetune.rst +++ b/docs/source/tutorials/qat_finetune.rst @@ -168,11 +168,6 @@ modifications accordingly: fake_quant_after_n_steps: 1000 memory_efficient_fsdp_wrap: False -.. note:: - - QAT in torchtune is currently not compatible with `memory_efficient_fsdp_wrap `_. - This is a known issue and will be fixed in a future torchtune version. - Empirically, we observed that disabling fake quantization for the first N steps led to better results, presumably because doing so allows the weights to stabilize before we start introducing quantization noise to the fine-tuning process. diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index 26e6a236bc..d02a221d29 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -27,7 +27,6 @@ get_merged_lora_ckpt, set_trainable_params, validate_missing_and_unexpected_for_lora, - validate_state_dict_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface @@ -271,19 +270,6 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - validate_state_dict_for_lora( - lora_attn_modules=cfg_model.lora_attn_modules, - apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, - apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), - full_model_state_dict_keys=model.state_dict().keys(), - lora_state_dict_keys=( - lora_weights_state_dict.keys() - if lora_weights_state_dict is not None - else None - ), - base_model_state_dict_keys=base_model_state_dict.keys(), - ) - base_missing, base_unexpected = model.load_state_dict( base_model_state_dict, strict=False ) diff --git a/tests/torchtune/modules/peft/test_utils.py b/tests/torchtune/modules/peft/test_utils.py index de90150195..2147af0d6d 100644 --- a/tests/torchtune/modules/peft/test_utils.py +++ b/tests/torchtune/modules/peft/test_utils.py @@ -21,7 +21,6 @@ LoRALinear, set_trainable_params, validate_missing_and_unexpected_for_lora, - validate_state_dict_for_lora, ) N_LAYERS = 3 @@ -261,9 +260,10 @@ def test_set_trainable_params( lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output, - full_model_state_dict_keys, - lora_state_dict_keys, - base_model_state_dict_keys, + base_missing, + base_unexpected, + lora_missing, + lora_unexpected, expected """ ), @@ -272,166 +272,77 @@ def test_set_trainable_params( ["q_proj", "k_proj"], False, False, - ["q_proj.lora_a.weight", "dummy_param.weight"], ["q_proj.lora_a.weight"], + [], ["dummy_param.weight"], + [], "", ), - ( - ["v_proj"], - False, - False, - ["param_a", "param_b"], - None, - ["param_a", "param_b"], - "", - ), + (["v_proj"], False, False, [], [], ["param_a", "param_b"], [], ""), ( ["output_proj"], False, True, - ["output_proj.weight", "output_proj.lora_a.weight"], ["output_proj.lora_a.weight"], + [], ["output_proj.weight"], + [], "", ), - (["q_proj"], False, False, ["param_a"], [], [], "Missing non-LoRA"), ( - ["k_proj", "output_proj"], + ["q_proj"], False, - True, - ["k_proj.lora_a.weight", "param_a"], - ["k_proj.lora_a.weight", "param_a"], + False, + ["param_a"], + [], ["param_a"], - "found in LoRA", + [], + "Missing non-LoRA", ), ( - ["k_proj"], - False, + ["k_proj", "output_proj"], False, - ["k_proj.lora_a.weight"], + True, + [], [], ["k_proj.lora_a.weight"], - "found in base model", + [], + "Missing LoRA key", ), ( - ["k_proj"], - False, + ["q_proj", "k_proj"], + True, False, - ["k_proj.lora_a.weight"], + ["k_proj.lora"], + [], + ["q_proj.lora"], [], - None, "Missing LoRA", ), - (["q_proj"], False, False, [], ["a"], ["a"], "overlapping"), - ( - ["v_proj"], - False, - False, - ["dummy_param.weight"], - ["v_proj.lora_a.weight"], - ["dummy_param.weight"], - "Extra", - ), ( - ["w1", "w2", "w3"], + ["q_proj", "k_proj"], True, False, - ["w1.lora_a.weight", "w2.weight", "q_proj.weight"], - ["w1.lora_a.weight"], - ["q_proj.weight"], - "Missing non-LoRA key", + ["k_proj.lora"], + [], + ["q_proj.magnitude"], + [], + "Missing LoRA", ), ( - ["q_proj", "output"], - False, + ["q_proj", "k_proj"], True, - [ - "q_proj.lora_a", - "output.weight", - "output.lora_a", - "output_proj.lora_b", - ], - ["q_proj.lora_a", "output.lora_a", "output_proj.lora_b"], - ["output.weight"], - "Missing non-LoRA key", - ), - ( - ["q_proj", "v_proj"], - False, False, - "lora_llama2_model_all_keys", - "lora_llama2_expected_adapter_keys", - "lora_llama2_expected_base_model_keys", - "", + ["output_proj.lora"], + [], + ["q_proj.lora"], + [], + "Missing non-LoRA", ), ( - ["q_proj", "v_proj"], - False, + ["q_proj", "k_proj"], + True, False, - "dora_llama2_model_all_keys", - "dora_llama2_expected_adapter_keys", - "lora_llama2_expected_base_model_keys", - "", - ), - ], - ) - def test_validate_lora_state_dict( - self, - request, - lora_attn_modules, - apply_lora_to_mlp, - apply_lora_to_output, - full_model_state_dict_keys, - lora_state_dict_keys, - base_model_state_dict_keys, - expected, - ): - if isinstance(full_model_state_dict_keys, str): - full_model_state_dict_keys = request.getfixturevalue( - full_model_state_dict_keys - ) - if isinstance(lora_state_dict_keys, str): - lora_state_dict_keys = request.getfixturevalue(lora_state_dict_keys) - if isinstance(base_model_state_dict_keys, str): - base_model_state_dict_keys = request.getfixturevalue( - base_model_state_dict_keys - ) - if expected: - with pytest.raises(AssertionError, match=expected): - validate_state_dict_for_lora( - lora_attn_modules, - apply_lora_to_mlp, - apply_lora_to_output, - full_model_state_dict_keys=full_model_state_dict_keys, - lora_state_dict_keys=lora_state_dict_keys, - base_model_state_dict_keys=base_model_state_dict_keys, - ) - else: - validate_state_dict_for_lora( - lora_attn_modules, - apply_lora_to_mlp, - apply_lora_to_output, - full_model_state_dict_keys=full_model_state_dict_keys, - lora_state_dict_keys=lora_state_dict_keys, - base_model_state_dict_keys=base_model_state_dict_keys, - ) - - @pytest.mark.parametrize( - ( - """ - base_missing, - base_unexpected, - lora_missing, - lora_unexpected, - expected - """ - ), - [ - (["k_proj.lora"], [], ["q_proj.lora"], [], "Missing LoRA"), - (["k_proj.lora"], [], ["q_proj.magnitude"], [], "Missing LoRA"), - (["output_proj.lora"], [], ["q_proj.lora"], [], "Missing non-LoRA"), - ( ["k_proj.lora"], ["output.weight"], ["q_proj.base_weight"], @@ -439,21 +350,39 @@ def test_validate_lora_state_dict( "loading base model", ), ( + ["q_proj", "k_proj"], + True, + False, ["k_proj.lora"], [], ["q_proj.base_weight"], ["output.weight"], "loading adapter", ), - (["k_proj.lora"], [], ["q_proj.base_weight"], [], ""), + ( + ["q_proj", "k_proj"], + True, + False, + ["k_proj.lora"], + [], + ["q_proj.base_weight"], + [], + "", + ), ], ) def test_validate_missing_and_unexpected_for_lora( - self, base_missing, base_unexpected, lora_missing, lora_unexpected, expected + self, + lora_attn_modules, + apply_lora_to_mlp, + apply_lora_to_output, + base_missing, + base_unexpected, + lora_missing, + lora_unexpected, + expected, ): - lora_attn_modules = ["q_proj", "k_proj"] - apply_lora_to_mlp = True - apply_lora_to_output = False + if expected: with pytest.raises(AssertionError, match=expected): validate_missing_and_unexpected_for_lora( diff --git a/tests/torchtune/training/test_distributed.py b/tests/torchtune/training/test_distributed.py index 830a2ab4a8..334a67487e 100644 --- a/tests/torchtune/training/test_distributed.py +++ b/tests/torchtune/training/test_distributed.py @@ -7,26 +7,22 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. import copy -from itertools import chain import pytest import torch import torch.nn as nn from packaging import version -from tests.test_utils import gpu_test, single_box_init +from tests.test_utils import gpu_test from torch.distributed import launcher from torch.distributed._composable.fsdp import fully_shard from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointWrapper, ) -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - from torch.testing._internal.common_fsdp import FSDPTest, MLP from torchao.dtypes.nf4tensor import NF4Tensor from torchtune import modules, training -from torchtune.models.llama2._component_builders import llama2, lora_llama2 -from torchtune.models.llama3._component_builders import llama3 +from torchtune.models.llama2._component_builders import lora_llama2 from torchtune.modules import TransformerSelfAttentionLayer from torchtune.modules.peft import ( DoRALinear, @@ -110,51 +106,6 @@ def test_validate_no_params_on_meta_device(self) -> None: with pytest.raises(RuntimeError, match="Unexpected param or buffer"): training.validate_no_params_on_meta_device(model) - def test_get_fsdp_wrap_policies(self) -> None: - with single_box_init(): - llama3_policy = training.get_full_finetune_fsdp_wrap_policy( - memory_efficient_fsdp_wrap=True, - modules_to_wrap={modules.TransformerSelfAttentionLayer}, - ) - l3 = llama3( - vocab_size=64, - num_layers=1, - num_heads=4, - num_kv_heads=4, - embed_dim=64, - max_seq_len=128, - ) - wrapped_l3 = FSDP( - l3, auto_wrap_policy=llama3_policy, device_id=torch.device("cpu") - ) - # Ensure embedding, output proj, and transformer decoder blocks are wrapped - assert isinstance(wrapped_l3.tok_embeddings, FSDP) - assert isinstance(wrapped_l3.output, FSDP) - for layer in wrapped_l3.layers: - assert isinstance(layer, FSDP) - - llama2_policy = training.get_full_finetune_fsdp_wrap_policy( - memory_efficient_fsdp_wrap=False, - modules_to_wrap={modules.TransformerSelfAttentionLayer}, - ) - l2 = llama2( - vocab_size=64, - num_layers=1, - num_heads=4, - num_kv_heads=4, - embed_dim=64, - max_seq_len=128, - ) - wrapped_l2 = FSDP( - l2, auto_wrap_policy=llama2_policy, device_id=torch.device("cpu") - ) - # Ensure embedding, output proj, and transformer decoder blocks are not wrapped - assert not isinstance(wrapped_l2.tok_embeddings, FSDP) - assert not isinstance(wrapped_l2.output, FSDP) - # Ensure transformer decoder blocks are wrapped - for layer in wrapped_l2.layers: - assert isinstance(layer, FSDP) - N_LAYERS = 3 IN_DIM = 5 @@ -181,84 +132,6 @@ def _get_n_lora_and_tformer_layers(model): return num_lora_ab, num_transformer_layers -# TODO: figure out a permanent home for FSDP + LoRA code -class TestLoRAFSDP: - def test_lora_fsdp_wrap(self): - with torch.device("meta"): - model = lora_llama2( - lora_attn_modules=["q_proj", "v_proj"], - vocab_size=VOCAB_SIZE, - num_layers=N_LAYERS, - num_heads=NUM_HEADS, - num_kv_heads=NUM_KV_HEADS, - embed_dim=EMBED_DIM, - max_seq_len=MAX_SEQ_LEN, - lora_rank=4, - lora_alpha=1.0, - ) - - adapter_params = get_adapter_params(model) - set_trainable_params(model, adapter_params) - num_lora_ab, num_transformer_layers = _get_n_lora_and_tformer_layers(model) - with single_box_init(): - lora_wrap_policy = training.lora_fsdp_wrap_policy( - modules_to_wrap={TransformerSelfAttentionLayer} - ) - training.prepare_model_for_fsdp_with_meta_device(model) - wrapped_lora = FSDP( - model, - auto_wrap_policy=lora_wrap_policy, - device_id=torch.device("cpu"), - ) - - # After FSDP wrap, nothing should be left on meta device, and LoRA params - # should be initialized. - for p in chain(wrapped_lora.parameters(), wrapped_lora.buffers()): - assert not p.is_meta - - for m in wrapped_lora.modules(): - if isinstance(m, LoRALinear) or isinstance(m, DoRALinear): - torch.testing.assert_close( - m.lora_b.weight, torch.zeros_like(m.lora_b.weight) - ) - # Total # FSDP modules should be num_transformer + num_lora_ab + 1 - total_fsdp_submodules = len([m for m in FSDP.fsdp_modules(wrapped_lora)]) - assert total_fsdp_submodules == (num_lora_ab + num_transformer_layers + 1) - # LoRA a & b linears should be individually wrapped. - # And TransformerSelfAttentionLayers should be individually wrapped. - for fsdp_submodule in FSDP.fsdp_modules(wrapped_lora): - if isinstance(fsdp_submodule.module, nn.Linear): - num_lora_ab -= 1 - elif isinstance(fsdp_submodule.module, TransformerSelfAttentionLayer): - num_transformer_layers -= 1 - assert num_lora_ab == 0 - assert num_transformer_layers == 0 - - def test_lora_meta_device_init_fsdp(self): - with torch.device("meta"): - lora = lora_llama2( - lora_attn_modules=["q_proj", "v_proj"], - vocab_size=VOCAB_SIZE, - num_layers=N_LAYERS, - num_heads=NUM_HEADS, - num_kv_heads=NUM_KV_HEADS, - embed_dim=EMBED_DIM, - max_seq_len=MAX_SEQ_LEN, - lora_rank=4, - lora_alpha=8, - ) - training.prepare_model_for_fsdp_with_meta_device(lora) - for m in lora.modules(): - m.to_empty(device=torch.device("cpu"), recurse=False) - m.reset_parameters() - # No params should be left on meta device - for n, p in lora.named_parameters(): - assert not p.is_meta, f"parameter {n} is still on meta device!" - # Neither should buffers - for n, b in lora.named_buffers(): - assert not b.is_meta, f"buffer {n} is still on meta device!" - - class TestFullyShardState(FSDPTest): @property def world_size(self) -> int: diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py index 4d678ea6ab..165559df9c 100644 --- a/torchtune/modules/peft/__init__.py +++ b/torchtune/modules/peft/__init__.py @@ -15,7 +15,6 @@ LORA_ATTN_MODULES, set_trainable_params, validate_missing_and_unexpected_for_lora, - validate_state_dict_for_lora, ) from .dora import DoRALinear from .lora import LoRALinear @@ -28,7 +27,6 @@ "get_adapter_params", "set_trainable_params", "validate_missing_and_unexpected_for_lora", - "validate_state_dict_for_lora", "load_dora_magnitudes", "disable_adapter", "get_adapter_state_dict", diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index e0e29bb716..560bfdaf52 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -130,91 +130,6 @@ def get_adapter_state_dict( return {k: v.to(device) for k, v in state_dict.items() if adapter_key_filter(k)} -def validate_state_dict_for_lora( - lora_attn_modules: List[LORA_ATTN_MODULES], - apply_lora_to_mlp: bool, - apply_lora_to_output: bool, - full_model_state_dict_keys: List[str], - lora_state_dict_keys: Optional[List[str]] = None, - base_model_state_dict_keys: Optional[List[str]] = None, -) -> None: - """ - Validate that the state dict keys for a LoRA model are as expected. - - (1) If lora_state_dict_keys are passed, this function will confirm that they match exactly the - LoRA param names from the full model (as determined by lora_modules). - (2) If base_model_state_dict_keys are passed, this function will confirm that they are exactly the - complement of the LoRA param names from the full model. - (3) If both lora_state_dict_keys and base_model_state_dict_keys are passed, this function will - confirm that the full model's params are exactly their disjoint union. - - Args: - lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers - LoRA should be applied to in each self-attention block. Options are - ``{"q_proj", "k_proj", "v_proj", "output_proj"}``. - apply_lora_to_mlp (bool): whether LoRA is applied to each MLP linear. - apply_lora_to_output (bool): whether LoRA is applied to the final output projection. - full_model_state_dict_keys (List[str]): List of keys in the full model state dict. - lora_state_dict_keys (Optional[List[str]]): List of keys in the LoRA state dict. - If none, LoRA state dict keys will not be validated. - base_model_state_dict_keys (Optional[List[str]]): List of keys in the base model state dict. - If none, base model keys will not be validated. - - Returns: - None - - Raises: - AssertionError: If base model state dict is missing any non-LoRA params from the full model. - AssertionError: If LoRA state dict is missing any LoRA params from the full model. - AssertionError: If base model state dict has any LoRA params. - AssertionError: If LoRA state dict has any non-LoRA params. - AssertionError: If base model and LoRA state dicts have overlapping keys. - AssertionError: If full model state dict is missing keys from either base model or LoRA state dict. - - """ - lora_modules = get_lora_module_names( - lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output - ) - is_lora_param = lambda x: any( - [ - ".".join([k, "lora"]) in x or ".".join([k, "magnitude"]) in x - for k in lora_modules - ] - ) - for k in full_model_state_dict_keys: - if not is_lora_param(k): - if base_model_state_dict_keys is not None: - if k not in base_model_state_dict_keys: - raise AssertionError( - f"Missing non-LoRA key {k} from base model state dict" - ) - if lora_state_dict_keys is not None: - if k in lora_state_dict_keys: - raise AssertionError(f"Non-LoRA key {k} found in LoRA state dict") - else: - if base_model_state_dict_keys is not None: - if k in base_model_state_dict_keys: - raise AssertionError(f"LoRA key {k} found in base model state dict") - if lora_state_dict_keys is not None: - if k not in lora_state_dict_keys: - raise AssertionError(f"Missing LoRA key {k} From LoRA state dict") - - # Full model is disjoint union of base model and LoRA weights - if lora_state_dict_keys is not None and base_model_state_dict_keys is not None: - combined_state_dict_keys = set(lora_state_dict_keys).union( - base_model_state_dict_keys - ) - shared_state_dict_keys = set(lora_state_dict_keys).intersection( - base_model_state_dict_keys - ) - assert ( - shared_state_dict_keys == set() - ), "Base model and LoRA state dict have overlapping keys" - assert combined_state_dict_keys == set( - full_model_state_dict_keys - ), "Extra keys not present in full model" - - def _get_lora_modules(state_dict: Dict[str, Any]) -> Set[str]: """ Get the keys from a state dict that correspond to LoRALinear modules. @@ -345,10 +260,9 @@ def validate_missing_and_unexpected_for_lora( """ A more memory-efficient way to validate that LoRA state dict loading was done properly. - Similar to :func:`validate_state_dict_for_lora`, this function uses a model's LoRA config to - check that LoRA and/or base model weights are loaded into the full model correctly. - Unlike that function, this method relies only on the values of missing and unexpected - as returned by the load_state_dict API with strict=False. This allows us to do the + This function uses a model's LoRA config to check that LoRA and/or base model weights + are loaded into the full model correctly. This function relies only on the values of missing and + unexpected as returned by the load_state_dict API with strict=False. This allows us to do the validation without any additional calls to .state_dict(), which use additional memory. Args: diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index ee96ef5bc5..138dd0c5ee 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -91,14 +91,6 @@ def __init__( self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False) self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False) self.merged = False - # Note: FSDP's meta device initialization contract assumes that a module's - # reset_parameters method only initializes its own parameters (i.e. no child - # params are initialized, as is done in initialize_parameters below). - # For that reason, we patch reset_parameters directly on lora_a and lora_b submodules - # when using meta device. This is done in - # torchtune.training.prepare_model_for_fsdp_with_meta_device. - # See this issue for more details: https://github.com/pytorch/pytorch/issues/104187. - # Without meta device, we only need the following: self.initialize_parameters() def initialize_parameters(self): diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 06ec1b5312..5b7ae4d8d3 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -10,10 +10,7 @@ ) from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( - contains_fsdp, - FSDPPolicyType, gather_cpu_state_dict, - get_full_finetune_fsdp_wrap_policy, get_full_optimizer_state_dict, get_shard_conditions, get_world_size_and_rank, @@ -21,8 +18,6 @@ is_distributed, load_from_full_model_state_dict, load_from_full_optimizer_state_dict, - lora_fsdp_wrap_policy, - prepare_model_for_fsdp_with_meta_device, set_torch_num_threads, shard_model, validate_no_params_on_meta_device, @@ -114,12 +109,7 @@ "set_torch_num_threads", "shard_model", "get_shard_conditions", - "prepare_model_for_fsdp_with_meta_device", "validate_no_params_on_meta_device", - "contains_fsdp", - "FSDPPolicyType", - "get_full_finetune_fsdp_wrap_policy", - "lora_fsdp_wrap_policy", "gather_cpu_state_dict", "get_full_optimizer_state_dict", "load_from_full_model_state_dict", diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index 14371ed84a..206172b624 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -8,7 +8,7 @@ import logging import os from itertools import chain -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Callable, cast, Dict, List, Optional, Tuple import torch import torch.distributed as dist @@ -19,12 +19,9 @@ from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.checkpoint.state_dict import _init_optim_state from torch.distributed.fsdp import ShardingStrategy -from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.optim import Optimizer from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune.modules import TransformerDecoder -from torchtune.modules.peft import DoRALinear, LoRALinear -from torchtune.modules.peft.lora import _lora_a_init_params, _lora_b_init_params from torchtune.utils import get_logger from torchtune.utils._device import get_device @@ -32,33 +29,6 @@ _log: logging.Logger = get_logger() -FSDPPolicyType: Type = Callable[[nn.Module, bool, int], bool] - -FSDPPolicyType.__doc__ = """ - -A datatype for a function that can be used as an FSDP wrapping policy. -In particular, this type denotes a function that can accept an nn.Module, a boolean flag, and an integer -and return a boolean indicating whether the module should be wrapped with FSDP. Objects of this type can -be directly passed into PyTorch FSDP's ``auto_wrap_policy`` argument to specify how FSDP wraps submodules. - -The below function serves as an example of creating and returning a function that obeys the contract of -``FSDPPolicyType``:: - - def get_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], min_num_params: int): - - def my_fsdp_policy(module: nn.Module, modules_to_wrap: Set[Type], recurse: bool, min_num_params: int) -> bool: - if recurse: - return True - # Wrap layers that are of type in ``modules_to_wrap`` and layers with more than min_num_params - - return isinstance(module, tuple(modules_to_wrap)) or sum(p.numel() for p in module.parameters()) > 1000 - - return functools.partial(my_fsdp_policy, modules_to_wrap=modules_to_wrap) - -Please see documentation of ``auto_wrap_policy`` at https://pytorch.org/docs/stable/fsdp.html for additional details. - -""" - _valid_distributed_single_node_nnodes = ["1:1", "1"] @@ -177,105 +147,6 @@ def validate_no_params_on_meta_device(model: nn.Module) -> None: raise RuntimeError(f"Unexpected param or buffer {n} on meta device.") -def contains_fsdp(model: nn.Module) -> bool: - """ - Checks if the model contains FSDP. - - Args: - model (nn.Module): Model to check. - - Returns: - bool: True if the model contains FSDP, False otherwise. - """ - return any( - isinstance(m, torch.distributed.fsdp.FullyShardedDataParallel) - for m in model.modules() - ) - - -def _dummy_reset_params(x: nn.Module) -> None: - """ - Dummy method for patching no-op reset_parameters() when using - FSDP with meta device. - """ - return - - -def prepare_model_for_fsdp_with_meta_device(model: nn.Module) -> nn.Module: - """ - Dynamically define reset_parameters on every submodule of the model. For LoRA models, - ensure that the FSDP contract of reset_parameters only modifying a module's directly-owned - parameters is satisfied. More details here: https://github.com/pytorch/pytorch/issues/104187. - - Args: - model (nn.Module): model class to prepare for usage with FSDP and meta device. - - Returns: - nn.Module: Model with reset_parameters defined on every submodule. - In the case of a LoRA model, we override the default reset_parameters of nn.Linear. - - Raises: - RuntimeError: if model contains submodule with non-callable attribute reset_parameters - """ - for k, v in model.named_modules(): - # If the module does not have reset_parameters defined, we define - # a no-op reset_parameters method to satisfy FSDP's contract. - reset_params = getattr(v, "reset_parameters", None) - - if reset_params is not None and not callable(reset_params): - raise RuntimeError( - f"Cannot override existing reset_parameters variable for FSDP init in {k}" - ) - - if reset_params is None: - v.reset_parameters = _dummy_reset_params.__get__(v) - - # This will define reset_parameters for LoRA weight initialization - # directly on any LoRALinear submodules lora_a and lora_b. - if isinstance(v, LoRALinear) or isinstance(v, DoRALinear): - v.lora_a.reset_parameters = _lora_a_init_params.__get__(v.lora_a) - v.lora_b.reset_parameters = _lora_b_init_params.__get__(v.lora_b) - - return model - - -def lora_fsdp_wrap_policy(modules_to_wrap: Set[Type]) -> FSDPPolicyType: - """ - A default policy for wrapping models trained with LoRA using FSDP. - - FSDP's default behavior is to allocate gradients at the level of FSDP-wrapped modules. - This means that if any parameter in a given FSDP-wrapped module requires gradients, then memory will be - allocated for gradients for the entire module. - - In the case of LoRA, where only the adapters are trainable, this means that - we need to wrap the adapter submodules in their own FSDP units to - maximize memory savings. After this is done, model will also be hierarchically wrapped - based on nn.Module types specified in ``modules_to_wrap``. - - Args: - modules_to_wrap (Set[Type]): nn.Module types to recursively wrap - - Returns: - FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel``. Please see - documentation for :const:`~torchtune.utils.FSDPPolicyType` for additional details. - """ - - def lora_wrap_fsdp(module: nn.Module, recurse: bool, **kwargs): - if recurse: - return True - - # Assumes lora_a and lora_b are nn.Linears that are the - # only trainable modules in the entire network. Wraps - # these in separate FSDP unit to work around FSDP allocating - # extra gradient memory when wrapped with other modules. - if hasattr(module, "weight") and module.weight.requires_grad: - return True - - return isinstance(module, tuple(modules_to_wrap)) - - return lora_wrap_fsdp - - def load_from_full_model_state_dict( model: "FSDPModule", # noqa full_sd: Dict[str, Any], @@ -500,66 +371,6 @@ def load_from_full_optimizer_state_dict( ) -def get_full_finetune_fsdp_wrap_policy( - memory_efficient_fsdp_wrap: bool, modules_to_wrap: Set[Type] -) -> FSDPPolicyType: - """ - Retrieves an FSDP wrapping policy based on the specified flags ``memory_efficient_fsdp_wrap`` and - ``modules_to_wrap``. Specifically, if ``memory_efficient_fsdp_wrap`` is set to ``True``, the returned - policy will wrap the model's token embedding and output projection in addition to the modules specified - to maximize memory savings. - - Args: - memory_efficient_fsdp_wrap (bool): If ``True``, will also wrap embedding and output projection layers with FSDP. - modules_to_wrap (Set[Type]): Set of module types to wrap. - - Note: - ``memory_efficient_fsdp_wrap`` memory improvements have currently only been verified on llama3 workloads - where they provide ~15% memory improvement (when used alongside AC memory efficient wrapping). Other workloads - have not been verified and may not see the same improvements. - - Returns: - FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel`` as the ``auto_wrap_policy`` - argument. Please see documentation for :const:`~torchtune.utils.FSDPPolicyType` for additional details. - """ - if memory_efficient_fsdp_wrap: - return _memory_efficient_wrap_policy(modules_to_wrap=modules_to_wrap) - else: - return ModuleWrapPolicy(modules_to_wrap) - - -def _memory_efficient_wrap_policy(modules_to_wrap: Set[Type]) -> FSDPPolicyType: - """ - A default policy for memory efficient wrapping for full finetuning using FSDP. Specifically, - this will wrap the model's token embedding and output projection into their own FSDP units to - maximize memory savings. This helps especially if these layers are particularly large, - such as due to a large embedding size. - After this is done, model will also be hierarchically wrapped - based on nn.Module types specified in ``modules_to_wrap``. This function assumes that the - input model has an attribute ``output`` that is a nn.Linear which is the model's output projection. - Args: - modules_to_wrap (Set[Type]): nn.Module types to recursively wrap - Returns: - FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel``. - """ - modules_to_wrap.add(torch.nn.Embedding) - - def llama3_wrap(module: nn.Module, recurse: bool, **kwargs): - # Label that output_proj should be wrapped individually. - if isinstance(module, TransformerDecoder): - module.output._wrap = True - if recurse: - return True - - # Wrap output_proj individually. - if getattr(module, "_wrap", False): - return True - - return isinstance(module, tuple(modules_to_wrap)) - - return llama3_wrap - - def get_shard_conditions( name: str, module: nn.Module,