From b5ae0fce886118839c8e99f99785a3c7c4d38549 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Thu, 7 Nov 2024 10:02:01 +0100 Subject: [PATCH 1/2] Fix gradient checkpointing and write tests - oerwrite the gradient_checkpointing_enable to provide our ForwardContext during the recomputation of values during backpropagation - 2 bugs remaining: bottleneck adapter for models with the legacy implementation (BERT) & Parallel. Parallel has the problem that we manipulate the batch dimension and this currently leads to an error --- src/adapters/context.py | 5 ++- src/adapters/model_mixin.py | 64 ++++++++++++++++++++++++++ tests/methods/base.py | 69 ++++++++++++++++++++++++++++- tests/methods/test_ia3.py | 6 +++ tests/methods/test_lora.py | 6 +++ tests/methods/test_prefix_tuning.py | 6 +++ tests/methods/test_prompt_tuning.py | 6 +++ tests/methods/test_reft.py | 6 +++ tests/methods/test_unipelt.py | 6 +++ 9 files changed, 170 insertions(+), 4 deletions(-) diff --git a/src/adapters/context.py b/src/adapters/context.py index 70e685d037..db09b8918f 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -1,10 +1,11 @@ import functools import threading +from typing import ContextManager from .composition import parse_composition, parse_heads_from_composition -class AdapterSetup: +class AdapterSetup(ContextManager): """ Represents an adapter setup of a model including active adapters and active heads. This class is intended to be used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will @@ -67,7 +68,7 @@ def get_context_head_setup(cls): return None -class ForwardContext: +class ForwardContext(ContextManager): """ Holds context information during a forward pass through a model. This class should be used via the ``ForwardContext.wrap()`` method. diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 659a6cfcff..342913c8d6 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1,14 +1,18 @@ +import contextlib +import functools import inspect import logging import os from abc import ABC, abstractmethod from collections import defaultdict from copy import deepcopy +from functools import partial from os.path import join from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn +from torch.utils.checkpoint import checkpoint from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig from transformers import GenerationConfig @@ -1447,6 +1451,66 @@ def save_pretrained( # Remove adapters config del self.config.adapters + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": False} + + # >>> START AH Changes <<< + if "use_reentrant" not in gradient_checkpointing_kwargs: + # use_reentrant must be set. + gradient_checkpointing_kwargs["use_reentrant"] = False + else: + if gradient_checkpointing_kwargs["use_reentrant"]: + raise ValueError( + "Gradient checkpointing with use_reentrant=True is not supported. For gradient checkpointing, we need to set context_fn, which is only supported by PyTorch when use_reentrant is set to False." + ) + + def gradient_checkpointing_function(function, *args, **kwargs): + context = ForwardContext(self, *args, **kwargs) + context_fn = lambda: (contextlib.nullcontext(), context) + return checkpoint(function, *args, context_fn=context_fn, **kwargs) + + gradient_checkpointing_func = functools.partial( + gradient_checkpointing_function, **gradient_checkpointing_kwargs + ) + # >>> END AH Changes <<< + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` method + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warning( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + @inherit_doc class ModelBaseAdaptersMixin(ModelAdaptersMixin): diff --git a/tests/methods/base.py b/tests/methods/base.py index 0d20f32fef..86eb3e08ca 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -1,10 +1,12 @@ import copy import os import tempfile +from typing import Callable import torch import adapters +import adapters.composition as ac from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer, AutoAdapterModel from adapters.heads import CausalLMHead from adapters.utils import WEIGHTS_NAME @@ -247,7 +249,7 @@ def run_full_model_load_test(self, adapter_config): self.assertEqual(len(output1), len(output2)) self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4)) - def trainings_run(self, model, lr=1.0, steps=8): + def trainings_run(self, model, lr=1.0, steps=8, batch_size=2, gradient_accumulation_steps=1): # setup dataset train_dataset = self.dataset() @@ -257,7 +259,8 @@ def trainings_run(self, model, lr=1.0, steps=8): learning_rate=lr, max_steps=steps, no_cuda=True, - per_device_train_batch_size=2, + per_device_train_batch_size=batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, remove_unused_columns=False, ) @@ -370,3 +373,65 @@ def run_reset_test(self, adapter_config): # check forward pass self.assertEqual(len(output_1), len(output_2)) self.assertTrue(torch.allclose(output_1[0], output_2[0], atol=1e-3)) + + def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[adapters.ModelAdaptersMixin], None]): + """ + Test that gradient checkpointing produces the same results as normal training + Args: + adapter_setup_fn: Function that takes a model and sets up the adapter training. Must also add a head (usually via self.add_head(...)). We have this in a separate function to allow complex setups (like training a normal adapter or training parallel setups) + """ + + if not self.do_run_train_tests: + self.skipTest("Skipping training tests. Set `do_run_train_tests=True` to run them.") + if self.config_class not in ADAPTER_MODEL_MAPPING: + self.skipTest("Does not support flex heads.") + + config = self.config() + state_dict_after_training = {} + + for train_with_checkpointing in [True, False]: + # Set random seed + torch.manual_seed(42) + + # Initialize model + model = adapters.AutoAdapterModel.from_config(config) + model.to(torch_device) + adapter_setup_fn(model) + + # Enable gradient checkpointing + if train_with_checkpointing: + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + # Train & store state dict + self.trainings_run(model, batch_size=1, gradient_accumulation_steps=2) + state_dict_after_training[train_with_checkpointing] = copy.deepcopy(model.state_dict()) + + # Check that the state dicts are the same (we know that normal training works as expected, so we only need to check that gradient checkpointing produces the same results.) + for (k1, v1), (k2, v2) in zip( + state_dict_after_training[True].items(), state_dict_after_training[False].items() + ): + v1 = v1.to(v2.device) + self.assertTrue(torch.equal(v1, v2), msg=f"Key {k1} is not equal:\nv1: {v1}\nv2: {v2}") + + def run_gradient_checkpointing_single_adapter_test(self, adapter_config): + def adapter_setup_fn(model): + model.add_adapter("adapter1", config=adapter_config) + self.add_head(model, "adapter1") + model.train_adapter("adapter1") + model.adapter_to("adapter1", torch_device) + + self._run_gradient_checkpointing_test_helper(adapter_setup_fn) + + def run_gradient_checkpointing_test_parallel_adapters(self, adapter_config): + def adapter_setup_fn(model): + model.add_adapter("adapter1", config=adapter_config) + model.add_adapter("adapter2", config=adapter_config) + self.add_head(model, "adapter1") + self.add_head(model, "adapter2") + model.active_adapters = ac.Parallel("adapter1", "adapter2") + model.train_adapter(ac.Parallel("adapter1", "adapter2")) + model.adapter_to("adapter1", torch_device) + model.adapter_to("adapter2", torch_device) + + self._run_gradient_checkpointing_test_helper(adapter_setup_fn) diff --git a/tests/methods/test_ia3.py b/tests/methods/test_ia3.py index 3a30e2448d..ced2dbb003 100644 --- a/tests/methods/test_ia3.py +++ b/tests/methods/test_ia3.py @@ -45,3 +45,9 @@ def test_merge_ia3(self): def test_reset_ia3(self): self.run_reset_test(IA3Config(init_weights="bert")) + + def test_ia3_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(IA3Config()) + + def test_ia3_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(IA3Config()) diff --git a/tests/methods/test_lora.py b/tests/methods/test_lora.py index 067f78c8b8..0fbd2f6808 100644 --- a/tests/methods/test_lora.py +++ b/tests/methods/test_lora.py @@ -313,3 +313,9 @@ def test_merge_lora(self): def test_reset_lora(self): self.run_reset_test(LoRAConfig(init_weights="bert")) + + def test_lora_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(LoRAConfig()) + + def test_lora_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(LoRAConfig()) diff --git a/tests/methods/test_prefix_tuning.py b/tests/methods/test_prefix_tuning.py index dd443c0d0b..c6f5ade445 100644 --- a/tests/methods/test_prefix_tuning.py +++ b/tests/methods/test_prefix_tuning.py @@ -101,3 +101,9 @@ def test_prefix_tuning_generate(self): input_ids = input_ids.to(torch_device) generated = model1.generate(input_ids, max_length=seq_output_length) self.assertLessEqual(generated.shape, (1, seq_output_length)) + + def test_prefix_tuning_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(PrefixTuningConfig()) + + def test_prefix_tuning_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(PrefixTuningConfig()) diff --git a/tests/methods/test_prompt_tuning.py b/tests/methods/test_prompt_tuning.py index 97015d1319..f3c4b5b657 100644 --- a/tests/methods/test_prompt_tuning.py +++ b/tests/methods/test_prompt_tuning.py @@ -36,3 +36,9 @@ def test_load_full_model_prompt_tuning(self): def test_train_prompt_tuning(self): self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."]) + + def test_prompt_tuning_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(PromptTuningConfig(prompt_length=10)) + + def test_prompt_tuning_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(PromptTuningConfig(prompt_length=10)) diff --git a/tests/methods/test_reft.py b/tests/methods/test_reft.py index 8849221808..8e5cab0f27 100644 --- a/tests/methods/test_reft.py +++ b/tests/methods/test_reft.py @@ -77,3 +77,9 @@ def test_load_full_model_reft(self): def test_train_loreft(self): self.run_train_test(LoReftConfig(), ["refts.{name}."]) + + def test_reft_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(LoReftConfig()) + + def test_reft_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(LoReftConfig()) diff --git a/tests/methods/test_unipelt.py b/tests/methods/test_unipelt.py index d29fa5f18d..2191a31161 100644 --- a/tests/methods/test_unipelt.py +++ b/tests/methods/test_unipelt.py @@ -64,3 +64,9 @@ def test_output_adapter_gating_scores_unipelt(self): self.assertGreaterEqual(len(per_layer_scores), 3) for k, v in per_layer_scores.items(): self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + + def test_unipelt_gradient_checkpointing_single_adapter(self): + self.run_gradient_checkpointing_single_adapter_test(UniPELTConfig()) + + def test_unipelt_gradient_checkpointing_parallel_adapters(self): + self.run_gradient_checkpointing_test_parallel_adapters(UniPELTConfig()) From cb07dd44c1040eb9a85194f85425ed130eca5304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leon=20Engl=C3=A4nder?= Date: Mon, 11 Nov 2024 22:51:08 +0100 Subject: [PATCH 2/2] minor fix but doesn't resolve the remaining issues --- src/adapters/model_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 342913c8d6..dfc7022bd1 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1482,7 +1482,7 @@ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): ) def gradient_checkpointing_function(function, *args, **kwargs): - context = ForwardContext(self, *args, **kwargs) + context = ForwardContext.get_context() context_fn = lambda: (contextlib.nullcontext(), context) return checkpoint(function, *args, context_fn=context_fn, **kwargs)