Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for QAT + LoRA #1931

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ torchtune provides the following finetuning recipes for training on one or more
| LoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py) <br> [lora_finetune_distributed](recipes/lora_finetune_distributed.py) | [Qwen2 0.5B single-device](recipes/configs/qwen2/0.5B_lora_single_device.yaml) <br> [Gemma 7B distributed](recipes/configs/gemma/7B_lora.yaml)
| QLoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py) <br> [lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Phi3 Mini single-device](recipes/configs/phi3/mini_qlora_single_device.yaml) <br> [Llama 3.1 405B distributed](recipes/configs/llama3_1/405B_qlora.yaml)
| DoRA/QDoRA Finetuning | 1-8 | [lora_finetune_single_device](recipes/lora_finetune_single_device.py) <br> [lora_finetune_distributed](recipes/lora_finetune_distributed.py)| [Llama3 8B QDoRA single-device](recipes/configs/llama3/8B_qdora_single_device.yaml) <br> [Llama3 8B DoRA distributed](recipes/configs/llama3/8B_dora.yaml)
| Quantization-Aware Training | 4-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml)
| Quantization-Aware Training | 2-8 | [qat_distributed](recipes/qat_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_full.yaml)
| Quantization-Aware Training and LoRA Finetuning | 2-8 | [qat_lora_finetune_distributed](recipes/qat_lora_finetune_distributed.py)| [Llama3 8B QAT](recipes/configs/llama3/8B_qat_lora.yaml)
| Direct Preference Optimization |1-8 | [lora_dpo_single_device](recipes/lora_dpo_single_device.py) <br> [lora_dpo_distributed](recipes/lora_dpo_distributed.py) | [Llama2 7B single-device](recipes/configs/llama2/7B_lora_dpo_single_device.yaml) <br> [Llama2 7B distributed](recipes/configs/llama2/7B_lora_dpo.yaml)
| Proximal Policy Optimization | 1 | [ppo_full_finetune_single_device](recipes/ppo_full_finetune_single_device.py) | [Mistral 7B](recipes/configs/mistral/7B_full_ppo_low_memory.yaml)
| Knowledge Distillation | 1 | [knowledge_distillation_single_device](recipes/knowledge_distillation_single_device.py) | [Qwen2 1.5B -> 0.5B](recipes/configs/qwen2/knowledge_distillation_single_device.yaml)
Expand Down
113 changes: 113 additions & 0 deletions recipes/configs/llama3/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
# using a Llama3 8B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token <HF_TOKEN>
#
# To launch on 2 devices, run the following command from root:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3/8B_qat_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model
max_seq_len: null

# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_8b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8 # higher increases accuracy and memory
lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/
checkpoint_files: [
consolidated.00.pth
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory

# Logging
output_dir: /tmp/qat_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
116 changes: 116 additions & 0 deletions recipes/configs/llama3_1/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
# using a Llama3.1 8B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 2 devices, run the following command from root:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_1/8B_qat_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: null

# Model Arguments
model:
_component_: torchtune.models.llama3_1.lora_llama3_1_8b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 8 # higher increases accuracy and memory
lora_alpha: 16 # usually alpha=2*rank
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory

# Logging
output_dir: /tmp/qat_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
112 changes: 112 additions & 0 deletions recipes/configs/llama3_2/1B_qat_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Config for multi-device QAT + LoRA finetuning in qat_lora_finetune_distributed.py
# using a Llama3.2 1B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 2 devices, run the following command from root:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nproc_per_node 2 qat_lora_finetune_distributed --config llama3_2/1B_qat_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
max_seq_len: null

# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
lora_rank: 64 # higher increases accuracy and memory
lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Llama-3.2-1B-Instruct/
model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 8 # Use to increase virtual batch size
compile: False # pytorch compile, set to true for better perf/memory

# Logging
output_dir: /tmp/qat_lora_finetune_output
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
Loading