Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate INT8 mixed-precision from torchao 0.7 #1552
base: main
Are you sure you want to change the base?
Integrate INT8 mixed-precision from torchao 0.7 #1552
Changes from 5 commits
cf3355e
d3bbaeb
5a61d3e
560039d
2b6e066
d32f5b8
60dad97
8395070
b7b8a7d
2829b03
688a1c8
21391ad
f885d56
7db782c
8306f9a
25a2451
86d5f04
6094cdb
19a2d3e
faec18d
f4f1945
8fc2826
911df57
51bbeac
1e5ae92
1e4eaf6
30585c2
45b4365
3e5b040
1ac836a
5d94cb3
06abd88
0ff702e
05563f2
3050c32
864c6fb
1fed859
66e8cdd
207308b
0fecc26
39e1fc1
b2bc5ef
a334986
d149801
03a1978
35ca06a
0699aa3
be9c0fb
ca29866
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Context in this comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question: would it be nice to add some sort of import guard to tell the user they need torchao > N for this? Torchao is not a requirement anymore
cc: @ebsmothers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work with older GPUs? Does it work on CPU? Maybe we need something like this:
torchtune/torchtune/modules/attention_utils.py
Line 22 in ee343e6
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! I use Triton for this so it will probably run on any GPUs Triton supports (same as torch.compile). Though I think only Ampere (sm80) and above has INT8 tensor cores. To be safe, I think we just guard for sm80 and above.
CPU is not supported. Technically it is possible, but I didn't add it in torchao since I can't reliably test it / see it useful.
This works with PyTorch 2.4 (though FlexAttention requires PyTorch 2.5 🤣).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw doesn't QAT also need some kind of guards like this? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I completely missed this convo previously. Actually now that we've asked users to install torchao manually I am kinda taking a similar stance there to what we have with PyTorch: people should always be running on the latest stable version. So actually I claim that the first two lines of the
_SUPPORTS_INT8_MIXED_PRECISION_TRAINING
check are not strictly necessary (fine to keep them in though, I don't have a strong preference). But anyways that should hopefully answer the question about why we don't have similar such guards for QATThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is my understanding:
torchao has a nice api that does _quantize(model, config)
This pr creates a wrapper specific for int8. The wrapper is needed for two reasons:
i) ignora lora_a/lora_b
ii) we dont do nested configs
this means that if we want to support fp8. bitnet, or any other torchao technique, we have to create a custom wrapper, instead of doing _quantize(model, torchao_config)
IMO, this makes a lot of sense if:
a) every torchao technique interacts differently with torchtune, e.g. one doesnt work with offloading, another needs some extra work for ckpt, etc, and we cant solve it with a config parser
b) if there are realistically only a couple or two quantization methods we will use from torchao (e.g. int8 and fp8)
But if thats not the case, then we should probably avoid having a custom torchtune wrapper per torchao technique.
I guess we had a similar situation with OptimizerCPUOffload. We ended up giving up on a custom torchtune wrapper and just instantiated directly from the config.
In summary, instead of "Int8MixedPrecisionTrainingQuantizer", should we have a generalist torchtune "PrecisionTrainingQuantizer", just to handle LoRA, etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
PrecisionTrainingQuantizer
is not needed atm. We can revisit it in the future once there are more things like this in the future (like you said, realistically I can only think of int8 and fp8 for now, but it may change)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think that there is a way around us always having our own wrapper, since we need it to deal with LoRA/other details. Therefore, I think that we should add "enabled" here, and have config like this:
That way the configs can advertise this functionality.
I also think that we should rename quantizer to "mixed_precision" or "mixed_precision_training", so its self-explanatory.
Let me know if you need help with making the changes or if you disagree with the ideas above.
thanks again for the PR! @ebsmothers might want to leave a few comments too
For me, after these changes and the test pass, i am good to merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah personally I like the change to rename
quantizer
->mixed_precision
(unless the idea is to later support full low-precision training using the same API, in which case it wouldn't make sense).Edit: after talking with @felipemello1 about this, what do you think about adding the
enabled
flag to the config and checking in the recipe? It'd look like this in the recipe. I think this is a good tradeoff because (a) we show the feature in our configs, (b) we don't conditionally define a no-op version of the quantizer (a pattern I really don't like) and are instead explicit about it in the recipe, and (c) it can be easily extended to other mixed precision strategies in the future.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea I agree
quantizer
name right now can sound confusing, since the weight is actually not quantized (well QAT also does not actually quantize the weight 😄)Regarding
mixed_precision
name, I'm thinking about what other training recipes out there that would fall under this category (because it would be weird if we havemixed_precision
but only 1 option). There are FP8 and FP16/BF16 mixed-precision training, but they follow a totally different UXtorch.autocast(dtype=torch.bfloat16)
. For FSDP2, it would be usingMixedPrecisionPolicy(param_dtype=torch.bfloat16)
. This is very different UX (and mechanism) from INT8 and FP8, and would be hard to unify the torchtune user-facing API (and not sure if torchtune wants to add this option? Btw there can be small gaps between full BF16 training (current torchtune) and BF16 mixed-precision training, especially when LR is small. Would take some efforts for someone to investigate this 😅).There are also MX FP6/FP4 in the new NVIDIA GPUs, but that's too far in the future. Considering that at least we can have INT8 and FP8 mixed-precision training, using the
mixed_precision
name seems reasonable!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be too different from int8? Or could we leverage the same API and replace the quantization module?
BTW, does FP8 work with FSDP? In this case, is there anything that makes it special, since int8 seems to not be compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@felipemello1 Currently torchao API is different https://github.com/pytorch/ao/tree/main/torchao/float8. But it can be hidden away by torchtune (e.g. a separate Quantizer class, or some other name like MixedPrecisionTrainer class)
Both FP8 (by FP8 folks) and INT8 (by me) support FSDP2. The issue is with how torchtune handles distributed weight loading, which is not trivial to handle. FP8 module swap might not have this problem though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neither 3b or 8b full single device are working for me using A100
torch 2.6.0.dev20241107+cu124
torchao 0.7.0.dev20241111+cu124
torchtune 0.0.0 /data/users/felipemello/torchtune
torchvision 0.20.0.dev20241107+cu124
I also tried with torch 2.5.1.
I think that the issue is with recompiling. It takes several minutes on this step (10+min). I wonder its recompiling non-stop, until it runs out of memory and errors with RuntimeError: std::bad_alloc
8b
3b
PS: for 3b we use TiedLinear implementation:
torchtune/torchtune/modules/tied_linear.py
Line 28 in e1caa9f
I printed the shape and dtype, not that it matters
I also checked for 3b, and filter_fn returns false for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding
std::bad_alloc
, I think it's an issue with recent pytorch nightlies. ao CI has been seeing it too pytorch/ao#1229. Currently CI pins to 20241101 https://github.com/pytorch/ao/blob/b2642fb33e360ffe478fe19665b1c4efd80537c6/.github/workflows/regression_test.yml#L64. Can you try with this version?I don't exactly recall, but I think usually I had to kill the first run (it kinda hangs?), then the subsequent run is fast. Didn't think much of it. I will try to reproduce it again in a fresh environment to see if my memory serves me right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I additionally set
quantize_(set_inductor_config=False)
to avoid exhaustive torch.compile tuning. Can you try again? It seemed fast to me (<2min). Worked with both Llama3.1-8B and Llama3.2-3B.torch==2.5.1+cu121
andtorchao==0.7.0.dev20241112+cu121
The command