Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update QAT: add grad clipping, torch.compile, collate fn #1854

Merged
merged 1 commit into from
Nov 8, 2024

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Oct 16, 2024

Summary:

Update the qat_distributed recipe to match the full_finetune_distributed recipe. This commit adds features to QAT
like gradient clipping, torch.compile, and user configurable collate function for data pre-processing. Mirrors all changes in
full_finetune_distributed as of 506e099.

Helpful commands for quick review:

diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py
diff --color recipes/configs/llama2/7B_full.yaml recipes/configs/llama2/7B_qat_full.yaml
diff --color recipes/configs/llama3/8B_full.yaml recipes/configs/llama3/8B_qat_full.yaml

Test Plan:

Fine-tune on alpaca dataset for 1 epoch with and without QAT:

CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 tune run --nnodes 1 --nproc_per_node 6 qat_distributed --config llama3/8B_qat_full \
    epochs=1 \
    checkpointer.output_dir="$LOG_DIR" \
    metric_logger.output_dir="${LOG_DIR}/metrics" \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer

CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/quantization.yaml \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=[meta_model_0.pt] \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer

CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \
    tasks=[wikitext] \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=[meta_model_0-8da4w.pt] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer

With QAT:

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.9821|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.9754|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |38.1039|±  |   N/A|

Without QAT:

| Tasks  |Version|Filter|n-shot|    Metric     |   |  Value  |   |Stderr|
|--------|------:|------|------|---------------|---|--------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  |   2.2017|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  |   4.6003|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |3501.1122|±  |   N/A|

Copy link

pytorch-bot bot commented Oct 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1854

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 258cd8b with merge base 506e099 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 16, 2024
@codecov-commenter
Copy link

codecov-commenter commented Oct 16, 2024

Codecov Report

Attention: Patch coverage is 0% with 23 lines in your changes missing coverage. Please review.

Project coverage is 69.19%. Comparing base (c70ad29) to head (5aef800).
Report is 11 commits behind head on main.

Files with missing lines Patch % Lines
recipes/qat_distributed.py 0.00% 23 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1854      +/-   ##
==========================================
+ Coverage   67.30%   69.19%   +1.89%     
==========================================
  Files         304      305       +1     
  Lines       16000    16031      +31     
==========================================
+ Hits        10768    11092     +324     
+ Misses       5232     4939     -293     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you attach some output from your run? Then this looks good to me.

**Summary:**

Update the qat_distributed recipe to match the
full_finetune_distributed recipe. This commit adds features to QAT
like gradient clipping, torch.compile, and user configurable collate
function for data pre-processing. Mirrors all changes in
full_finetune_distributed as of 506e099.

Helpful commands for quick review:
```
diff --color recipes/full_finetune_distributed.py recipes/qat_distributed.py
diff --color recipes/configs/llama2/7B_full.yaml recipes/configs/llama2/7B_qat_full.yaml
diff --color recipes/configs/llama3/8B_full.yaml recipes/configs/llama3/8B_qat_full.yaml
```

**Test Plan:**

Fine-tune on alpaca dataset for 1 epoch with and without QAT:
```
CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 tune run --nnodes 1 --nproc_per_node 6 qat_distributed --config llama3/8B_qat_full \
    epochs=1 \
    checkpointer.output_dir="$LOG_DIR" \
    metric_logger.output_dir="${LOG_DIR}/metrics" \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer

CUDA_VISIBLE_DEVICES=1 tune run quantize --config recipes/configs/quantization.yaml \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=[meta_model_0.pt] \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer

CUDA_VISIBLE_DEVICES=1 tune run eleuther_eval --config eleuther_evaluation \
    tasks=[wikitext] \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=[meta_model_0-8da4w.pt] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer
```

With QAT:
```
| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.9821|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.9754|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |38.1039|±  |   N/A|
```

Without QAT:
```
| Tasks  |Version|Filter|n-shot|    Metric     |   |  Value  |   |Stderr|
|--------|------:|------|------|---------------|---|--------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  |   2.2017|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  |   4.6003|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |3501.1122|±  |   N/A|
```
@joecummings joecummings merged commit 96dea61 into pytorch:main Nov 8, 2024
17 checks passed
@ebsmothers ebsmothers mentioned this pull request Nov 26, 2024
44 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants