WIP: Add Support for Gradient Checkpointing #759
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add Support for Gradient Checkpointing
This PR adds support for gradient checkpointing Gradient checkpointing is a technique that trades computation for memory by recomputing intermediate activations during the backward pass instead of storing them. This is particularly useful when training large models. Because we recompute values during the backpropagation, we need to preserve the original ForwardContext in this phase. I solved this by overwriting the
gradient_checkpointing_enable
function so that the checkpoint function receives the current ForwardContext as the backward pass context manager.I added tests to check if it works for every model using 1) a single adapter or 2) parallel composition.
Not Yet Fixed Bugs
one of the variables needed for gradient computation has been modified by an inplace operation
. I couldn't find out where this happened. For some models, the single adapter tests run through (like for Llama, ViT, Whisper); for others not (like BERT, DistilBERT, GPT2).pytest tests/test_llama.py::LlamaAdapterTest::test_prefix_tuning_gradient_checkpointing_single_adapter
pytest tests/test_bert.py::BertAdapterTest::test_prefix_tuning_gradient_checkpointing_single_adapter
pytest tests/test_llama.py::LlamaAdapterTest::test_lora_gradient_checkpointing_parallel_adapters