diff --git a/README.md b/README.md
index 8caa38c890..3e494dd757 100644
--- a/README.md
+++ b/README.md
@@ -67,7 +67,8 @@ torchtune provides the following finetuning recipes for training on one or more
| LoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py) | [Qwen2 0.5B single-device](recipes/configs/qwen2/0.5B_lora_single_device.yaml)
[Gemma 7B distributed](recipes/configs/gemma/7B_lora.yaml)
| QLoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Phi3 Mini single-device](recipes/configs/phi3/mini_qlora_single_device.yaml)
[Llama 3.1 405B distributed](recipes/configs/llama3_1/405B_qlora.yaml)
| DoRA/QDoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py)
[lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Llama3 8B QDoRA single-device](recipes/configs/llama3/8B_qdora_single_device.yaml)
[Llama3 8B DoRA distributed](recipes/configs/llama3/8B_dora.yaml)
-| Quantization-Aware Training | 4-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml)
+| Quantization-Aware Training | 2-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml)
+| Quantization-Aware Training and LoRA Finetuning | 2-8 | [qat_lora_finetune_distributed](recipes/qat_lora_finetune_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_lora.yaml)
| Direct Preference Optimization |1-8 | [lora_dpo_single_device](recipes/lora_dpo_single_device.py)
[lora_dpo_distributed](recipes/lora_dpo_distributed.py) | [Llama2 7B single-device](recipes/configs/llama2/7B_lora_dpo_single_device.yaml)
[Llama2 7B distributed](recipes/configs/llama2/7B_lora_dpo.yaml)
| Proximal Policy Optimization | 1 | [ppo_full_finetune_single_device](recipes/ppo_full_finetune_single_device.py) | [Mistral 7B](recipes/configs/mistral/7B_full_ppo_low_memory.yaml)
| Knowledge Distillation | 1 | [knowledge_distillation_single_device](recipes/knowledge_distillation_single_device.py) | [Qwen2 1.5B -> 0.5B](recipes/configs/qwen2/knowledge_distillation_single_device.yaml)
diff --git a/recipes/configs/llama3/8B_qat_lora.yaml b/recipes/configs/llama3/8B_qat_lora.yaml
new file mode 100644
index 0000000000..2104e6268e
--- /dev/null
+++ b/recipes/configs/llama3/8B_qat_lora.yaml
@@ -0,0 +1,113 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora checkpointer.checkpoint_dir=
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3.lora_llama3_8b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelMetaCheckpointer
+ checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
+ checkpoint_files: [
+ consolidated.00.pth
+ ]
+ recipe_checkpoint: null
+ output_dir: /tmp/Meta-Llama-3-8B-Instruct/
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 2
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase virtual batch size
+compile: False # pytorch compile, set to true for better perf/memory
+
+# Logging
+output_dir: /tmp/qat_lora_finetune_output
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/configs/llama3_1/8B_qat_lora.yaml b/recipes/configs/llama3_1/8B_qat_lora.yaml
new file mode 100644
index 0000000000..531d31fee9
--- /dev/null
+++ b/recipes/configs/llama3_1/8B_qat_lora.yaml
@@ -0,0 +1,116 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3.1 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora checkpointer.checkpoint_dir=
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_1.lora_llama3_1_8b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ apply_lora_to_output: False
+ lora_rank: 8 # higher increases accuracy and memory
+ lora_alpha: 16 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00004.safetensors,
+ model-00002-of-00004.safetensors,
+ model-00003-of-00004.safetensors,
+ model-00004-of-00004.safetensors
+ ]
+ recipe_checkpoint: null
+ output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
+ model_type: LLAMA3
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 2
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase virtual batch size
+compile: False # pytorch compile, set to true for better perf/memory
+
+# Logging
+output_dir: /tmp/qat_lora_finetune_output
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/configs/llama3_2/1B_qat_lora.yaml b/recipes/configs/llama3_2/1B_qat_lora.yaml
new file mode 100644
index 0000000000..8d68ef632c
--- /dev/null
+++ b/recipes/configs/llama3_2/1B_qat_lora.yaml
@@ -0,0 +1,112 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3.2 1B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora checkpointer.checkpoint_dir=
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_2.lora_llama3_2_1b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
+ checkpoint_files: [
+ model.safetensors
+ ]
+ recipe_checkpoint: null
+ output_dir: /tmp/Llama-3.2-1B-Instruct/
+ model_type: LLAMA3_2
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 4
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase virtual batch size
+compile: False # pytorch compile, set to true for better perf/memory
+
+# Logging
+output_dir: /tmp/qat_lora_finetune_output
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/configs/llama3_2/3B_qat_lora.yaml b/recipes/configs/llama3_2/3B_qat_lora.yaml
new file mode 100644
index 0000000000..2fac4e0fa1
--- /dev/null
+++ b/recipes/configs/llama3_2/3B_qat_lora.yaml
@@ -0,0 +1,113 @@
+# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
+# using a Llama3.2 3B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+# tune download meta-llama/Llama-3.2-3B-Instruct --output-dir /tmp/Llama-3.2-3B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/3B_qat_lora checkpointer.checkpoint_dir=
+
+# Tokenizer
+tokenizer:
+ _component_: torchtune.models.llama3.llama3_tokenizer
+ path: /tmp/Llama-3.2-3B-Instruct/original/tokenizer.model
+ max_seq_len: null
+
+# Model Arguments
+model:
+ _component_: torchtune.models.llama3_2.lora_llama3_2_3b
+ lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+ apply_lora_to_mlp: True
+ lora_rank: 64 # higher increases accuracy and memory
+ lora_alpha: 128 # usually alpha=2*rank
+ lora_dropout: 0.0
+
+checkpointer:
+ _component_: torchtune.training.FullModelHFCheckpointer
+ checkpoint_dir: /tmp/Llama-3.2-3B-Instruct/
+ checkpoint_files: [
+ model-00001-of-00002.safetensors,
+ model-00002-of-00002.safetensors,
+ ]
+ recipe_checkpoint: null
+ output_dir: /tmp/Llama-3.2-3B-Instruct/
+ model_type: LLAMA3_2
+resume_from_checkpoint: False
+save_adapter_weights_only: False
+
+# Dataset and Sampler
+dataset:
+ _component_: torchtune.datasets.alpaca_cleaned_dataset
+ packed: False # True increases speed
+seed: null
+shuffle: True
+batch_size: 4
+
+# Optimizer and Scheduler
+optimizer:
+ _component_: torch.optim.AdamW
+ fused: True
+ weight_decay: 0.01
+ lr: 3e-4
+lr_scheduler:
+ _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+ num_warmup_steps: 100
+
+loss:
+ _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+
+# Training
+epochs: 1
+max_steps_per_epoch: null
+gradient_accumulation_steps: 8 # Use to increase virtual batch size
+compile: False # pytorch compile, set to true for better perf/memory
+
+# Logging
+output_dir: /tmp/qat_lora_finetune_output
+metric_logger:
+ _component_: torchtune.training.metric_logging.DiskLogger
+ log_dir: ${output_dir}
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Environment
+device: cuda
+dtype: bf16
+enable_activation_checkpointing: False # True reduces memory
+enable_activation_offloading: False # True reduces memory
+
+# Profiler (disabled)
+profiler:
+ _component_: torchtune.training.setup_torch_profiler
+ enabled: False
+
+ #Output directory of trace artifacts
+ output_dir: ${output_dir}/profiling_outputs
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: True
+ cuda: True
+
+ #trace options passed to `torch.profiler.profile`
+ profile_memory: False
+ with_stack: False
+ record_shapes: True
+ with_flops: False
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: 5
+ warmup_steps: 3
+ active_steps: 2
+ num_cycles: 1
+
+# QAT arguments
+quantizer:
+ _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
+ groupsize: 256
diff --git a/recipes/qat_lora_finetune_distributed.py b/recipes/qat_lora_finetune_distributed.py
new file mode 100644
index 0000000000..f9b1fc991f
--- /dev/null
+++ b/recipes/qat_lora_finetune_distributed.py
@@ -0,0 +1,972 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import time
+
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple, Union
+from warnings import warn
+
+import torch
+from omegaconf import DictConfig, ListConfig
+
+from torch import nn
+from torch.distributed import destroy_process_group, init_process_group
+
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader, DistributedSampler
+from torchtune import config, modules, training, utils
+from torchtune.config._utils import _get_component_from_path
+from torchtune.data import padded_collate_packed
+from torchtune.datasets import ConcatDataset
+from torchtune.modules.peft import (
+ DoRALinear,
+ get_adapter_params,
+ get_adapter_state_dict,
+ get_lora_module_names,
+ get_merged_lora_ckpt,
+ LoRALinear,
+ set_trainable_params,
+ validate_missing_and_unexpected_for_lora,
+)
+from torchtune.recipe_interfaces import FTRecipeInterface
+from torchtune.training import DummyProfiler, PROFILER_KEY
+from torchtune.training.quantization import swap_lora_linear_with_qat
+
+from tqdm import tqdm
+
+log = utils.get_logger("DEBUG")
+
+
+class QATLoRAFinetuneRecipeDistributed(FTRecipeInterface):
+ """
+ Distributed quantization-aware training (QAT) and LoRA finetuning recipe for dense transformer-based
+ LLMs such as Llama2. This recipe supports distributed training and can be run on a single node (1 to
+ 8 GPUs). Only compatible with torchao 0.7+.
+
+ Features:
+ - Quantization-aware training (QAT). Perform fake quantization on weights and/or activations
+ during finetuning, with the goal of ultimately producing a quantized model with minimal
+ accuracy degradation. This recipe produces an unquantized model in the original dtype,
+ which can then be quantized separately.
+
+ - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
+ is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
+ done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
+ ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
+ DDP is currently not supported. Training on CPU is not supported.
+
+ - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
+ flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
+ activations in memory and instead recompute them during the backward pass. This is especially
+ helpful for larger batch sizes when you're memory constrained. But these savings in memory
+ come at the cost of training performance. In most cases training can slow-down quite a bit as
+ a result of this activation recomputation.
+
+ - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
+ flag. Activation offloading is a technique similar to activations checkpointing that helps
+ reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
+ checkpointing drops the activation in the forward to recompute it later in the backward,
+ activations offloading will drop the activation in the forward to the CPU and bring it
+ back during the backward pass. As always, there is a tradeoff--these savings in memory can
+ come at the cost of training performance and CPU resources. To recover some runtime cost,
+ we've added an option to enable offloading on a different stream to permit overlapping with
+ the computation. This option is currently only available on PyTorch 2.5.0 or later and will be
+ enabled by default if an acceptable torch version is found. Activation offloading can be used in
+ conjunction with activation checkpointing.
+
+ - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
+ flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
+ most cases this should halve the memory footprint of full precision (fp32) training, without
+ loss in model quality (will depend on the model, training data and other settings). For
+ GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
+ precision are currently not supported.
+
+ - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
+ controlled using the ``gradient_accumulation_steps`` flag.
+
+ Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
+
+ For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
+ total batch size of 64.
+
+ Gradient accumulation is especially useful when you are memory constrained. In this case,
+ accumulating gradients might give you better training speed than enabling activation
+ checkpointing.
+
+ - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
+ training. Currently we checkpoint both the adapter weights (trainable params only) and the
+ complete merged weights (adapter weights added back to the base model). For more details
+ please take a look at our LoRA tutorial
+ (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html).
+
+ Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are
+ only saved at the end of a given epoch and used in case of resuming training. Resuming
+ training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
+ currently not supported.
+
+ For more details on the checkpointer, please take a look at
+ our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html).
+
+ - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
+
+ - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
+ ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
+ ``clip_grad_norm='inf'``.
+
+ For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
+ has example commands for how to kick-off training.
+
+ Args:
+ cfg (DictConfig): OmegaConf object parsed from yaml file
+
+ Raises:
+ ValueError: If ``dtype`` is set to fp16.
+ ValueError: If world_size is 1
+ RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
+ RuntimeError: If ``left_pad_sequence`` is set as the data collator.
+ RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
+ RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
+ """
+
+ def __init__(self, cfg: DictConfig) -> None:
+ try:
+ from torchao.quantization import qat # noqa: F401
+ except ImportError as err:
+ raise ValueError(
+ "qat_lora_finetune_distributed is only compatible with torchao 0.7+"
+ ) from err
+
+ self._device = utils.get_device(device=cfg.device)
+ self._dtype = training.get_dtype(cfg.dtype, device=self._device)
+
+ if self._dtype == torch.float16:
+ raise ValueError(
+ "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
+ )
+
+ _, rank = training.get_world_size_and_rank()
+
+ # _is_rank_zero is used primarily for logging. In the future, the logger
+ # should directly take care of this
+ self._is_rank_zero = rank == 0
+
+ # logging attributes
+ self._output_dir = cfg.output_dir
+ self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
+ self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
+
+ if self._log_peak_memory_stats and self._device.type != "cuda":
+ log.info(
+ "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
+ )
+ self._log_peak_memory_stats = False
+
+ # These attributes constitute the recipe state and are updated by ``load_checkpoint``
+ # when ``resume_from_checkpoint`` is ``True``
+ self.seed = training.set_seed(seed=cfg.seed)
+ self.epochs_run = 0
+ self.total_epochs = cfg.epochs
+ self.max_steps_per_epoch = cfg.max_steps_per_epoch
+ self.global_step = 0
+ self._clip_grad_norm = cfg.get("clip_grad_norm", None)
+
+ self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
+ self._resume_from_checkpoint = cfg.resume_from_checkpoint
+ self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
+
+ # activation checkpointing/offloading
+ self._enable_activation_checkpointing = cfg.get(
+ "enable_activation_checkpointing", False
+ )
+ self._enable_activation_offloading = cfg.get(
+ "enable_activation_offloading", False
+ )
+ if self._enable_activation_offloading:
+ if self._device.type != "cuda":
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when training on CUDA"
+ )
+ if not self._enable_activation_checkpointing:
+ raise RuntimeError(
+ "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
+ )
+ elif (
+ self._enable_activation_checkpointing
+ and cfg.checkpointer.model_type != "LLAMA3_VISION"
+ ):
+ utils.log_rank_zero(
+ log,
+ "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
+ "Enabling activation offloading should reduce memory further.",
+ )
+
+ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
+ """
+ Extract the checkpoint state from file and validate. This includes the
+ base model weights. If resume_from_checkpoint is True, this also includes
+ the adapter weights and recipe state
+ """
+ self._checkpointer = config.instantiate(
+ cfg_checkpointer,
+ resume_from_checkpoint=self._resume_from_checkpoint,
+ )
+ checkpoint_dict = self._checkpointer.load_checkpoint()
+
+ # When resuming from checkpoint for LoRA, the recipe expects the adapter weights
+ # and recipe state to be present. The keys should match up with what ``save_checkpoint``
+ # used to create these intermediate checkpoints
+ if self._resume_from_checkpoint:
+ if training.ADAPTER_KEY not in checkpoint_dict:
+ raise ValueError(
+ "Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
+ )
+ # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded
+ # no need to check here
+ self._update_recipe_state(checkpoint_dict)
+ return checkpoint_dict
+
+ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
+ """
+ Updates the recipe state from checkpoint.
+ """
+ try:
+ self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
+
+ # on mismatch, warn the user and prevent the override
+ if self.seed != ckpt_dict[training.SEED_KEY]:
+ warn(
+ message=(
+ "Config value for seed does not match the checkpoint value, "
+ f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
+ )
+ )
+ self.seed = ckpt_dict[training.SEED_KEY]
+ if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
+ warn(
+ message=(
+ "Config value for max_steps_per_epoch does not match the checkpoint value, "
+ f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
+ )
+ )
+ self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
+
+ # on mismatch, warn the user but allow the override
+ if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
+ warn(
+ message=(
+ "Config value for total_epochs does not match the checkpoint value, "
+ f"using the config value: {self.total_epochs}"
+ )
+ )
+
+ except KeyError as e:
+ raise KeyError(
+ "Checkpoint does not contain the required keys needed for updating recipe state. "
+ "Are you sure you passed in the right recipe checkpoint?"
+ ) from e
+
+ def setup(self, cfg: DictConfig) -> None:
+ """
+ Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True),
+ model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader.
+ """
+ if self._is_rank_zero:
+ self._metric_logger = config.instantiate(cfg.metric_logger)
+
+ # log config with parameter override
+ self._metric_logger.log_config(cfg)
+
+ checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
+ self._compile = cfg.get("compile", False)
+
+ self._model = self._setup_model(
+ cfg_model=cfg.model,
+ enable_activation_checkpointing=self._enable_activation_checkpointing,
+ enable_activation_offloading=self._enable_activation_offloading,
+ fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
+ reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
+ base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
+ lora_weights_state_dict=(
+ checkpoint_dict[training.ADAPTER_KEY]
+ if self._resume_from_checkpoint
+ else None
+ ),
+ quantizer_cfg=cfg.get("quantizer", None),
+ )
+ self._tokenizer = config.instantiate(cfg.tokenizer)
+
+ self._optimizer = self._setup_optimizer(
+ cfg_optimizer=cfg.optimizer,
+ opt_state_dict=(
+ checkpoint_dict[training.OPT_KEY]
+ if self._resume_from_checkpoint
+ else None
+ ),
+ )
+
+ # initialize loss
+ self._loss_fn = config.instantiate(cfg.loss)
+
+ if self._compile:
+ training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
+
+ if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
+ # set num_output_chunks for model
+ self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
+ if self._is_rank_zero:
+ log.info("Loss is initialized.")
+
+ # sampler and dataloader depend on the tokenizer and loss_fn and should be
+ # setup after all of these are setup
+ collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
+ self._sampler, self._dataloader = self._setup_data(
+ cfg_dataset=cfg.dataset,
+ shuffle=cfg.shuffle,
+ batch_size=cfg.batch_size,
+ collate_fn=collate_name,
+ )
+
+ # Finally update the recipe state which can only be correctly set after all of the
+ # other components have been initialized and updated.
+
+ # Number of training steps in each epoch depends on the number of batches produced
+ # by the dataloader and the max_steps_per_epoch param set by the user and is used
+ # for logging and tracking training state. This should be computed after the dataloader
+ # has been setup
+ self._steps_per_epoch = (
+ len(self._dataloader) // self._gradient_accumulation_steps
+ )
+ if (
+ self.max_steps_per_epoch is not None
+ and self.max_steps_per_epoch < self._steps_per_epoch
+ ):
+ self._steps_per_epoch = self.max_steps_per_epoch
+ self.global_step = self.epochs_run * self._steps_per_epoch
+
+ # Learning rate scheduler can only be set up after number of steps
+ # has been computed
+ self._lr_scheduler = self._setup_lr_scheduler(
+ cfg_lr_scheduler=cfg.lr_scheduler,
+ num_training_steps=self.total_epochs * self._steps_per_epoch,
+ last_epoch=self.global_step - 1,
+ )
+
+ # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
+ # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
+ self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
+
+ # Used to ignore labels for loss computation
+ self.ignore_labels_cache = torch.full(
+ (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
+ )
+
+ def _setup_profiler(
+ self, cfg_profiler: Optional[DictConfig] = None
+ ) -> Union[torch.profiler.profile, DummyProfiler]:
+ """
+ Parses the `profiler` section of top-level `cfg` and sets up profiler
+
+ Args:
+ cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
+ `recipe.main`). Default None.
+
+ Returns:
+ profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
+ for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
+ that the instrumented training loop does not need to be changed profiling is disabled.
+
+ The profiler config can be provided in configs under the `profiler` key with the following layout:
+
+ .. code-block:: yaml
+ profiler:
+ enabled: bool
+
+ #Output directory of trace artifacts
+ output_dir: str
+
+ #`torch.profiler.ProfilerActivity` types to trace
+ cpu: bool
+ cuda: bool
+
+ #Trace options
+ profile_memory: bool
+ with_stack: bool
+ record_shapes: bool
+ with_flops: bool
+
+ # `torch.profiler.schedule` options:
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+ wait_steps: int
+ warmup_steps: int
+ active_steps: int
+ num_cycles: int
+ """
+ # Missing profiler section in config, assume disabled
+ if cfg_profiler is None:
+ cfg_profiler = DictConfig({"enabled": False})
+
+ # Check that component is included and set correctly
+ if cfg_profiler.get("_component_", None) is None:
+ cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
+ else:
+ assert (
+ cfg_profiler.get("_component_")
+ == "torchtune.training.setup_torch_profiler"
+ ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
+
+ profiler, profiler_cfg = config.instantiate(cfg_profiler)
+
+ if self._is_rank_zero:
+ log.info(f" Profiler config after instantiation: {profiler_cfg}")
+
+ self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
+ if profiler_cfg["enabled"]:
+ self.profiler_wait_steps = profiler_cfg["wait_steps"]
+ self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
+ self.profiler_active_steps = profiler_cfg["active_steps"]
+
+ return profiler
+
+ def _convert_model_to_qat(self, model: nn.Module, quantizer_cfg: DictConfig):
+ """
+ Convert the model to support quantization-aware training during fine-tuning.
+ """
+ for name, child in model.named_modules():
+ if isinstance(child, DoRALinear):
+ raise ValueError("QAT is currently not compatible with DoRA")
+ quantizer = config.instantiate(quantizer_cfg)
+ quantizer.precision = self._dtype
+ quantizer_mode = training.quantization.get_quantizer_mode(quantizer)
+ if "qat" not in quantizer_mode:
+ raise ValueError(
+ "Quantizer mode '%s' is not supported for finetuning" % quantizer_mode
+ )
+ activation_config = quantizer.get_activation_fake_quantize_config()
+ weight_config = quantizer.get_weight_fake_quantize_config()
+ swap_lora_linear_with_qat(model, activation_config, weight_config)
+
+ def _setup_model(
+ self,
+ cfg_model: DictConfig,
+ enable_activation_checkpointing: bool,
+ enable_activation_offloading: bool,
+ fsdp_cpu_offload: bool,
+ reshard_after_forward: bool,
+ base_model_state_dict: Dict[str, Any],
+ custom_sharded_layers: Optional[List[str]] = None,
+ lora_weights_state_dict: Optional[Dict[str, Any]] = None,
+ quantizer_cfg: Optional[DictConfig] = None,
+ ) -> nn.Module:
+ """
+ Model initialization has some important considerations:
+ a. To minimize GPU peak memory, we initialize the model on meta device with
+ the right dtype
+ b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
+ full state dicts are loaded with ``torch.load(mmap=True)``
+ c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module`
+ """
+
+ self._lora_rank = cfg_model.lora_rank
+ self._lora_alpha = cfg_model.lora_alpha
+ self._lora_attn_modules = list(cfg_model.lora_attn_modules)
+ self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp
+ self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False)
+
+ if self._is_rank_zero:
+ log.info(
+ "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
+ )
+ init_start = time.perf_counter()
+
+ if quantizer_cfg is None:
+ raise ValueError("Quantizer must be specified for QAT + LoRA finetuning")
+
+ with training.set_default_dtype(self._dtype), torch.device("meta"):
+ model = config.instantiate(cfg_model)
+ self._convert_model_to_qat(model, quantizer_cfg)
+
+ set_trainable_params(model, get_adapter_params(model))
+
+ if self._compile:
+ training.compile_model(model, verbose=self._is_rank_zero)
+
+ if enable_activation_checkpointing:
+ training.set_activation_checkpointing(
+ model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
+ )
+
+ # For FSDP sharding
+ fsdp_shard_conditions = [
+ partial(
+ training.get_shard_conditions,
+ names_to_match=custom_sharded_layers,
+ )
+ ]
+ training.shard_model(
+ model=model,
+ shard_conditions=fsdp_shard_conditions,
+ cpu_offload=fsdp_cpu_offload,
+ reshard_after_forward=reshard_after_forward,
+ )
+
+ if lora_weights_state_dict:
+ lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
+ model,
+ lora_weights_state_dict,
+ self._device,
+ self._is_rank_zero,
+ cpu_offload=fsdp_cpu_offload,
+ )
+ else:
+ lora_missing, lora_unexpected = None, None
+
+ # Initialize LoRA params and RoPE buffers
+ with training.set_default_dtype(self._dtype), self._device:
+ lora_device = "cpu" if fsdp_cpu_offload else self._device
+ for m in model.modules():
+ if (
+ isinstance(m, LoRALinear) or isinstance(m, DoRALinear)
+ ) and not lora_weights_state_dict:
+ # lora may not be covered in state dict
+ # if finetune for the 1st time
+ m.lora_a.to_empty(device=lora_device)
+ m.lora_b.to_empty(device=lora_device)
+ m.initialize_parameters()
+ # RoPE is not covered in state dict
+ if hasattr(m, "rope_init"):
+ m.rope_init()
+
+ base_missing, base_unexpected = training.load_from_full_model_state_dict(
+ model,
+ base_model_state_dict,
+ self._device,
+ self._is_rank_zero,
+ cpu_offload=fsdp_cpu_offload,
+ )
+ validate_missing_and_unexpected_for_lora(
+ lora_attn_modules=self._lora_attn_modules,
+ apply_lora_to_mlp=self._apply_lora_to_mlp,
+ apply_lora_to_output=self._apply_lora_to_output,
+ base_missing=base_missing,
+ base_unexpected=base_unexpected,
+ lora_missing=lora_missing,
+ lora_unexpected=lora_unexpected,
+ )
+ # Ensure no params and buffers are on meta device
+ training.validate_no_params_on_meta_device(model)
+
+ # activation offloading
+ self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
+ model, enable_activation_offloading
+ )
+
+ # log
+ if self._is_rank_zero:
+ log.info(
+ f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
+ )
+ memory_stats = training.get_memory_stats(device=self._device)
+ training.log_memory_stats(memory_stats)
+
+ # synchronize before training begins
+ torch.distributed.barrier()
+
+ return model
+
+ def _setup_optimizer(
+ self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
+ ) -> Optimizer:
+ optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
+ if opt_state_dict:
+ training.load_from_full_optimizer_state_dict(
+ optimizer,
+ opt_state_dict,
+ self._device,
+ )
+
+ if self._is_rank_zero:
+ log.info("Optimizer is initialized.")
+ return optimizer
+
+ def _setup_lr_scheduler(
+ self,
+ cfg_lr_scheduler: DictConfig,
+ num_training_steps: int,
+ last_epoch: int,
+ ) -> Optimizer:
+ lr_scheduler = config.instantiate(
+ cfg_lr_scheduler,
+ self._optimizer,
+ num_training_steps=num_training_steps,
+ last_epoch=last_epoch,
+ )
+ if self._is_rank_zero:
+ log.info("Learning rate scheduler is initialized.")
+ return lr_scheduler
+
+ def _setup_data(
+ self,
+ cfg_dataset: DictConfig,
+ shuffle: bool,
+ batch_size: int,
+ collate_fn: str,
+ ) -> Tuple[DistributedSampler, DataLoader]:
+ """
+ All data related setup happens here. Currently this recipe only supports the
+ DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
+ iterable datasets and streaming datasets are not supported.
+ """
+ world_size, rank = training.get_world_size_and_rank()
+
+ if isinstance(cfg_dataset, ListConfig):
+ datasets = [
+ config.instantiate(single_cfg_dataset, self._tokenizer)
+ for single_cfg_dataset in cfg_dataset
+ ]
+ ds = ConcatDataset(datasets=datasets)
+ packed = False
+ else:
+ ds = config.instantiate(cfg_dataset, self._tokenizer)
+ packed = cfg_dataset.get("packed", False)
+
+ # Instantiate collate_fn
+ if "left_pad_sequence" in collate_fn:
+ raise RuntimeError("left_pad_sequence collator is only for inference.")
+ collate_fn = _get_component_from_path(collate_fn)
+
+ sampler = DistributedSampler(
+ ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
+ )
+
+ dataloader = DataLoader(
+ dataset=ds,
+ batch_size=batch_size,
+ sampler=sampler,
+ # dropping last avoids shape issues with compile + flex attention
+ drop_last=True,
+ collate_fn=(
+ partial(
+ collate_fn,
+ padding_idx=self._tokenizer.pad_id,
+ ignore_idx=self._loss_fn.ignore_index,
+ )
+ if not packed
+ else padded_collate_packed
+ ),
+ )
+
+ if self._is_rank_zero:
+ log.info("Dataset and Sampler are initialized.")
+
+ return sampler, dataloader
+
+ def save_checkpoint(
+ self,
+ epoch: int,
+ ) -> None:
+ """
+ Checkpoint the state of the recipe. The constructed checkpoint state dict
+ contains the following information:
+ - Merged weights with key MODEL_KEY
+ - Adapter weights with key ADAPTER_KEY
+ - Relevant recipe state if training is not complete
+ - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
+
+ Checkpointer will save the merged weights, adapter weights and recipe state in
+ different checkpoint files. To correctly resume from training, the adapter weights
+ and recipe state must be provided along with the base model weights.
+ """
+ # final dict passed onto the checkpointer
+ checkpoint_dict = {}
+
+ intermediate_checkpoint = epoch + 1 < self.total_epochs
+
+ if self._is_rank_zero:
+ log.info(
+ "Saving checkpoint. This may take some time. Retrieving full model state dict..."
+ )
+ start = time.perf_counter()
+
+ # To prevent GPU memory from spiking during checkpoint save,
+ # we consolidate the full model and optim state dicts on CPU for rank 0
+ state_dict = self._model.state_dict()
+ if self._save_adapter_weights_only:
+ state_dict = get_adapter_state_dict(state_dict, device=None)
+
+ cpu_state_dict = training.gather_cpu_state_dict(
+ state_dict,
+ self._is_rank_zero,
+ device=self._device,
+ )
+ if self._is_rank_zero:
+ log.info(
+ f"Getting full model state dict took {time.perf_counter() - start:.2f} secs"
+ )
+
+ if intermediate_checkpoint:
+ if self._is_rank_zero:
+ log.info("Retrieving optimizer state dict...")
+ opt_state_dict = training.get_full_optimizer_state_dict(
+ self._optimizer,
+ self._is_rank_zero,
+ device=self._device,
+ )
+ if self._is_rank_zero:
+ log.info(
+ f"Getting optimizer state dict took {time.perf_counter() - start:.2f} secs"
+ )
+ else:
+ opt_state_dict = None
+
+ # Now that we have the model and opt state dict, create the actual checkpoint dict
+ # to be sent to the checkpointer and ultimately written to file
+ if self._is_rank_zero:
+ start = time.perf_counter()
+
+ if self._save_adapter_weights_only:
+ adapter_state_dict = cpu_state_dict
+ else:
+ # Filter out the adapter keys and weights from the model state dict. These will
+ # be saved separately
+ adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
+
+ # merge the adapter weights and base weights to create the model checkpoint
+ merged_state_dict = get_merged_lora_ckpt(
+ cpu_state_dict,
+ rank=self._lora_rank,
+ alpha=self._lora_alpha,
+ )
+ checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
+ checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
+
+ # if training is in-progress, checkpoint the optimizer state and recipe state
+ # as well.
+ if intermediate_checkpoint:
+ checkpoint_dict.update(
+ {
+ training.OPT_KEY: opt_state_dict,
+ training.SEED_KEY: self.seed,
+ training.EPOCHS_KEY: self.epochs_run,
+ training.TOTAL_EPOCHS_KEY: self.total_epochs,
+ training.MAX_STEPS_KEY: self.max_steps_per_epoch,
+ }
+ )
+
+ adapter_config = {
+ "r": self._lora_rank,
+ "lora_alpha": self._lora_alpha,
+ "target_modules": get_lora_module_names(
+ self._lora_attn_modules,
+ self._apply_lora_to_mlp,
+ self._apply_lora_to_output,
+ ),
+ "peft_type": "LORA",
+ }
+ checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config})
+ self._checkpointer.save_checkpoint(
+ checkpoint_dict,
+ epoch=epoch,
+ intermediate_checkpoint=intermediate_checkpoint,
+ adapter_only=self._save_adapter_weights_only,
+ )
+ log.info(f"Saving checkpoint took {time.perf_counter() - start:.2f} secs")
+
+ torch.distributed.barrier()
+
+ def train(self) -> None:
+ """
+ The core training loop.
+ """
+ # clean up before training begins
+ training.cleanup_before_training()
+
+ world_size, rank = training.get_world_size_and_rank()
+
+ # zero out the gradients before starting training
+ self._optimizer.zero_grad()
+
+ # Initialize tokens count and running loss (for grad accumulation)
+ t0 = time.perf_counter()
+ running_loss = 0
+ num_tokens = 0
+
+ self._profiler.start()
+ # self.epochs_run should be non-zero when we're resuming from a checkpoint
+ for curr_epoch in range(self.epochs_run, self.total_epochs):
+
+ # Update the sampler to ensure data is correctly shuffled across epochs
+ # in case shuffle is True
+ self._sampler.set_epoch(curr_epoch)
+
+ pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
+ for idx, batch in enumerate(self._dataloader):
+ if (
+ self.max_steps_per_epoch is not None
+ and (idx // self._gradient_accumulation_steps)
+ == self.max_steps_per_epoch
+ ):
+ break
+
+ # Start tracking CUDA memory for active steps for just the first epoch
+ if (
+ self._is_rank_zero
+ and curr_epoch == 0
+ and self.profiler_profile_memory
+ and idx == self.profiler_wait_steps + self.profiler_warmup_steps
+ ):
+ torch.cuda.memory._record_memory_history()
+
+ utils.batch_to_device(batch, self._device)
+
+ # Calculate the number of unmasked tokens in the current batch
+ # and increment the total number of tokens seen in the step
+ current_num_tokens = (
+ batch["labels"] != self._loss_fn.ignore_index
+ ).sum()
+ num_tokens += current_num_tokens
+
+ # Shape [b, s], needed for the loss not the model
+ labels = batch.pop("labels")
+
+ with self.activations_handling_ctx:
+ logits = self._model(**batch)
+
+ # Shift labels to compute loss
+ # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
+ # But this way we dont need to slice the logits. We just add an ignore index to labels.
+ labels = torch.hstack(
+ (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
+ )
+ if not isinstance(logits, list):
+ labels = labels.reshape(-1)
+ logits = logits.reshape(-1, logits.size(-1))
+
+ # Compute loss
+ # Loss is normalized by default so we multiply by the number of tokens
+ # This way we can normalize by the total number of tokens if we're accumulating gradients
+ current_loss = self._loss_fn(logits, labels) * current_num_tokens
+
+ # free logits otherwise it peaks backward memory
+ del logits
+
+ running_loss += current_loss
+ current_loss.backward()
+
+ # Step with optimizer
+ if (idx + 1) % self._gradient_accumulation_steps == 0:
+ # Get total number of tokens across all ranks to normalize gradients
+ torch.distributed.all_reduce(num_tokens)
+ # This will ensure that the logged loss matches what we're optimizing
+ torch.distributed.all_reduce(running_loss)
+ # Manually scale the gradients from unnormalized loss by total # of tokens
+ training.scale_grads(self._model, 1 / num_tokens)
+ if self._clip_grad_norm is not None:
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ self._model.parameters(),
+ max_norm=float(self._clip_grad_norm),
+ )
+ self._optimizer.step()
+ self._optimizer.zero_grad(set_to_none=True)
+ self._lr_scheduler.step()
+
+ # Update the number of steps when the weights are updated
+ self.global_step += 1
+
+ loss_to_log = running_loss.item() / num_tokens
+ pbar.update(1)
+ pbar.set_description(
+ f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
+ )
+
+ # Log per-step metrics
+ if (
+ self.global_step % self._log_every_n_steps == 0
+ and self._is_rank_zero
+ ):
+ time_per_step = time.perf_counter() - t0
+ log_dict = {
+ "loss": loss_to_log,
+ "lr": self._optimizer.param_groups[0]["lr"],
+ "tokens_per_second_per_gpu": num_tokens
+ / (time_per_step * world_size),
+ }
+ if self._log_peak_memory_stats:
+ log_dict.update(
+ training.get_memory_stats(device=self._device)
+ )
+
+ if self._clip_grad_norm is not None:
+ log_dict.update({"grad_norm": grad_norm})
+ self._metric_logger.log_dict(
+ log_dict,
+ step=self.global_step,
+ )
+
+ # Reset running stats for the next step
+ running_loss = 0
+ num_tokens = 0
+ t0 = time.perf_counter()
+
+ # Stop tracking CUDA memory now that active steps are complete
+ if (
+ self._is_rank_zero
+ and curr_epoch == 0
+ and self.profiler_profile_memory
+ and idx
+ == self.profiler_wait_steps
+ + self.profiler_warmup_steps
+ + self.profiler_active_steps
+ ):
+ torch.cuda.memory._record_memory_history(enabled=None)
+
+ # Step profiler
+ # Note that this is called within gradient accumulation block, hence
+ # will include multiple forward / backward passes if gradient accumulation > 1
+ self._profiler.step()
+
+ self.epochs_run += 1
+ self.save_checkpoint(epoch=curr_epoch)
+
+ self._profiler.stop()
+
+ def cleanup(self) -> None:
+ if self._is_rank_zero:
+ self._metric_logger.close()
+ destroy_process_group()
+
+
+@config.parse
+def recipe_main(cfg: DictConfig) -> None:
+ """
+ Entry point for the recipe.
+
+ Configurable parameters are read in the following order:
+ - Parameters specified in config (see available configs through ``tune ls``)
+ - Overwritten by arguments from the command-line
+ """
+ if not training.is_distributed():
+ raise RuntimeError(
+ "Distributed finetune recipe should be run via a distributed launcher."
+ "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
+ )
+ if cfg.get("fsdp_cpu_offload", False):
+ # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
+ # speed up when benchmarking fused AdamW on CPU
+ training.set_torch_num_threads()
+ init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
+
+ config.log_config(recipe_name="QATLoRAFinetuneRecipeDistributed", cfg=cfg)
+
+ recipe = QATLoRAFinetuneRecipeDistributed(cfg=cfg)
+ recipe.setup(cfg=cfg)
+ recipe.train()
+ recipe.cleanup()
+
+
+if __name__ == "__main__":
+ sys.exit(recipe_main())
diff --git a/tests/recipes/test_qat_lora_finetune_distributed.py b/tests/recipes/test_qat_lora_finetune_distributed.py
new file mode 100644
index 0000000000..5be3a2379a
--- /dev/null
+++ b/tests/recipes/test_qat_lora_finetune_distributed.py
@@ -0,0 +1,266 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import runpy
+import sys
+from pathlib import Path
+
+import pytest
+import torch
+from omegaconf import OmegaConf
+from tests.common import TUNE_PATH
+from tests.recipes.utils import (
+ CKPT_COMPONENT_MAP,
+ dummy_alpaca_dataset_config,
+ MODEL_TEST_CONFIGS,
+ write_hf_ckpt_config,
+)
+from tests.test_utils import (
+ CKPT_MODEL_PATHS,
+ gen_log_file_name,
+ get_loss_values_from_metric_logger,
+ gpu_test,
+ TOKENIZER_PATHS,
+)
+from torchtune import config
+from torchtune.training.quantization import _torchao_0_7_supported
+
+
+class TestQATLoRAFinetuneDistributedRecipe:
+ def _get_test_config_overrides(self):
+ return [
+ "dataset.train_on_input=False",
+ "seed=9",
+ "epochs=2",
+ "dtype=fp32",
+ "max_steps_per_epoch=2",
+ "optimizer.lr=2e-5",
+ "log_every_n_steps=1",
+ "compile=False",
+ ] + dummy_alpaca_dataset_config()
+
+ def _fetch_expected_loss_values(self, model_type):
+ loss_values_map = {
+ "llama3": [11.9325, 11.9325, 11.9325, 11.9369],
+ }
+ return loss_values_map[model_type]
+
+ @pytest.mark.integration_test
+ @gpu_test(gpu_count=2)
+ @pytest.mark.parametrize(
+ "micro_batch_size, gradient_accumulation_steps, should_compile",
+ [(4, 1, True), (1, 4, False)],
+ )
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_loss(
+ self,
+ micro_batch_size,
+ gradient_accumulation_steps,
+ should_compile,
+ tmpdir,
+ monkeypatch,
+ ):
+ ckpt = "llama3_tune"
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ ckpt_dir = ckpt_path.parent
+ log_file = gen_log_file_name(tmpdir)
+ cmd = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed
+ --config llama3/8B_qat_lora \
+ batch_size={micro_batch_size} \
+ gradient_accumulation_steps={gradient_accumulation_steps} \
+ output_dir={tmpdir} \
+ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type=LLAMA3 \
+ metric_logger.filename={log_file} \
+ tokenizer.path=/tmp/test-artifacts/tokenizer.model \
+ tokenizer.prompt_template=null \
+ compile={should_compile} \
+ enable_activation_checkpointing=False \
+ enable_activation_offloading=False \
+ """.split()
+
+ model_config = MODEL_TEST_CONFIGS["llama3_lora"]
+
+ cmd = cmd + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+ loss_values = get_loss_values_from_metric_logger(log_file)
+ expected_loss_values = self._fetch_expected_loss_values("llama3")
+ torch.testing.assert_close(
+ loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
+ )
+
+ @pytest.mark.integration_test
+ @gpu_test(gpu_count=2)
+ @pytest.mark.parametrize(
+ "config, model_type, ckpt_type, save_adapter_weights_only",
+ [
+ ("llama3/8B_qat_lora", "llama3", "tune", False),
+ ],
+ )
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_training_state_on_resume(
+ self,
+ config,
+ model_type,
+ ckpt_type,
+ tmpdir,
+ monkeypatch,
+ save_adapter_weights_only,
+ ):
+ """Test whether the recipe state is correctly updated on resume. Since this
+ is model agnostic, we should run this on the small model only. The test
+ consists of three stages:
+ - Train a model for 2 epochs
+ - Resume training after epoch 1
+ - Make sure final loss matches the expected value of a model successfully resumed from a ckpt
+ """
+ ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
+ ckpt = model_type + "_" + ckpt_type
+ expected_loss_values = self._fetch_expected_loss_values(model_type)
+
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ tokenizer_path = Path(TOKENIZER_PATHS[model_type])
+ ckpt_dir = ckpt_path.parent
+ log_file = gen_log_file_name(tmpdir)
+
+ # Config file needed for model conversion.
+ # Create a second copy for training resume
+ write_hf_ckpt_config(ckpt_dir)
+ write_hf_ckpt_config(tmpdir)
+
+ # Train for two epochs
+ cmd_1 = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \
+ --config {config} \
+ batch_size=4 \
+ gradient_accumulation_steps=1 \
+ output_dir={tmpdir} \
+ checkpointer._component_={ckpt_component} \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type={model_type.upper()} \
+ tokenizer.path='{tokenizer_path}' \
+ tokenizer.prompt_template=null \
+ save_adapter_weights_only={save_adapter_weights_only} \
+ enable_activation_checkpointing=True \
+ enable_activation_offloading=True \
+ """.split()
+
+ model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
+
+ cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd_1)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ # Resume training
+ cmd_2 = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \
+ --config {config} \
+ batch_size=4 \
+ gradient_accumulation_steps=1 \
+ output_dir={tmpdir} \
+ checkpointer._component_={ckpt_component} \
+ checkpointer.checkpoint_dir={tmpdir} \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")}
+ checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type={model_type.upper()} \
+ tokenizer.path='{tokenizer_path}' \
+ tokenizer.prompt_template=null \
+ resume_from_checkpoint=True \
+ metric_logger.filename={log_file} \
+ enable_activation_checkpointing=True \
+ enable_activation_offloading=True \
+ """.split()
+
+ cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd_2)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ expected_loss_values = self._fetch_expected_loss_values(model_type)[2:]
+
+ loss_values = get_loss_values_from_metric_logger(log_file)
+ torch.testing.assert_close(
+ loss_values, expected_loss_values, rtol=1e-5, atol=1e-5
+ )
+
+ @pytest.mark.integration_test
+ @pytest.mark.parametrize(
+ "recipe_config, model_type, ckpt_type",
+ [
+ ("llama3/8B_qat_lora", "llama3", "tune"),
+ ],
+ )
+ @gpu_test(gpu_count=2)
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_save_and_load_merged_weights(
+ self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch
+ ):
+ ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
+ ckpt = model_type + "_" + ckpt_type
+ ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
+ tokenizer_path = Path(TOKENIZER_PATHS[model_type])
+ ckpt_dir = ckpt_path.parent
+ cmd = f"""
+ tune run --nnodes 1 --nproc_per_node 2 qat_lora_finetune_distributed \
+ --config {recipe_config} \
+ batch_size=4 \
+ gradient_accumulation_steps=1 \
+ output_dir={tmpdir} \
+ model=torchtune.models.lora_small_test_model \
+ checkpointer._component_={ckpt_component} \
+ checkpointer.checkpoint_dir='{ckpt_dir}' \
+ checkpointer.checkpoint_files=[{ckpt_path}]\
+ checkpointer.output_dir={tmpdir} \
+ checkpointer.model_type={model_type.upper()} \
+ tokenizer.path='{tokenizer_path}' \
+ tokenizer.prompt_template=null \
+ enable_activation_checkpointing=True \
+ enable_activation_offloading=True \
+ """.split()
+
+ model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
+
+ cmd = cmd + self._get_test_config_overrides() + model_config
+ monkeypatch.setattr(sys, "argv", cmd)
+ runpy.run_path(TUNE_PATH, run_name="__main__")
+
+ # Next load both the merged weights in a base model
+ # and the base model weights + trained adapter weights in the LoRA model
+ # The results of calling forward on dummy inputs should be the same.
+ inputs = torch.randint(low=0, high=32_000, size=(2, 100))
+
+ # Build LoRA model for loading base + adapter weights separately
+ lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model)
+
+ # Build base model for loading merged weights
+ base_config = MODEL_TEST_CONFIGS[model_type]
+ model = config.instantiate(OmegaConf.from_dotlist(base_config).model)
+
+ # Load base model and trained adapter weights into LoRA model and call fwd
+ with open(f"{tmpdir}/adapter_1.pt", "rb") as f:
+ lora_sd = torch.load(f, weights_only=True)
+ with open(ckpt_path, "rb") as f:
+ base_model_sd = torch.load(f, weights_only=True)
+ lora_model.load_state_dict(lora_sd, strict=False)
+ lora_model.load_state_dict(base_model_sd, strict=False)
+ baseline_out = lora_model(inputs)
+
+ # Load merged final ckpt directly into model and call fwd
+ with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f:
+ sd = torch.load(f, weights_only=True)
+ model.load_state_dict(sd)
+ merged_ckpt_out = model(inputs)
+
+ torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5)
diff --git a/tests/torchtune/modules/peft/test_lora.py b/tests/torchtune/modules/peft/test_lora.py
index 80d2b2d767..ff03b1d3c4 100644
--- a/tests/torchtune/modules/peft/test_lora.py
+++ b/tests/torchtune/modules/peft/test_lora.py
@@ -14,14 +14,17 @@
from torchao.dtypes.nf4tensor import NF4Tensor, to_nf4
from torchtune import training
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
-from torchtune.modules.peft import LoRALinear
+from torchtune.modules.peft import LoRALinear, QATLoRALinear
+from torchtune.training.quantization import _torchao_0_7_supported
from torchtune.training.seed import set_seed
+
RANK = 4
ALPHA = 1.0
BSZ = 2
SEQ_LEN = 32
EXPECTED_VAL = 1.1252
+QAT_EXPECTED_VAL = 0.6291
@pytest.fixture(autouse=True)
@@ -232,3 +235,12 @@ def test_quantized_state_dict(self, dtype):
assert torch.allclose(
lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data
)
+
+ @pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
+ def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None:
+ lora_linear = lora_linear(use_bias=True, dtype=torch.float32)
+ qat_lora_linear = QATLoRALinear.from_lora_linear(lora_linear)
+ expected = torch.tensor(QAT_EXPECTED_VAL)
+ actual = qat_lora_linear(inputs)
+ assert actual.shape == (BSZ, SEQ_LEN, out_dim)
+ torch.testing.assert_close(actual.mean(), expected, atol=1e-4, rtol=1e-6)
diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py
index 5bbf860482..daf84be1b7 100644
--- a/torchtune/_recipe_registry.py
+++ b/torchtune/_recipe_registry.py
@@ -456,6 +456,17 @@ class Recipe:
],
supports_distributed=True,
),
+ Recipe(
+ name="qat_lora_finetune_distributed",
+ file_path="qat_lora_finetune_distributed.py",
+ configs=[
+ Config(name="llama3/8B_qat_lora", file_path="llama3/8B_qat_lora.yaml"),
+ Config(name="llama3_1/8B_qat_lora", file_path="llama3_1/8B_qat_lora.yaml"),
+ Config(name="llama3_2/1B_qat_lora", file_path="llama3_2/1B_qat_lora.yaml"),
+ Config(name="llama3_2/3B_qat_lora", file_path="llama3_2/3B_qat_lora.yaml"),
+ ],
+ supports_distributed=True,
+ ),
Recipe(
name="knowledge_distillation_single_device",
file_path="knowledge_distillation_single_device.py",
diff --git a/torchtune/modules/peft/__init__.py b/torchtune/modules/peft/__init__.py
index 165559df9c..2959bc3bb6 100644
--- a/torchtune/modules/peft/__init__.py
+++ b/torchtune/modules/peft/__init__.py
@@ -17,13 +17,14 @@
validate_missing_and_unexpected_for_lora,
)
from .dora import DoRALinear
-from .lora import LoRALinear
+from .lora import LoRALinear, QATLoRALinear
__all__ = [
+ "AdapterModule",
"DoRALinear",
"LoRALinear",
- "AdapterModule",
+ "QATLoRALinear",
"get_adapter_params",
"set_trainable_params",
"validate_missing_and_unexpected_for_lora",
diff --git a/torchtune/modules/peft/lora.py b/torchtune/modules/peft/lora.py
index e03d854f1f..f6303b798c 100644
--- a/torchtune/modules/peft/lora.py
+++ b/torchtune/modules/peft/lora.py
@@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
-from typing import List
+from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -131,6 +131,165 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out + lora_out
+class QATLoRALinear(LoRALinear):
+ """
+ LoRA linear layer with quantization-aware training (QAT) applied to the
+ activations and/or weights before the low rank adapters.
+
+ QAT leverages fake quantization to simulate the quantization numerics during
+ training without actually casting the data to lower precision. This class
+ combines LoRA with QAT to improve the final quantized accuracy during inference
+ while reducing the memory required during training.
+
+ Args:
+ in_dim (int): input dimension
+ out_dim (int): output dimension
+ rank (int): rank of the low-rank approximation
+ alpha (float): scaling factor for the low-rank approximation
+ dropout (float): dropout probability. Default: 0.0
+ activation_qat_config (Optional[FakeQuantizeConfig]): config for specifying
+ how input activations will be fake quantized, defaults to None
+ weight_qat_config (Optional[FakeQuantizeConfig]): config for specifying
+ how weights will be fake quantized, defaults to None
+
+ Raises:
+ ValueError: If `in_dim` is not divisible by weight `group_size`
+
+ Example usage::
+
+ activation_qat_config = FakeQuantizeConfig(
+ dtype=torch.int8,
+ granularity="per_token",
+ is_symmetric=False,
+ )
+ weight_qat_config = FakeQuantizeConfig(
+ dtype=torch.int4,
+ group_size=8,
+ is_symmetric=True,
+ )
+ qat_lora_linear = QATLoRALinear(
+ in_dim=512,
+ out_dim=1024,
+ rank=8,
+ alpha=16,
+ dropout=0.0,
+ activation_qat_config=activation_qat_config,
+ weight_qat_config=weight_qat_config,
+ )
+ qat_lora_linear(torch.randn(512))
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ rank: int,
+ alpha: float,
+ dropout: float = 0.0,
+ # fake quantize configs
+ # TODO: make the types Optional[FakeQuantizeConfig] once we
+ # support torchao 0.7+ by default
+ activation_qat_config: Optional["FakeQuantizeConfig"] = None,
+ weight_qat_config: Optional["FakeQuantizeConfig"] = None,
+ ):
+ super().__init__(
+ in_dim,
+ out_dim,
+ rank,
+ alpha,
+ dropout,
+ use_bias=False,
+ quantize_base=False,
+ )
+
+ try:
+ from torchao.quantization.qat.api import FakeQuantizeConfig
+ from torchao.quantization.qat.fake_quantizer import FakeQuantizer
+ except ImportError as err:
+ raise ValueError(
+ "QATLoRALinear is only compatible with torchao 0.7+"
+ ) from err
+
+ # initialize activation fake quantizer
+ if activation_qat_config is not None:
+ assert isinstance(activation_qat_config, FakeQuantizeConfig)
+ self.activation_fake_quantizer = FakeQuantizer(activation_qat_config)
+ else:
+ self.activation_fake_quantizer = nn.Identity()
+
+ # initialize weight fake quantizer
+ if weight_qat_config is not None:
+ assert isinstance(weight_qat_config, FakeQuantizeConfig)
+ group_size = weight_qat_config.group_size
+ if group_size is not None and in_dim % group_size != 0:
+ raise ValueError(
+ "in_dim (%s) must be divisible by group_size (%s)"
+ % (in_dim, group_size)
+ )
+ self.weight_fake_quantizer = FakeQuantizer(weight_qat_config)
+ else:
+ self.weight_fake_quantizer = nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape ``(..., in_dim)``
+
+ Returns:
+ torch.Tensor: output tensor with shape ``(..., out_dim)``
+
+ """
+ _x = self.activation_fake_quantizer(x)
+ w = self.weight_fake_quantizer(self.weight)
+ out = F.linear(_x, w)
+ if self.disabled:
+ return out
+ lora_out = self.lora_a(self.dropout(x))
+ lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
+ return out + lora_out
+
+ @classmethod
+ def from_lora_linear(
+ cls,
+ lora_linear: LoRALinear,
+ # TODO: make the types Optional[FakeQuantizeConfig] once we
+ # support torchao 0.7+ by default
+ activation_qat_config: Optional["FakeQuantizeConfig"] = None,
+ weight_qat_config: Optional["FakeQuantizeConfig"] = None,
+ ) -> "QATLoRALinear":
+ """
+ Create a `QATLoRALinear` from an existing `LoRALinear`,
+ preserving the weights and adapters.
+ """
+ if lora_linear.bias is not None:
+ ValueError("Bias is not supported in QAT + LoRA yet")
+ if lora_linear._quantize_base:
+ ValueError("quantize_base is not compatible with QAT + LoRA")
+ if isinstance(lora_linear.dropout, nn.Dropout):
+ dropout = lora_linear.dropout.p
+ else:
+ dropout = 0.0
+ new_linear = cls(
+ lora_linear.in_dim,
+ lora_linear.out_dim,
+ lora_linear.rank,
+ lora_linear.alpha,
+ dropout,
+ activation_qat_config,
+ weight_qat_config,
+ )
+ # In distributed training, the model may be instantiated
+ # on the meta device, in which case there is no need to
+ # copy the weights, and doing so will result in an error
+ if lora_linear.weight.device != torch.device("meta"):
+ new_linear.weight = lora_linear.weight
+ if lora_linear.lora_a.weight.device != torch.device("meta"):
+ new_linear.lora_a.weight = lora_linear.lora_a.weight
+ if lora_linear.lora_b.weight.device != torch.device("meta"):
+ new_linear.lora_b.weight = lora_linear.lora_b.weight
+ return new_linear
+
+
def _lora_a_init_params(x: nn.Linear) -> None:
"""
Initialize LoRA A weight to Kaiming uniform.
diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py
index 7ff9315f41..4e21cb4936 100644
--- a/torchtune/training/quantization.py
+++ b/torchtune/training/quantization.py
@@ -7,6 +7,10 @@
from typing import Callable, Optional
from warnings import warn
+from torch import nn
+from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear
+
+
try:
# torchao 0.7+
from torchao.dtypes import TensorCoreTiledLayout
@@ -55,6 +59,12 @@
]
+_torchao_0_7_supported = True
+try:
+ from torchao.quantization import qat # noqa: F401
+except ImportError:
+ _torchao_0_7_supported = False
+
_quantizer_to_mode = {}
_quantizer_mode_to_disable_fake_quant = {}
_quantizer_mode_to_enable_fake_quant = {}
@@ -185,3 +195,45 @@ def _get_enable_fake_quant(quantizer_mode: str) -> Callable:
If the quantizer is not recognized as a known QAT quantizer, return None.
"""
return _quantizer_mode_to_enable_fake_quant.get(quantizer_mode, None)
+
+
+def swap_lora_linear_with_qat(
+ module: nn.Module,
+ # TODO: make the types Optional[FakeQuantizeConfig] once we
+ # support torchao 0.7+ by default
+ activation_qat_config: Optional["FakeQuantizeConfig"] = None,
+ weight_qat_config: Optional["FakeQuantizeConfig"] = None,
+) -> None:
+ """
+ Swap all `LoRALinear` in the model with `QATLoRALinear`.
+
+ This is used for combining QAT + LoRA during finetuning. The resulting linear layers
+ will apply the following transformation instead:
+
+ x -> fake_quantize(W_frozen) @ fake_quantize(x) + BAx
+
+ Fake quantization here refers to simulating the quantization numerics without actual
+ dtype casting, with the goal of providing improved accuracies when the model is
+ ultimately quantized after finetuning.
+
+ Args:
+ module (nn.Module): The model to swap linear layers on
+ activation_qat_config (Optional[FakeQuantizeConfig]): The config for specifying
+ how to fake quantize input activations in the base linear layer
+ weight_qat_config (Optional[FakeQuantizeConfig]): The config for specifying
+ how to fake quantize base linear weights
+ """
+ for name, child in module.named_children():
+ if isinstance(child, LoRALinear):
+ new_linear = QATLoRALinear.from_lora_linear(
+ child,
+ activation_qat_config,
+ weight_qat_config,
+ )
+ setattr(module, name, new_linear)
+ else:
+ swap_lora_linear_with_qat(
+ child,
+ activation_qat_config,
+ weight_qat_config,
+ )