Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add Support for Gradient Checkpointing #759

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import functools
import threading
from typing import ContextManager

from .composition import parse_composition, parse_heads_from_composition


class AdapterSetup:
class AdapterSetup(ContextManager):
"""
Represents an adapter setup of a model including active adapters and active heads. This class is intended to be
used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will
Expand Down Expand Up @@ -67,7 +68,7 @@ def get_context_head_setup(cls):
return None


class ForwardContext:
class ForwardContext(ContextManager):
"""
Holds context information during a forward pass through a model. This class should be used via the
``ForwardContext.wrap()`` method.
Expand Down
64 changes: 64 additions & 0 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import contextlib
import functools
import inspect
import logging
import os
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
from functools import partial
from os.path import join
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

from adapters.configuration.adapter_config import ConfigUnion, LoRAConfig
from transformers import GenerationConfig
Expand Down Expand Up @@ -1447,6 +1451,66 @@ def save_pretrained(
# Remove adapters config
del self.config.adapters

def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""
Activates gradient checkpointing for the current model.

Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".

We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2

Args:
gradient_checkpointing_kwargs (dict, *optional*):
Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")

if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {"use_reentrant": False}

# >>> START AH Changes <<<
if "use_reentrant" not in gradient_checkpointing_kwargs:
# use_reentrant must be set.
gradient_checkpointing_kwargs["use_reentrant"] = False
else:
if gradient_checkpointing_kwargs["use_reentrant"]:
raise ValueError(
"Gradient checkpointing with use_reentrant=True is not supported. For gradient checkpointing, we need to set context_fn, which is only supported by PyTorch when use_reentrant is set to False."
)

def gradient_checkpointing_function(function, *args, **kwargs):
context = ForwardContext.get_context()
context_fn = lambda: (contextlib.nullcontext(), context)
return checkpoint(function, *args, context_fn=context_fn, **kwargs)

gradient_checkpointing_func = functools.partial(
gradient_checkpointing_function, **gradient_checkpointing_kwargs
)
# >>> END AH Changes <<<

# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` method
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters

if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warning(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()


@inherit_doc
class ModelBaseAdaptersMixin(ModelAdaptersMixin):
Expand Down
69 changes: 67 additions & 2 deletions tests/methods/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import copy
import os
import tempfile
from typing import Callable

import torch

import adapters
import adapters.composition as ac
from adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer, AutoAdapterModel
from adapters.heads import CausalLMHead
from adapters.utils import WEIGHTS_NAME
Expand Down Expand Up @@ -247,7 +249,7 @@ def run_full_model_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.allclose(output1[0], output2[0], atol=1e-4))

def trainings_run(self, model, lr=1.0, steps=8):
def trainings_run(self, model, lr=1.0, steps=8, batch_size=2, gradient_accumulation_steps=1):
# setup dataset
train_dataset = self.dataset()

Expand All @@ -257,7 +259,8 @@ def trainings_run(self, model, lr=1.0, steps=8):
learning_rate=lr,
max_steps=steps,
no_cuda=True,
per_device_train_batch_size=2,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
remove_unused_columns=False,
)

Expand Down Expand Up @@ -370,3 +373,65 @@ def run_reset_test(self, adapter_config):
# check forward pass
self.assertEqual(len(output_1), len(output_2))
self.assertTrue(torch.allclose(output_1[0], output_2[0], atol=1e-3))

def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[adapters.ModelAdaptersMixin], None]):
"""
Test that gradient checkpointing produces the same results as normal training
Args:
adapter_setup_fn: Function that takes a model and sets up the adapter training. Must also add a head (usually via self.add_head(...)). We have this in a separate function to allow complex setups (like training a normal adapter or training parallel setups)
"""

if not self.do_run_train_tests:
self.skipTest("Skipping training tests. Set `do_run_train_tests=True` to run them.")
if self.config_class not in ADAPTER_MODEL_MAPPING:
self.skipTest("Does not support flex heads.")

config = self.config()
state_dict_after_training = {}

for train_with_checkpointing in [True, False]:
# Set random seed
torch.manual_seed(42)

# Initialize model
model = adapters.AutoAdapterModel.from_config(config)
model.to(torch_device)
adapter_setup_fn(model)

# Enable gradient checkpointing
if train_with_checkpointing:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Train & store state dict
self.trainings_run(model, batch_size=1, gradient_accumulation_steps=2)
state_dict_after_training[train_with_checkpointing] = copy.deepcopy(model.state_dict())

# Check that the state dicts are the same (we know that normal training works as expected, so we only need to check that gradient checkpointing produces the same results.)
for (k1, v1), (k2, v2) in zip(
state_dict_after_training[True].items(), state_dict_after_training[False].items()
):
v1 = v1.to(v2.device)
self.assertTrue(torch.equal(v1, v2), msg=f"Key {k1} is not equal:\nv1: {v1}\nv2: {v2}")

def run_gradient_checkpointing_single_adapter_test(self, adapter_config):
def adapter_setup_fn(model):
model.add_adapter("adapter1", config=adapter_config)
self.add_head(model, "adapter1")
model.train_adapter("adapter1")
model.adapter_to("adapter1", torch_device)

self._run_gradient_checkpointing_test_helper(adapter_setup_fn)

def run_gradient_checkpointing_test_parallel_adapters(self, adapter_config):
def adapter_setup_fn(model):
model.add_adapter("adapter1", config=adapter_config)
model.add_adapter("adapter2", config=adapter_config)
self.add_head(model, "adapter1")
self.add_head(model, "adapter2")
model.active_adapters = ac.Parallel("adapter1", "adapter2")
model.train_adapter(ac.Parallel("adapter1", "adapter2"))
model.adapter_to("adapter1", torch_device)
model.adapter_to("adapter2", torch_device)

self._run_gradient_checkpointing_test_helper(adapter_setup_fn)
6 changes: 6 additions & 0 deletions tests/methods/test_ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ def test_merge_ia3(self):

def test_reset_ia3(self):
self.run_reset_test(IA3Config(init_weights="bert"))

def test_ia3_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(IA3Config())

def test_ia3_gradient_checkpointing_parallel_adapters(self):
self.run_gradient_checkpointing_test_parallel_adapters(IA3Config())
6 changes: 6 additions & 0 deletions tests/methods/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,9 @@ def test_merge_lora(self):

def test_reset_lora(self):
self.run_reset_test(LoRAConfig(init_weights="bert"))

def test_lora_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(LoRAConfig())

def test_lora_gradient_checkpointing_parallel_adapters(self):
self.run_gradient_checkpointing_test_parallel_adapters(LoRAConfig())
6 changes: 6 additions & 0 deletions tests/methods/test_prefix_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,9 @@ def test_prefix_tuning_generate(self):
input_ids = input_ids.to(torch_device)
generated = model1.generate(input_ids, max_length=seq_output_length)
self.assertLessEqual(generated.shape, (1, seq_output_length))

def test_prefix_tuning_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(PrefixTuningConfig())

def test_prefix_tuning_gradient_checkpointing_parallel_adapters(self):
self.run_gradient_checkpointing_test_parallel_adapters(PrefixTuningConfig())
6 changes: 6 additions & 0 deletions tests/methods/test_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,9 @@ def test_load_full_model_prompt_tuning(self):

def test_train_prompt_tuning(self):
self.run_train_test(PromptTuningConfig(prompt_length=10), ["prompt_tunings.{name}."])

def test_prompt_tuning_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(PromptTuningConfig(prompt_length=10))

def test_prompt_tuning_gradient_checkpointing_parallel_adapters(self):
self.run_gradient_checkpointing_test_parallel_adapters(PromptTuningConfig(prompt_length=10))
6 changes: 6 additions & 0 deletions tests/methods/test_reft.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ def test_load_full_model_reft(self):

def test_train_loreft(self):
self.run_train_test(LoReftConfig(), ["refts.{name}."])

def test_reft_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(LoReftConfig())

def test_reft_gradient_checkpointing_parallel_adapters(self):
self.run_gradient_checkpointing_test_parallel_adapters(LoReftConfig())
6 changes: 6 additions & 0 deletions tests/methods/test_unipelt.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,9 @@ def test_output_adapter_gating_scores_unipelt(self):
self.assertGreaterEqual(len(per_layer_scores), 3)
for k, v in per_layer_scores.items():
self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k)

def test_unipelt_gradient_checkpointing_single_adapter(self):
self.run_gradient_checkpointing_single_adapter_test(UniPELTConfig())

def test_unipelt_gradient_checkpointing_parallel_adapters(self):
self.run_gradient_checkpointing_test_parallel_adapters(UniPELTConfig())
Loading