-
Notifications
You must be signed in to change notification settings - Fork 468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate INT8 mixed-precision from torchao 0.7 #1552
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1552
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its very simple and it looks great! thanks for the PR!
My two cents:
- We need some tests to make sure it works with compile, AC, offloading (not landed yet), optimizer in backward (i guess?), etc
- The configs can be updated in bulk using something like this dummy script: https://gist.github.com/felipemello1/5f2002433c6da3a21f33d6cdf82e702a
Let me know if you want me to help with any of these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: would it be nice to add some sort of import guard to tell the user they need torchao > N for this? Torchao is not a requirement anymore
cc: @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work with older GPUs? Does it work on CPU? Maybe we need something like this:
_SUPPORTS_INT8_MIXED_PRECISION = (
torch_version_ge("2.5.0")
and torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (7, 5)
)
and torch.cuda.get_device_capability() >= (7, 5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! I use Triton for this so it will probably run on any GPUs Triton supports (same as torch.compile). Though I think only Ampere (sm80) and above has INT8 tensor cores. To be safe, I think we just guard for sm80 and above.
CPU is not supported. Technically it is possible, but I didn't add it in torchao since I can't reliably test it / see it useful.
This works with PyTorch 2.4 (though FlexAttention requires PyTorch 2.5 🤣).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw doesn't QAT also need some kind of guards like this? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I completely missed this convo previously. Actually now that we've asked users to install torchao manually I am kinda taking a similar stance there to what we have with PyTorch: people should always be running on the latest stable version. So actually I claim that the first two lines of the _SUPPORTS_INT8_MIXED_PRECISION_TRAINING
check are not strictly necessary (fine to keep them in though, I don't have a strong preference). But anyways that should hopefully answer the question about why we don't have similar such guards for QAT
Memory is not expected to change. In this scheme, weights and activations are quantized dynamically to INT8 (with scaling, so that we can do INT8 matmul), then the result is scaled back to BF16. Activations and weights are still in BF16 throughout the model. This is the same strategy as the current To reduce memory, either weight or activations must stay in low-precision. Using only low-bit weight for training is a bit challenging (not impossible, but there will be convergence/accuracy issues). There are some research works on using low-bit activations (FP8/INT8), but we don't have that in torchao yet. e.g.
|
I tested it on multiple other models. Main issues:
Lets address these and I am good to merge. I can edit all configs in a follow up PR :) Thanks again @gau-nernst :) |
|
sorry, I realize that you mentioned it about LoRA, but I didnt connect the dots. Let me talk to Evan and get back to you. Maybe we can land it for full finetuning only, since this is working well and we can follow up about LoRA. I will reply here on Monday. |
Sorry for the delay here @gau-nernst and @felipemello1. After looking through the code again I think this is clean enough to move forward with (subject to Felipe's comment), even if we do only support full finetune. I need to do a bit more research to have a more informed opinion on the options laid out for handling LoRA. Regarding option (a): I took a look at the implementation in main...gau-nernst:qlora and while I am still averse to state dict hooks it is quite clean. Separately I am still trying to understand pros and cons of module-swap vs tensor subclass -- I know tensor subclass needs custom handling for state dict, are there drawbacks to module-swap (I previously there were potentially challenges with FSDP, but actually it seems like that's not the case)? Also the point raised in pytorch/ao#1179 about composability is not clear to me, isn't this a function of one quantization method (NF4) using tensor subclass and another (int8) using module swap? If I look at pytorch/ao#987 it claims that composability of tensor subclass is better (at least with other tensor subclasses). Lmk if I'm missing the point here though. Again, I don't want to block on any of this because unfortunately even the pretty trivial change to So let's move forward with this for full finetune only (I will give a proper review shortly) and separately we can create an RFC to discuss pros and cons of a possible LoRA redesign. |
@ebsmothers Thanks for your feedback. I will add "raise an error if mixed_precision.enabled=True but compile=False or packed=False", as commented by Felipe, soon. Regarding module-swap vs tensor subclass, after all they are just different ways to inject custom logic to an existing model. Currently, we have 2 general use cases:
In terms of "composability", I don't think there is a clear answer about which one is better, because it largely depends on what kind of functionalities being composed together. The point about "composability" that I mentioned in pytorch/ao#1179 is specifically about custom weight-only quantization + low-bit compute. e.g.
This is possible since for weight-only quant, we don't have custom Of course it's also possible to use INT8-matmul-subclass + NF4-subclass: INT8-matmul-subclass can hold NF4-subclass internally (similar to how dynamic activation quant is implemented in torchao). However, this can be a bit messy: INT8-matmul-subclass may need to be aware of the NF4-subclass implementation, multiple level of indirection is hard to debug, and possibly many others. So at least for this use case of "weight-only quant + low-bit compute", delegate one function to subclass and another to module-swap seems clearer. Hope it somewhat clarifies the options presented here (and hope that it didn't confuse you further 😅) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gau-nernst @felipemello1 are we aiming to get this in for our next release? I left a few comments but I think (modulo a final round of testing, which I'm happy to help with) the remaining gap (at least for full finetune) to land should be pretty small.
I addressed all of the remaining comments. Please take another look if I missed out anything. Thank you. Also added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gau-nernst I took a look through the code and no concerns from me, everything looks good. I think @felipemello1 landed a bunch of changes today that will cause merge conflicts on this PR, if it's too much of a headache just let us know and he can do the merge then push to your PR. Last thing: I was trying to test out the perf gains one final time before stamping but can't seem to find a good PyTorch nightly to run on (seems like lots of issues with nightlies lately). Lmk if you have a good stable env that I can test with (until then I will blindly try out various nightlies 😅 )
update: ok I got a nightly version that "works". Unfortunately it doesn't look great:
Repro:
Baseline command:
tune run full_finetune_single_device --config llama3/8B_full_single_device dataset.packed=True compile=True \
tokenizer.max_seq_len=2048 metric_logger=torchtune.training.metric_logging.WandBLogger \
metric_logger.project=mp-tests metric_logger.name=no-mp-latest
Mixed precision command:
tune run full_finetune_single_device --config llama3/8B_full_single_device mixed_precision.enabled=True \
dataset.packed=True compile=True tokenizer.max_seq_len=2048 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=mp-tests \
metric_logger.name=with-mp-latest
My env. Like I said above I had a pretty hard time finding a recent working nightly (and I don't think ao 0.7 is quite in stable yet), so maybe it's just something messed up in my environment. But a sanity check here would be much appreciated
@ebsmothers PyTorch 2.5 should work. I have been testing with PyTorch 2.5. Nightly is not needed. I can confirm that INT8 mixed-precision does not play well with bnb 8-bit AdamW. Loss will explode shortly after training starts. I'm not sure why. Switching to the plain PyTorch AdamW seems to solve the issue. In terms of speed, indeed batch size must be sufficiently large to see speedup. I ran a few runs
We probably should add these 2 notes somewhere
Perhaps you can advise me an appropriate place to put these notes? |
@gau-nernst thanks for the detailed analysis. On a new env with PyTorch 2.5 and ao nightlies I see similar results to you for batch size 8, max seq len 2048. I have a few more comments -- the first one should be handled in this PR but the others are more for follow-ups.
Anyways, none of this is to push back too much on the PR. I think once (1) is addressed we should land it. It's an awesome feature that most other finetuning frameworks don't have, I'm really excited to support it. But given that we are one of the first to have it, it's up to us to set the right expectations for the community on how to use it (hence the importance of (2) and (3)). |
Also, regarding 1B, I tested it and indeed A100 does not see significant speedup, but on my personal GPU, the speedup is still observed at 4x2048.
On A100
On 4070Ti SUPER
My GPU OOM with larger batch sizes (only 16GB 😢). 4090 (24GB VRAM) should have the same speedup characteristic. This makes sense because INT8 tensor cores in consumer cards are rated at 4x faster than BF16/FP16 tensor cores, while enterprise GPUs are only rated at 2x speedup. But it may indicate that there is room for improvement on A100. Anyway, this is to show that the speedup is different across GPUs. Perhaps once we work on (2), we can do more extensive end2end benchmarks across model sizes and GPUs. |
@gau-nernst on (1): this sounds good to me. The one minor point I would change if possible is to pass each field explicitly to a utility like Re (2): yes, I think the ao readme page is great here. I think if we wanted to expand upon it in our torchtune docs, there are two potential directions we could go: lower-level (i.e. getting into the underlying implementation in more detail) or more user-focused (i.e. discussing different feature interactions, giving perf numbers, running experiments, etc.). If we combine what you already have with one or both of these, it'd make a nice standalone page on our live docs. But again we can figure out the details here in a follow-up. So yeah the ao documentation is already really clear in terms of usage, it's just a matter of seeing if there's a nice narrative we can provide around the feature to get people even more aware and interested to try it. Thanks for the benchmarking results! This is interesting and not something I would have expected. And totally agree -- this is exactly the kind of thing that would be great to discuss in more detail in (2). |
I like that :) |
Context
What is the purpose of this PR? Is it to
Recent INT8 mixed-precision work in torchao shows very promising results.
Known major limitations
torch.compile()
to enjoy speedup (to codegen efficient dynamic quantization code)seq_len
is static.Does not work with-> solved by using module-swap UX instead. Pending Add module-swap UX for INT8 mixed-precision training ao#1179training.load_from_full_model_state_dict()
-> cannot integrate with distributed recipes atm.See https://github.com/pytorch/ao/tree/v0.5.0/torchao/prototype/quantized_training#int8-mixed-precision for more details.
For now, I only added the code to show the necessary changes. I'm open to suggestions on how to expose this in torchtune. One idea from mine:
int8_mixed_precision
(similar tocompile
flag). This will be a boolean_setup_model()
-> repeated code for each recipe-> UPDATE: from previous feedback, add a new flag
mixed_precision
Some concerns:
Int8MixedPrecisionTrainingConfig
(see doc). Should we expose it to torchtune's users? From my testing, the default config works well. There might be more knobs to customize in the future too.Int8MixedPrecisionTrainingQuantizer
quantize_()
API, though they can be re-implemented to do so).LoRALinear
will always callF.linear()
on the NF4 weight. If we make the base weight inLoRALinear
a separatenn.Linear
module (instead of plainnn.Parameter()
, then we can swap the linear module to change the outer op.These concerns can be addressed in the future I think, when torchao's training subclasses become more mature/stable.
Note: I can't test with 4090 since FlexAttention errors out on 4090
It's pretty strange since it works fine for another repo of mine 🤔.
Changelog
What are the changes made in this PR?
Integrate INT8 mixed-precision from torchao
0.50.7Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models
Llama3.1-8B single device A100 40% speedup. torch=2.5.0.dev20240911, torchao=0.5.0
Llama3.1-8B FSDP2 2x A100 24% speedup. torch=2.5.1, pytorch/ao#1179
Llama3.1-8B single device A100 LoRA 50% speedup. torch==2.6.0.dev20240914, torchao=0.5.0
LLama3.2-1B single device 4070Ti SUPER QLoRA 60% speedup. torch==2.6.0.dev20241102+cu124, pytorch/ao#1179. Proof-of-concept only since it requires quite significant changes to
LoRALinear
class. See main...gau-nernst:qlora