Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for QAT + LoRA #1931

Merged
merged 1 commit into from
Nov 26, 2024
Merged

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Oct 31, 2024

Summary:

This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe qat_lora_finetune_distributed mirrors the existing lora_finetune_distributed recipe, which performs only LoRA, and is analogous to the existing qat_distributed recipe, which performs only QAT.

Helpful code review commands:

diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py
diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml
diff --color recipes/configs/llama3_1/8B_lora.yaml recipes/configs/llama3_1/8B_qat_lora.yaml
diff --color recipes/configs/llama3_2/1B_lora.yaml recipes/configs/llama3_2/1B_qat_lora.yaml
diff --color recipes/configs/llama3_2/3B_lora.yaml recipes/configs/llama3_2/3B_qat_lora.yaml

For more context on QAT, please visit #980 and https://pytorch.org/blog/quantization-aware-training/.

Test Plan

Unit tests:

pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py

Manual tests:

export CUDA_VISIBLE_DEVICES=4,5,6,7
export NCCL_SHM_DISABLE=0
LOG_DIR=/home/andrewor/local/logs/tune/qat_lora

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \
    batch_size=4 \
    quantizer.groupsize=32 \
    checkpointer.output_dir="$LOG_DIR" \
    metric_logger.output_dir="${LOG_DIR}/metrics"

tune run quantize --config quantization \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt"] \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

tune run eleuther_eval --config eleuther_evaluation \
    batch_size=1 \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    tasks=[wikitext] \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

Results:

# Baseline (LoRA only, no QAT)

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6284|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5458|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.2694|±  |   N/A|

# LoRA + QAT (new recipe)

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6245|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5416|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.1208|±  |   N/A|

Copy link

pytorch-bot bot commented Oct 31, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1931

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 95961d4 with merge base abdb5a4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@andrewor14 andrewor14 marked this pull request as draft October 31, 2024 00:10
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 31, 2024
eval_it.sh Outdated Show resolved Hide resolved
@andrewor14 andrewor14 force-pushed the try-qat-lora branch 2 times, most recently from e20e891 to d09c71f Compare November 1, 2024 19:38
@gau-nernst
Copy link
Contributor

Hey @andrewor14, I was hacking around with LoRA/QLoRA + INT8 mixed-precision and came across this PR of yours. I realized what we are trying to achieve is quite similar.

Components of LoRALinear:

  • base weight: F.linear(x, self.weight) -> we want to modify this op
  • LoRA adapters: self.lora_b(self.lora_a(x))

Since the base weight is direct children of LoRALinear, and F.linear() is hard-coded, it's hard to extend functionality of LoRALinear without re-writing the whole thing. So I had the idea of making the base weight as its own nn.Linear() module, thus we can freely swap the base linear module to modify its op.

  • For my specific use case, I can quantize base weight to NF4, AND swap base linear module to INT8 mixed-precision, thus composing both NF4 weight and INT8 matmul for compute.

I have a POC here main...gau-nernst:qlora (you can focus on torchtune/modules/peft/lora.py file). With this, we can re-use the linear module-swap in torchao. And you don't need a separate qat_lora_finetune_distributed.py, since we can add quantizer to existing recipes (though I understand you might not want this. I didn't carefully check the differences of QAT recipe script from other training scripts).

@andrewor14
Copy link
Contributor Author

Hi @gau-nernst, yeah I agree we can make the base weight more flexible, then we won't need to create a new class every time we need to extend lora functionality. cc @ebsmothers to see your thoughts on extending LoRALinear this way: main...gau-nernst:qlora. For QAT in particular though the current flow uses full module swap (no tensor subclass yet), so we'll need some other way to initialize the base module like manually setting self.base, so it may not be as elegant there. Also I think the existing QATLoRALinear doesn't add much boiler plate code, so it might be OK.

For the separate recipe, I discussed this with @ebsmothers recently and I think it's torchtune's recipe organization philosophy to keep them separate, so QAT functionality won't complicate the original lora recipe.

@andrewor14 andrewor14 force-pushed the try-qat-lora branch 3 times, most recently from 4c1b14f to 9faac01 Compare November 8, 2024 16:54
@andrewor14 andrewor14 marked this pull request as ready for review November 8, 2024 16:54
@andrewor14 andrewor14 force-pushed the try-qat-lora branch 4 times, most recently from b34da3b to 7a600cc Compare November 8, 2024 22:09
@codecov-commenter
Copy link

codecov-commenter commented Nov 8, 2024

Codecov Report

Attention: Patch coverage is 10.54852% with 424 lines in your changes missing coverage. Please review.

Project coverage is 24.40%. Comparing base (1814feb) to head (2404803).
Report is 7 commits behind head on main.

Files with missing lines Patch % Lines
recipes/qat_lora_finetune_distributed.py 0.00% 311 Missing ⚠️
...ests/recipes/test_qat_lora_finetune_distributed.py 32.60% 62 Missing ⚠️
torchtune/modules/peft/lora.py 13.04% 40 Missing ⚠️
tests/torchtune/modules/peft/test_lora.py 45.45% 6 Missing ⚠️
torchtune/training/quantization.py 61.53% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1931       +/-   ##
===========================================
- Coverage   67.29%   24.40%   -42.89%     
===========================================
  Files         318      325        +7     
  Lines       17646    18498      +852     
===========================================
- Hits        11874     4515     -7359     
- Misses       5772    13983     +8211     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

@andrewor14
Copy link
Contributor Author

@ebsmothers Any comments? Does this look good to you?

@ebsmothers
Copy link
Contributor

Hey @andrewor14 sorry for the delay and thanks for your patience here. We are doing planning this week so my available bandwidth for reviewing this has taken a hit. I promise to get to it by Friday at the latest

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @andrewor14 for the PR, and for your patience in the review process! I left a number of comments but no major concerns from my side.

Regarding the two discussion points raised by @gau-nernst previously:

(1) I am open to discussing whether we should change how we expose self.weight in LoRALinear from just nn.Parameter to nn.Linear. I agree that the latter would be more module-swap friendly, but (for better or for worse) we also do not really design things to have module-swap-based methods as a first-class citizen. Also I think it is not for free, as it means that the key names of the base linear weight will now have an extra module name in between for LoRALinear (e.g. q_proj.base_weight.weight instead of q_proj.weight). On the surface this is quite trivial, but from a checkpointing perspective it's really nice that we currently have an exact match between nn.Linear and LoRALinear and hence can load weights from one directly into the other (with strict=False of course). So making this simple change would actually have a pretty big blast radius across all our various checkpointing logic. Lmk if these points make sense, happy to discuss further with you.

(2) I think Andrew's separation of QAT full and QAT LoRA into separate recipes makes sense and aligns with what we do for our other recipes. If you are (as I am) a bit concerned about the proliferation of many recipes with similar functionality, I have two comments: (1) we will be making a dedicated effort to start simplifying and consolidating the recipe files somewhat so that they look more like their earlier versions. And (2) I think we now have enough recipes that we can actually consider e.g. recipes/qat, recipes/knowledge_distillation, etc. subdirectories. Neither of these will happen tomorrow, but I would like both to happen eventually.

recipes/configs/llama2/7B_qat_lora.yaml Outdated Show resolved Hide resolved
recipes/configs/llama2/7B_qat_lora.yaml Outdated Show resolved Hide resolved
Comment on lines +100 to +103
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
training. Currently we checkpoint both the adapter weights (trainable params only) and the
complete merged weights (adapter weights added back to the base model). For more details
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think we may need to update our lora_finetune_distributed.py docstring as well.. really we should make it clear that this behavior can be disabled if save_adapter_weights_only=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, let's fix these recipes in a separate PR

self._sampler.set_epoch(curr_epoch)

pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
for idx, batch in enumerate(self._dataloader):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So no option to wait N steps before enabling fake quant in this recipe? Any particular reason for that? (To clarify I'm not saying we should add it, mainly just curious)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not super straightforward to add it right now because this recipe uses the new general FakeQuantizedLinear, as opposed to the specific Int8DynActInt4WeightLinear class used by the qat_distributed recipe. I think we can add it separately

recipes/qat_lora_finetune_distributed.py Outdated Show resolved Hide resolved
torchtune/modules/peft/lora.py Outdated Show resolved Hide resolved
torchtune/modules/peft/lora.py Outdated Show resolved Hide resolved
torchtune/modules/peft/lora.py Show resolved Hide resolved
torchtune/modules/peft/lora.py Show resolved Hide resolved
torchtune/modules/peft/lora.py Outdated Show resolved Hide resolved
@gau-nernst
Copy link
Contributor

@ebsmothers Regarding

key names of the base linear weight will now have an extra module name in between for LoRALinear

In my proof-of-concept above, I handle this by adding the following hooks

        def load_state_dict_pre_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
            if isinstance(module, LoRALinear):
                state_dict[f"{prefix}base.weight"] = state_dict.pop(f"{prefix}weight")

        self.register_load_state_dict_pre_hook(load_state_dict_pre_hook)

        def state_dict_post_hook(module, state_dict, prefix, local_metadata):
            if isinstance(module, LoRALinear):
                state_dict[f"{prefix}weight"] = state_dict.pop(f"{prefix}base.weight")

        self.register_state_dict_post_hook(state_dict_post_hook)

From my testing it seems sufficient, though I might not cover all edge cases (FSDP2?)

We can discuss more in a separate issue/PR if you are open to it, so as not to hijack this PR about QAT + LoRA 😄. The main benefit is ease of injecting custom logic, such as QAT for this PR, INT8 matmul for #1552, or even FP8 matmul in the future. You probably know better than me what are the potential issues, but I think we can try to see if those can be handled nicely.

@ebsmothers
Copy link
Contributor

@gau-nernst personally I have a bit of an aversion to state dict hooks as @pbontrager can attest 😅. Mainly I find that they make code really hard to debug. Correct usage of modules having state dict hooks generally requires that a module has its state dict called exactly once and submodules are not accessed or modified in any other way. And if either of these constraints are not satisfied the user will get a very non-obvious error about some missing attribute and it won't be at all clear where to go to fix it.

But I agree with your point about consolidating the discussion elsewhere (sounds like this PR wouldn't benefit as much from modifying LoRALinear's self.weight anyways). Maybe some lightweight RFC discussing pros and cons would be helpful (I can add my comments there as well), and we can tag other folks to get their thoughts too.

@andrewor14 andrewor14 force-pushed the try-qat-lora branch 3 times, most recently from 131542a to 618fdce Compare November 19, 2024 22:15
@andrewor14 andrewor14 force-pushed the try-qat-lora branch 2 times, most recently from 724fb11 to 0256e97 Compare November 19, 2024 22:28
@andrewor14
Copy link
Contributor Author

@ebsmothers any other comments?

@andrewor14 andrewor14 force-pushed the try-qat-lora branch 4 times, most recently from c0e9778 to 1181e39 Compare November 21, 2024 20:19
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK a couple more small comments but after that I think this should be good to go. A couple other requests before landing:

  1. Can you make sure this works with all our usual features (e.g. activation checkpointing, activation offloading)? I already ran with compile myself so no need to worry about that one
  2. You should also add it to the recipes table in our readme! That way people will know to try it out

tests/recipes/test_qat_lora_finetune_distributed.py Outdated Show resolved Hide resolved
tests/recipes/test_qat_lora_finetune_distributed.py Outdated Show resolved Hide resolved
@@ -232,3 +235,12 @@ def test_quantized_state_dict(self, dtype):
assert torch.allclose(
lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data
)

@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

recipes/qat_lora_finetune_distributed.py Outdated Show resolved Hide resolved
@andrewor14 andrewor14 force-pushed the try-qat-lora branch 2 times, most recently from 9f154fe to c3c0d4a Compare November 26, 2024 16:23
@andrewor14
Copy link
Contributor Author

OK a couple more small comments but after that I think this should be good to go. A couple other requests before landing:

  1. Can you make sure this works with all our usual features (e.g. activation checkpointing, activation offloading)? I already ran with compile myself so no need to worry about that one
  2. You should also add it to the recipes table in our readme! That way people will know to try it out

Sounds good. I think I addressed all of the comments and also tested it with the features you mentioned. Please take another look, thanks!

README.md Outdated Show resolved Hide resolved
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two more really minor comments. After that it's good to merge. Thanks so much for adding this!

**Summary:**

This commit adds a recipe that combines QAT + LoRA, with the main goal of improving final quantized accuracy after training while reducing the memory required for fine-tuning. The new recipe `qat_lora_finetune_distributed` mirrors the existing `lora_finetune_distributed` recipe, which performs only LoRA, and is analogous to the existing `qat_distributed` recipe, which performs only QAT.

Helpful code review commands:
```
diff --color recipes/lora_finetune_distributed.py recipes/qat_lora_finetune_distributed.py
diff --color recipes/configs/llama3/8B_lora.yaml recipes/configs/llama3/8B_qat_lora.yaml
diff --color recipes/configs/llama3_1/8B_lora.yaml recipes/configs/llama3_1/8B_qat_lora.yaml
diff --color recipes/configs/llama3_2/1B_lora.yaml recipes/configs/llama3_2/1B_qat_lora.yaml
diff --color recipes/configs/llama3_2/3B_lora.yaml recipes/configs/llama3_2/3B_qat_lora.yaml
```

For more context on QAT, please visit pytorch#980 and https://pytorch.org/blog/quantization-aware-training/.

**Test Plan**

Unit tests:
```
pytest -m integration_test tests/recipes/test_qat_lora_finetune_distributed.py
```

Manual tests:
```
export CUDA_VISIBLE_DEVICES=4,5,6,7
export NCCL_SHM_DISABLE=0
LOG_DIR=/home/andrewor/local/logs/tune/qat_lora

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora \
    batch_size=4 \
    quantizer.groupsize=32 \
    checkpointer.output_dir="$LOG_DIR" \
    metric_logger.output_dir="${LOG_DIR}/metrics"

tune run quantize --config quantization \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelMetaCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt"] \
    checkpointer.model_type=LLAMA3 \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32

tune run eleuther_eval --config eleuther_evaluation \
    batch_size=1 \
    model._component_=torchtune.models.llama3.llama3_8b \
    checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \
    checkpointer.checkpoint_dir="$LOG_DIR" \
    checkpointer.output_dir="$LOG_DIR" \
    checkpointer.checkpoint_files=["meta_model_0.pt-8da4w"] \
    checkpointer.model_type=LLAMA3 \
    tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \
    tokenizer.path=/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    tasks=[wikitext] \
    quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
    quantizer.groupsize=32
```

Results:
```

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6284|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5458|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.2694|±  |   N/A|

| Tasks  |Version|Filter|n-shot|    Metric     |   | Value |   |Stderr|
|--------|------:|------|------|---------------|---|------:|---|------|
|wikitext|      2|none  |None  |bits_per_byte  |↓  | 0.6245|±  |   N/A|
|        |       |none  |None  |byte_perplexity|↓  | 1.5416|±  |   N/A|
|        |       |none  |None  |word_perplexity|↓  |10.1208|±  |   N/A|
```
@ebsmothers ebsmothers merged commit 437a8ff into pytorch:main Nov 26, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants