Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
**Summary:** This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe `qat_lora_finetune_distributed` mirrors the existing `lora_finetune_distributed` recipe, which performs only LoRA, and is analogous to the existing `qat_distributed` recipe, which performs only QAT. Helpful code review commands: ``` diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml diff --color recipes/configs/llama3_1/8B_lora.yaml recipes/configs/llama3_1/8B_qat_lora.yaml ``` For more context on QAT, please visit #980 and https://pytorch.org/blog/quantization-aware-training/. **Test Plan** Unit tests: ``` pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py ``` Manual tests: ``` export CUDA_VISIBLE_DEVICES=4,5,6,7 export NCCL_SHM_DISABLE=0 LOG_DIR=/home/andrewor/local/logs/tune/qat_lora tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \ batch_size=4 \ quantizer.groupsize=32 \ checkpointer.output_dir="$LOG_DIR" \ metric_logger.output_dir="${LOG_DIR}/metrics" tune run quantize --config quantization \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=["meta_model_0.pt"] \ checkpointer.model_type=LLAMA3 \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 tune run eleuther_eval --config eleuther_evaluation \ batch_size=1 \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ tasks=[wikitext] \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \ quantizer.groupsize=32 ``` Results: ``` | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.6284|± | N/A| | | |none |None |byte_perplexity|↓ | 1.5458|± | N/A| | | |none |None |word_perplexity|↓ |10.2694|± | N/A| | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.6245|± | N/A| | | |none |None |byte_perplexity|↓ | 1.5416|± | N/A| | | |none |None |word_perplexity|↓ |10.1208|± | N/A| ```
- Loading branch information