Skip to content

Commit

Permalink
Add support for QAT + LoRA
Browse files Browse the repository at this point in the history
**Summary:**

This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe `qat_lora_finetune_distributed` mirrors the existing `lora_finetune_distributed` recipe, which performs only LoRA, and is analogous to the existing `qat_distributed` recipe, which performs only QAT.

Helpful code review commands:
```
diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py
diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml
diff --color recipes/configs/llama3_1/8B_lora.yaml recipes/configs/llama3_1/8B_qat_lora.yaml
```

For more context on QAT, please visit #980 and https://pytorch.org/blog/quantization-aware-training/.

**Test Plan**

Unit tests:
```
pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py
```

Manual tests:
```
export CUDA_VISIBLE_DEVICES=4,5,6,7
export NCCL_SHM_DISABLE=0
LOG_DIR=/home/andrewor/local/logs/tune/qat_lora

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \
    batch_size=4 \
    quantizer.groupsize=32 \
    checkpointer.output_dir="$LOG_DIR" \
    metric_logger.output_dir="${LOG_DIR}/metrics"

tune run quantize --config quantization \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt"] \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

tune run eleuther_eval --config eleuther_evaluation \
    batch_size=1 \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    tasks=[wikitext] \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32
```

Results:
```

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6284|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5458|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.2694|±  |   N/A|

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6245|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5416|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.1208|±  |   N/A|
```
  • Loading branch information
andrewor14 committed Nov 19, 2024
1 parent abdb5a4 commit 724fb11
Show file tree
Hide file tree
Showing 9 changed files with 1,690 additions and 4 deletions.
113 changes: 113 additions & 0 deletions recipes/configs/llama3/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
# using a Llama3 8B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token <HF_TOKEN>
#
# To launch on 2 devices, run the following command from root:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
max_seq_len: null

# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_8b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8 # higher increases accuracy and memory
lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory

# Logging
output_dir: /tmp/qat_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
116 changes: 116 additions & 0 deletions recipes/configs/llama3_1/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
# using a Llama3.1 8B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 2 devices, run the following command from root:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: null

# Model Arguments
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_8b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8 # higher increases accuracy and memory
lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory

# Logging
output_dir: /tmp/qat_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
Loading

0 comments on commit 724fb11

Please sign in to comment.