From 95961d456d9fc5a07dd969be4dbbddd3a86fb1c1 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 28 Oct 2024 13:11:13 -0700 Subject: [PATCH] Add support for QAT + LoRA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Summary:** This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe `qat_lora_finetune_distributed` mirrors the existing `lora_finetune_distributed` recipe, which performs only LoRA, and is analogous to the existing `qat_distributed` recipe, which performs only QAT. Helpful code review commands: ``` diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml diff --color recipes/configs/llama3_1/8B_lora.yaml recipes/configs/llama3_1/8B_qat_lora.yaml diff --color recipes/configs/llama3_2/1B_lora.yaml recipes/configs/llama3_2/1B_qat_lora.yaml diff --color recipes/configs/llama3_2/3B_lora.yaml recipes/configs/llama3_2/3B_qat_lora.yaml ``` For more context on QAT, please visit https://github.com/pytorch/torchtune/pull/980 and https://pytorch.org/blog/quantization-aware-training/. **Test Plan** Unit tests: ``` pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py ``` Manual tests: ``` export CUDA_VISIBLE_DEVICES=4,5,6,7 export NCCL_SHM_DISABLE=0 LOG_DIR=/home/andrewor/local/logs/tune/qat_lora tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \ batch_size=4 \ quantizer.groupsize=32 \ checkpointer.output_dir="$LOG_DIR" \ metric_logger.output_dir="${LOG_DIR}/metrics" tune run quantize --config quantization \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=["meta_model_0.pt"] \ checkpointer.model_type=LLAMA3 \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 tune run eleuther_eval --config eleuther_evaluation \ batch_size=1 \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ tasks=[wikitext] \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 ``` Results: ``` | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.6284|± | N/A| | | |none |None |byte_perplexity|↓ | 1.5458|± | N/A| | | |none |None |word_perplexity|↓ |10.2694|± | N/A| | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.6245|± | N/A| | | |none |None |byte_perplexity|↓ | 1.5416|± | N/A| | | |none |None |word_perplexity|↓ |10.1208|± | N/A| ``` --- README.md | 3 +- recipes/configs/llama3/8B_qat_lora.yaml | 113 ++ recipes/configs/llama3_1/8B_qat_lora.yaml | 116 +++ recipes/configs/llama3_2/1B_qat_lora.yaml | 112 ++ recipes/configs/llama3_2/3B_qat_lora.yaml | 113 ++ recipes/qat_lora_finetune_distributed.py | 972 ++++++++++++++++++ .../test_qat_lora_finetune_distributed.py | 266 +++++ tests/torchtune/modules/peft/test_lora.py | 14 +- torchtune/_recipe_registry.py | 11 + torchtune/modules/peft/__init__.py | 5 +- torchtune/modules/peft/lora.py | 161 ++- torchtune/training/quantization.py | 52 + 12 files changed, 1933 insertions(+), 5 deletions(-) create mode 100644 recipes/configs/llama3/8B_qat_lora.yaml create mode 100644 recipes/configs/llama3_1/8B_qat_lora.yaml create mode 100644 recipes/configs/llama3_2/1B_qat_lora.yaml create mode 100644 recipes/configs/llama3_2/3B_qat_lora.yaml create mode 100644 recipes/qat_lora_finetune_distributed.py create mode 100644 tests/recipes/test_qat_lora_finetune_distributed.py diff --git a/README.md b/README.md index 8caa38c890..3e494dd757 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ torchtune provides the following finetuning recipes for training on one or more | LoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py) | [Qwen2 0.5B single-device](recipes/configs/qwen2/0.5B_lora_single_device.yaml)
[Gemma 7B distributed](recipes/configs/gemma/7B_lora.yaml) | QLoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Phi3 Mini single-device](recipes/configs/phi3/mini_qlora_single_device.yaml)
[Llama 3.1 405B distributed](recipes/configs/llama3_1/405B_qlora.yaml) | DoRA/QDoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Llama3 8B QDoRA single-device](recipes/configs/llama3/8B_qdora_single_device.yaml)
[Llama3 8B DoRA distributed](recipes/configs/llama3/8B_dora.yaml) -| Quantization-Aware Training | 4-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml) +| Quantization-Aware Training | 2-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml) +| Quantization-Aware Training and LoRA Finetuning | 2-8 | [qat_lora_finetune_distributed](recipes/qat_lora_finetune_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_lora.yaml) | Direct Preference Optimization |1-8 | [lora_dpo_single_device](recipes/lora_dpo_single_device.py)
[lora_dpo_distributed](recipes/lora_dpo_distributed.py) | [Llama2 7B single-device](recipes/configs/llama2/7B_lora_dpo_single_device.yaml)
[Llama2 7B distributed](recipes/configs/llama2/7B_lora_dpo.yaml) | Proximal Policy Optimization | 1 | [ppo_full_finetune_single_device](recipes/ppo_full_finetune_single_device.py) | [Mistral 7B](recipes/configs/mistral/7B_full_ppo_low_memory.yaml) | Knowledge Distillation | 1 | [knowledge_distillation_single_device](recipes/knowledge_distillation_single_device.py) | [Qwen2 1.5B -> 0.5B](recipes/configs/qwen2/knowledge_distillation_single_device.yaml) diff --git a/recipes/configs/llama3/8B_qat_lora.yaml b/recipes/configs/llama3/8B_qat_lora.yaml new file mode 100644 index 0000000000..2104e6268e --- /dev/null +++ b/recipes/configs/llama3/8B_qat_lora.yaml @@ -0,0 +1,113 @@ +# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora checkpointer.checkpoint_dir= + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + max_seq_len: null + +# Model Arguments +model: + _component_: torchtune.models.llama3.lora_llama3_8b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/qat_lora_finetune_output +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 + +# QAT arguments +quantizer: + _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 diff --git a/recipes/configs/llama3_1/8B_qat_lora.yaml b/recipes/configs/llama3_1/8B_qat_lora.yaml new file mode 100644 index 0000000000..531d31fee9 --- /dev/null +++ b/recipes/configs/llama3_1/8B_qat_lora.yaml @@ -0,0 +1,116 @@ +# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora checkpointer.checkpoint_dir= + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 8 # higher increases accuracy and memory + lora_alpha: 16 # usually alpha=2*rank + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed +seed: null +shuffle: True +batch_size: 2 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/qat_lora_finetune_output +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 + +# QAT arguments +quantizer: + _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 diff --git a/recipes/configs/llama3_2/1B_qat_lora.yaml b/recipes/configs/llama3_2/1B_qat_lora.yaml new file mode 100644 index 0000000000..8d68ef632c --- /dev/null +++ b/recipes/configs/llama3_2/1B_qat_lora.yaml @@ -0,0 +1,112 @@ +# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py +# using a Llama3.2 1B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora checkpointer.checkpoint_dir= + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model + max_seq_len: null + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.lora_llama3_2_1b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/ + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-3.2-1B-Instruct/ + model_type: LLAMA3_2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/qat_lora_finetune_output +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 + +# QAT arguments +quantizer: + _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 diff --git a/recipes/configs/llama3_2/3B_qat_lora.yaml b/recipes/configs/llama3_2/3B_qat_lora.yaml new file mode 100644 index 0000000000..2fac4e0fa1 --- /dev/null +++ b/recipes/configs/llama3_2/3B_qat_lora.yaml @@ -0,0 +1,113 @@ +# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py +# using a Llama3.2 3B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Llama-3.2-3B-Instruct --output-dir /tmp/Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on 2 devices, run the following command from root: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora checkpointer.checkpoint_dir= + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model + max_seq_len: null + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.lora_llama3_2_3b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank + lora_dropout: 0.0 + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-3B-Instruct/ + checkpoint_files: [ + model-00001-of-00002.safetensors, + model-00002-of-00002.safetensors, + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-3.2-3B-Instruct/ + model_type: LLAMA3_2 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 8 # Use to increase virtual batch size +compile: False # pytorch compile, set to true for better perf/memory + +# Logging +output_dir: /tmp/qat_lora_finetune_output +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: True + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 + +# QAT arguments +quantizer: + _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer + groupsize: 256 diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py new file mode 100644 index 0000000000..f9b1fc991f --- /dev/null +++ b/recipes/qat_lora_finetune_distributed.py @@ -0,0 +1,972 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, training, utils +from torchtune.config._utils import _get_component_from_path +from torchtune.data import padded_collate_packed +from torchtune.datasets import ConcatDataset +from torchtune.modules.peft import ( + DoRALinear, + get_adapter_params, + get_adapter_state_dict, + get_lora_module_names, + get_merged_lora_ckpt, + LoRALinear, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.quantization import swap_lora_linear_with_qat + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class QATLoRAFinetuneRecipeDistributed(FTRecipeInterface): + """ + Distributed quantization-aware training (QAT) and LoRA finetuning recipe for dense transformer-based + LLMs such as Llama2. This recipe supports distributed training and can be run on a single node (1 to + 8 GPUs). Only compatible with torchao 0.7+. + + Features: + - Quantization-aware training (QAT). Perform fake quantization on weights and/or activations + during finetuning, with the goal of ultimately producing a quantized model with minimal + accuracy degradation. This recipe produces an unquantized model in the original dtype, + which can then be quantized separately. + + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch 2.5.0 or later and will be + enabled by default if an acceptable torch version is found. Activation offloading can be used in + conjunction with activation checkpointing. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + ValueError: If world_size is 1 + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``left_pad_sequence`` is set as the data collator. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. + RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False. + """ + + def __init__(self, cfg: DictConfig) -> None: + try: + from torchao.quantization import qat # noqa: F401 + except ImportError as err: + raise ValueError( + "qat_lora_finetune_distributed is only compatible with torchao 0.7+" + ) from err + + self._device = utils.get_device(device=cfg.device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + _, rank = training.get_world_size_and_rank() + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + self._is_rank_zero = rank == 0 + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + if self._log_peak_memory_stats and self._device.type != "cuda": + log.info( + "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." + ) + self._log_peak_memory_stats = False + + # These attributes constitute the recipe state and are updated by ``load_checkpoint`` + # when ``resume_from_checkpoint`` is ``True`` + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + # activation checkpointing/offloading + self._enable_activation_checkpointing = cfg.get( + "enable_activation_checkpointing", False + ) + self._enable_activation_offloading = cfg.get( + "enable_activation_offloading", False + ) + if self._enable_activation_offloading: + if self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be True when training on CUDA" + ) + if not self._enable_activation_checkpointing: + raise RuntimeError( + "enable_activation_offloading should only be True when enable_activation_checkpointing is True" + ) + elif ( + self._enable_activation_checkpointing + and cfg.checkpointer.model_type != "LLAMA3_VISION" + ): + utils.log_rank_zero( + log, + "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. " + "Enabling activation offloading should reduce memory further.", + ) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + # When resuming from checkpoint for LoRA, the recipe expects the adapter weights + # and recipe state to be present. The keys should match up with what ``save_checkpoint`` + # used to create these intermediate checkpoints + if self._resume_from_checkpoint: + if training.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + self._compile = cfg.get("compile", False) + + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=self._enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[training.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + quantizer_cfg=cfg.get("quantizer", None), + ) + self._tokenizer = config.instantiate(cfg.tokenizer) + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._is_rank_zero: + log.info("Loss is initialized.") + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after all of these are setup + collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft") + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + collate_fn=collate_name, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _convert_model_to_qat(self, model: nn.Module, quantizer_cfg: DictConfig): + """ + Convert the model to support quantization-aware training during fine-tuning. + """ + for name, child in model.named_modules(): + if isinstance(child, DoRALinear): + raise ValueError("QAT is currently not compatible with DoRA") + quantizer = config.instantiate(quantizer_cfg) + quantizer.precision = self._dtype + quantizer_mode = training.quantization.get_quantizer_mode(quantizer) + if "qat" not in quantizer_mode: + raise ValueError( + "Quantizer mode '%s' is not supported for finetuning" % quantizer_mode + ) + activation_config = quantizer.get_activation_fake_quantize_config() + weight_config = quantizer.get_weight_fake_quantize_config() + swap_lora_linear_with_qat(model, activation_config, weight_config) + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + base_model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + quantizer_cfg: Optional[DictConfig] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` + """ + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + if quantizer_cfg is None: + raise ValueError("Quantizer must be specified for QAT + LoRA finetuning") + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + self._convert_model_to_qat(model, quantizer_cfg) + + set_trainable_params(model, get_adapter_params(model)) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + if lora_weights_state_dict: + lora_missing, lora_unexpected = training.load_from_full_model_state_dict( + model, + lora_weights_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + else: + lora_missing, lora_unexpected = None, None + + # Initialize LoRA params and RoPE buffers + with training.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if fsdp_cpu_offload else self._device + for m in model.modules(): + if ( + isinstance(m, LoRALinear) or isinstance(m, DoRALinear) + ) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.lora_a.to_empty(device=lora_device) + m.lora_b.to_empty(device=lora_device) + m.initialize_parameters() + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + base_missing, base_unexpected = training.load_from_full_model_state_dict( + model, + base_model_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + # activation offloading + self.activations_handling_ctx = training.get_act_offloading_ctx_manager( + model, enable_activation_offloading + ) + + # log + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + collate_fn: str, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = training.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + # Instantiate collate_fn + if "left_pad_sequence" in collate_fn: + raise RuntimeError("left_pad_sequence collator is only for inference.") + collate_fn = _get_component_from_path(collate_fn) + + sampler = DistributedSampler( + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ) + + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + # dropping last avoids shape issues with compile + flex attention + drop_last=True, + collate_fn=( + partial( + collate_fn, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else padded_collate_packed + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint( + self, + epoch: int, + ) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + Checkpointer will save the merged weights, adapter weights and recipe state in + different checkpoint files. To correctly resume from training, the adapter weights + and recipe state must be provided along with the base model weights. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + if self._is_rank_zero: + log.info( + "Saving checkpoint. This may take some time. Retrieving full model state dict..." + ) + start = time.perf_counter() + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + state_dict = self._model.state_dict() + if self._save_adapter_weights_only: + state_dict = get_adapter_state_dict(state_dict, device=None) + + cpu_state_dict = training.gather_cpu_state_dict( + state_dict, + self._is_rank_zero, + device=self._device, + ) + if self._is_rank_zero: + log.info( + f"Getting full model state dict took {time.perf_counter() - start:.2f} secs" + ) + + if intermediate_checkpoint: + if self._is_rank_zero: + log.info("Retrieving optimizer state dict...") + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + if self._is_rank_zero: + log.info( + f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs" + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + start = time.perf_counter() + + if self._save_adapter_weights_only: + adapter_state_dict = cpu_state_dict + else: + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_state_dict = get_adapter_state_dict(cpu_state_dict) + + # merge the adapter weights and base weights to create the model checkpoint + merged_state_dict = get_merged_lora_ckpt( + cpu_state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config}) + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs") + + torch.distributed.barrier() + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + world_size, rank = training.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + + with self.activations_handling_ctx: + logits = self._model(**batch) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_fn(logits, labels) * current_num_tokens + + # free logits otherwise it peaks backward memory + del logits + + running_loss += current_loss + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + # Get total number of tokens across all ranks to normalize gradients + torch.distributed.all_reduce(num_tokens) + # This will ensure that the logged loss matches what we're optimizing + torch.distributed.all_reduce(running_loss) + # Manually scale the gradients from unnormalized loss by total # of tokens + training.scale_grads(self._model, 1 / num_tokens) + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() / num_tokens + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens + / (time_per_step * world_size), + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + self._profiler.stop() + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + config.log_config(recipe_name="QATLoRAFinetuneRecipeDistributed", cfg=cfg) + + recipe = QATLoRAFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/tests/recipes/test_qat_lora_finetune_distributed.py b/tests/recipes/test_qat_lora_finetune_distributed.py new file mode 100644 index 0000000000..5be3a2379a --- /dev/null +++ b/tests/recipes/test_qat_lora_finetune_distributed.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from omegaconf import OmegaConf +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) +from torchtune import config +from torchtune.training.quantization import _torchao_0_7_supported + + +class TestQATLoRAFinetuneDistributedRecipe: + def _get_test_config_overrides(self): + return [ + "dataset.train_on_input=False", + "seed=9", + "epochs=2", + "dtype=fp32", + "max_steps_per_epoch=2", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "compile=False", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama3": [11.9325, 11.9325, 11.9325, 11.9369], + } + return loss_values_map[model_type] + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + @pytest.mark.parametrize( + "micro_batch_size, gradient_accumulation_steps, should_compile", + [(4, 1, True), (1, 4, False)], + ) + @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+") + def test_loss( + self, + micro_batch_size, + gradient_accumulation_steps, + should_compile, + tmpdir, + monkeypatch, + ): + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed + --config llama3/8B_qat_lora \ + batch_size={micro_batch_size} \ + gradient_accumulation_steps={gradient_accumulation_steps} \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + metric_logger.filename={log_file} \ + tokenizer.path=/tmp/test-artifacts/tokenizer.model \ + tokenizer.prompt_template=null \ + compile={should_compile} \ + enable_activation_checkpointing=False \ + enable_activation_offloading=False \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + + cmd = cmd + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + loss_values = get_loss_values_from_metric_logger(log_file) + expected_loss_values = self._fetch_expected_loss_values("llama3") + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + @pytest.mark.parametrize( + "config, model_type, ckpt_type, save_adapter_weights_only", + [ + ("llama3/8B_qat_lora", "llama3", "tune", False), + ], + ) + @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+") + def test_training_state_on_resume( + self, + config, + model_type, + ckpt_type, + tmpdir, + monkeypatch, + save_adapter_weights_only, + ): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + expected_loss_values = self._fetch_expected_loss_values(model_type) + + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \ + --config {config} \ + batch_size=4 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + save_adapter_weights_only={save_adapter_weights_only} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + + cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \ + --config {config} \ + batch_size=4 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + resume_from_checkpoint=True \ + metric_logger.filename={log_file} \ + enable_activation_checkpointing=True \ + enable_activation_offloading=True \ + """.split() + + cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + expected_loss_values = self._fetch_expected_loss_values(model_type)[2:] + + loss_values = get_loss_values_from_metric_logger(log_file) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @pytest.mark.parametrize( + "recipe_config, model_type, ckpt_type", + [ + ("llama3/8B_qat_lora", "llama3", "tune"), + ], + ) + @gpu_test(gpu_count=2) + @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+") + def test_save_and_load_merged_weights( + self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch + ): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \ + --config {recipe_config} \ + batch_size=4 \ + gradient_accumulation_steps=1 \ + output_dir={tmpdir} \ + model=torchtune.models.lora_small_test_model \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + enable_activation_checkpointing=True \ + enable_activation_offloading=True \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + + cmd = cmd + self._get_test_config_overrides() + model_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Next load both the merged weights in a base model + # and the base model weights + trained adapter weights in the LoRA model + # The results of calling forward on dummy inputs should be the same. + inputs = torch.randint(low=0, high=32_000, size=(2, 100)) + + # Build LoRA model for loading base + adapter weights separately + lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) + + # Build base model for loading merged weights + base_config = MODEL_TEST_CONFIGS[model_type] + model = config.instantiate(OmegaConf.from_dotlist(base_config).model) + + # Load base model and trained adapter weights into LoRA model and call fwd + with open(f"{tmpdir}/adapter_1.pt", "rb") as f: + lora_sd = torch.load(f, weights_only=True) + with open(ckpt_path, "rb") as f: + base_model_sd = torch.load(f, weights_only=True) + lora_model.load_state_dict(lora_sd, strict=False) + lora_model.load_state_dict(base_model_sd, strict=False) + baseline_out = lora_model(inputs) + + # Load merged final ckpt directly into model and call fwd + with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: + sd = torch.load(f, weights_only=True) + model.load_state_dict(sd) + merged_ckpt_out = model(inputs) + + torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py index 80d2b2d767..ff03b1d3c4 100644 --- a/tests/torchtune/modules/peft/test_lora.py +++ b/tests/torchtune/modules/peft/test_lora.py @@ -14,14 +14,17 @@ from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4 from torchtune import training from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook -from torchtune.modules.peft import LoRALinear +from torchtune.modules.peft import LoRALinear, QATLoRALinear +from torchtune.training.quantization import _torchao_0_7_supported from torchtune.training.seed import set_seed + RANK = 4 ALPHA = 1.0 BSZ = 2 SEQ_LEN = 32 EXPECTED_VAL = 1.1252 +QAT_EXPECTED_VAL = 0.6291 @pytest.fixture(autouse=True) @@ -232,3 +235,12 @@ def test_quantized_state_dict(self, dtype): assert torch.allclose( lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data ) + + @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+") + def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None: + lora_linear = lora_linear(use_bias=True, dtype=torch.float32) + qat_lora_linear = QATLoRALinear.from_lora_linear(lora_linear) + expected = torch.tensor(QAT_EXPECTED_VAL) + actual = qat_lora_linear(inputs) + assert actual.shape == (BSZ, SEQ_LEN, out_dim) + torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 5bbf860482..daf84be1b7 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -456,6 +456,17 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="qat_lora_finetune_distributed", + file_path="qat_lora_finetune_distributed.py", + configs=[ + Config(name="llama3/8B_qat_lora", file_path="llama3/8B_qat_lora.yaml"), + Config(name="llama3_1/8B_qat_lora", file_path="llama3_1/8B_qat_lora.yaml"), + Config(name="llama3_2/1B_qat_lora", file_path="llama3_2/1B_qat_lora.yaml"), + Config(name="llama3_2/3B_qat_lora", file_path="llama3_2/3B_qat_lora.yaml"), + ], + supports_distributed=True, + ), Recipe( name="knowledge_distillation_single_device", file_path="knowledge_distillation_single_device.py", diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py index 165559df9c..2959bc3bb6 100644 --- a/torchtune/modules/peft/__init__.py +++ b/torchtune/modules/peft/__init__.py @@ -17,13 +17,14 @@ validate_missing_and_unexpected_for_lora, ) from .dora import DoRALinear -from .lora import LoRALinear +from .lora import LoRALinear, QATLoRALinear __all__ = [ + "AdapterModule", "DoRALinear", "LoRALinear", - "AdapterModule", + "QATLoRALinear", "get_adapter_params", "set_trainable_params", "validate_missing_and_unexpected_for_lora", diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py index 138dd0c5ee..e7397d0f07 100644 --- a/torchtune/modules/peft/lora.py +++ b/torchtune/modules/peft/lora.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import math -from typing import List +from typing import List, Optional import torch import torch.nn.functional as F @@ -131,6 +131,165 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return out + lora_out +class QATLoRALinear(LoRALinear): + """ + LoRA linear layer with quantization-aware training (QAT) applied to the + activations and/or weights before the low rank adapters. + + QAT leverages fake quantization to simulate the quantization numerics during + training without actually casting the data to lower precision. This class + combines LoRA with QAT to improve the final quantized accuracy during inference + while reducing the memory required during training. + + Args: + in_dim (int): input dimension + out_dim (int): output dimension + rank (int): rank of the low-rank approximation + alpha (float): scaling factor for the low-rank approximation + dropout (float): dropout probability. Default: 0.0 + activation_qat_config (Optional[FakeQuantizeConfig]): config for specifying + how input activations will be fake quantized, defaults to None + weight_qat_config (Optional[FakeQuantizeConfig]): config for specifying + how weights will be fake quantized, defaults to None + + Raises: + ValueError: If `in_dim` is not divisible by weight `group_size` + + Example usage:: + + activation_qat_config = FakeQuantizeConfig( + dtype=torch.int8, + granularity="per_token", + is_symmetric=False, + ) + weight_qat_config = FakeQuantizeConfig( + dtype=torch.int4, + group_size=8, + is_symmetric=True, + ) + qat_lora_linear = QATLoRALinear( + in_dim=512, + out_dim=1024, + rank=8, + alpha=16, + dropout=0.0, + activation_qat_config=activation_qat_config, + weight_qat_config=weight_qat_config, + ) + qat_lora_linear(torch.randn(512)) + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + rank: int, + alpha: float, + dropout: float = 0.0, + # fake quantize configs + # TODO: make the types Optional[FakeQuantizeConfig] once we + # support torchao 0.7+ by default + activation_qat_config: Optional["FakeQuantizeConfig"] = None, + weight_qat_config: Optional["FakeQuantizeConfig"] = None, + ): + super().__init__( + in_dim, + out_dim, + rank, + alpha, + dropout, + use_bias=False, + quantize_base=False, + ) + + try: + from torchao.quantization.qat.api import FakeQuantizeConfig + from torchao.quantization.qat.fake_quantizer import FakeQuantizer + except ImportError as err: + raise ValueError( + "QATLoRALinear is only compatible with torchao 0.7+" + ) from err + + # initialize activation fake quantizer + if activation_qat_config is not None: + assert isinstance(activation_qat_config, FakeQuantizeConfig) + self.activation_fake_quantizer = FakeQuantizer(activation_qat_config) + else: + self.activation_fake_quantizer = nn.Identity() + + # initialize weight fake quantizer + if weight_qat_config is not None: + assert isinstance(weight_qat_config, FakeQuantizeConfig) + group_size = weight_qat_config.group_size + if group_size is not None and in_dim % group_size != 0: + raise ValueError( + "in_dim (%s) must be divisible by group_size (%s)" + % (in_dim, group_size) + ) + self.weight_fake_quantizer = FakeQuantizer(weight_qat_config) + else: + self.weight_fake_quantizer = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input tensor with shape ``(..., in_dim)`` + + Returns: + torch.Tensor: output tensor with shape ``(..., out_dim)`` + + """ + _x = self.activation_fake_quantizer(x) + w = self.weight_fake_quantizer(self.weight) + out = F.linear(_x, w) + if self.disabled: + return out + lora_out = self.lora_a(self.dropout(x)) + lora_out = (self.alpha / self.rank) * self.lora_b(lora_out) + return out + lora_out + + @classmethod + def from_lora_linear( + cls, + lora_linear: LoRALinear, + # TODO: make the types Optional[FakeQuantizeConfig] once we + # support torchao 0.7+ by default + activation_qat_config: Optional["FakeQuantizeConfig"] = None, + weight_qat_config: Optional["FakeQuantizeConfig"] = None, + ) -> "QATLoRALinear": + """ + Create a `QATLoRALinear` from an existing `LoRALinear`, + preserving the weights and adapters. + """ + if lora_linear.bias is not None: + ValueError("Bias is not supported in QAT + LoRA yet") + if lora_linear._quantize_base: + ValueError("quantize_base is not compatible with QAT + LoRA") + if isinstance(lora_linear.dropout, nn.Dropout): + dropout = lora_linear.dropout.p + else: + dropout = 0.0 + new_linear = cls( + lora_linear.in_dim, + lora_linear.out_dim, + lora_linear.rank, + lora_linear.alpha, + dropout, + activation_qat_config, + weight_qat_config, + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if lora_linear.weight.device != torch.device("meta"): + new_linear.weight = lora_linear.weight + if lora_linear.lora_a.weight.device != torch.device("meta"): + new_linear.lora_a.weight = lora_linear.lora_a.weight + if lora_linear.lora_b.weight.device != torch.device("meta"): + new_linear.lora_b.weight = lora_linear.lora_b.weight + return new_linear + + def _lora_a_init_params(x: nn.Linear) -> None: """ Initialize LoRA A weight to Kaiming uniform. diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 7ff9315f41..4e21cb4936 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -7,6 +7,10 @@ from typing import Callable, Optional from warnings import warn +from torch import nn +from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear + + try: # torchao 0.7+ from torchao.dtypes import TensorCoreTiledLayout @@ -55,6 +59,12 @@ ] +_torchao_0_7_supported = True +try: + from torchao.quantization import qat # noqa: F401 +except ImportError: + _torchao_0_7_supported = False + _quantizer_to_mode = {} _quantizer_mode_to_disable_fake_quant = {} _quantizer_mode_to_enable_fake_quant = {} @@ -185,3 +195,45 @@ def _get_enable_fake_quant(quantizer_mode: str) -> Callable: If the quantizer is not recognized as a known QAT quantizer, return None. """ return _quantizer_mode_to_enable_fake_quant.get(quantizer_mode, None) + + +def swap_lora_linear_with_qat( + module: nn.Module, + # TODO: make the types Optional[FakeQuantizeConfig] once we + # support torchao 0.7+ by default + activation_qat_config: Optional["FakeQuantizeConfig"] = None, + weight_qat_config: Optional["FakeQuantizeConfig"] = None, +) -> None: + """ + Swap all `LoRALinear` in the model with `QATLoRALinear`. + + This is used for combining QAT + LoRA during finetuning. The resulting linear layers + will apply the following transformation instead: + + x -> fake_quantize(W_frozen) @ fake_quantize(x) + BAx + + Fake quantization here refers to simulating the quantization numerics without actual + dtype casting, with the goal of providing improved accuracies when the model is + ultimately quantized after finetuning. + + Args: + module (nn.Module): The model to swap linear layers on + activation_qat_config (Optional[FakeQuantizeConfig]): The config for specifying + how to fake quantize input activations in the base linear layer + weight_qat_config (Optional[FakeQuantizeConfig]): The config for specifying + how to fake quantize base linear weights + """ + for name, child in module.named_children(): + if isinstance(child, LoRALinear): + new_linear = QATLoRALinear.from_lora_linear( + child, + activation_qat_config, + weight_qat_config, + ) + setattr(module, name, new_linear) + else: + swap_lora_linear_with_qat( + child, + activation_qat_config, + weight_qat_config, + )