-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update QAT: add grad clipping, torch.compile, collate fn #1854
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1854
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 258cd8b with merge base 506e099 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
5a2aa35
to
5aef800
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1854 +/- ##
==========================================
+ Coverage 67.30% 69.19% +1.89%
==========================================
Files 304 305 +1
Lines 16000 16031 +31
==========================================
+ Hits 10768 11092 +324
+ Misses 5232 4939 -293 ☔ View full report in Codecov by Sentry. |
5aef800
to
8b55ce7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you attach some output from your run? Then this looks good to me.
bb0cc85
to
3df9045
Compare
**Summary:** Update the qat_distributed recipe to match the full_finetune_distributed recipe. This commit adds features to QAT like gradient clipping, torch.compile, and user configurable collate function for data pre-processing. Mirrors all changes in full_finetune_distributed as of 506e099. Helpful commands for quick review: ``` diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py diff --color recipes/configs/llama2/7B_full.yaml recipes/configs/llama2/7B_qat_full.yaml diff --color recipes/configs/llama3/8B_full.yaml recipes/configs/llama3/8B_qat_full.yaml ``` **Test Plan:** Fine-tune on alpaca dataset for 1 epoch with and without QAT: ``` CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 tune run --nnodes 1 --nproc_per_node 6 qat_distributed --config llama3/8B_qat_full \ epochs=1 \ checkpointer.output_dir="$LOG_DIR" \ metric_logger.output_dir="${LOG_DIR}/metrics" \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/quantization.yaml \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=[meta_model_0.pt] \ checkpointer.model_type=LLAMA3 \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \ tasks=[wikitext] \ model._component_=torchtune.models.llama3.llama3_8b \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir="$LOG_DIR" \ checkpointer.output_dir="$LOG_DIR" \ checkpointer.checkpoint_files=[meta_model_0-8da4w.pt] \ checkpointer.model_type=LLAMA3 \ tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \ quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer ``` With QAT: ``` | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 0.9821|± | N/A| | | |none |None |byte_perplexity|↓ | 1.9754|± | N/A| | | |none |None |word_perplexity|↓ |38.1039|± | N/A| ``` Without QAT: ``` | Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr| |--------|------:|------|------|---------------|---|--------:|---|------| |wikitext| 2|none |None |bits_per_byte |↓ | 2.2017|± | N/A| | | |none |None |byte_perplexity|↓ | 4.6003|± | N/A| | | |none |None |word_perplexity|↓ |3501.1122|± | N/A| ```
3df9045
to
258cd8b
Compare
Summary:
Update the qat_distributed recipe to match the
full_finetune_distributed
recipe. This commit adds features to QATlike gradient clipping, torch.compile, and user configurable collate function for data pre-processing. Mirrors all changes in
full_finetune_distributed
as of 506e099.Helpful commands for quick review:
Test Plan:
Fine-tune on alpaca dataset for 1 epoch with and without QAT:
With QAT:
Without QAT: