Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Add COAT optimizer #1190

Open
gau-nernst opened this issue Oct 29, 2024 · 5 comments
Open

[low-bit optim] Add COAT optimizer #1190

gau-nernst opened this issue Oct 29, 2024 · 5 comments
Labels

Comments

@gau-nernst
Copy link
Collaborator

gau-nernst commented Oct 29, 2024

Paper: https://arxiv.org/abs/2410.19313
Code: https://github.com/NVlabs/COAT (not available yet)

Seems like we already have most of the building blocks. The only new logic is "dynamic range expansion"

We can start implementing it first, then wait for the official code release for numeric checks.

@MirMustafaAli
Copy link

I would like to work on this @gau-nernst!!

@gau-nernst
Copy link
Collaborator Author

@MirMustafaAli Go ahead and submit a PR 😄. Let me know if you face any problems.

@MirMustafaAli
Copy link

@gau-nernst Which section should i look for to implement "dynamic range expansion"?. According to my understanding of repo it must be in float8 folder as it's aimed at hopper architecture utilizing type float8. Any pointers, PR's and reference methods which i can follow would be very much helpful for me.

@gau-nernst
Copy link
Collaborator Author

You can park it under torchao/prototype/low_bit_optim. The float8/ folder is more for training stuff (fp8 matmul).

You can extend our current OptimStateFp8. See https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/prototype/low_bit_optim/subclass_fp8.py. I think we only need to change the quantize_fp8 function.

Another option is to create a separate optimizer. See https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/prototype/low_bit_optim/adam.py. You can wrap all of the logic in a functional way (see single_param_adam() in the link above), and add boilerplate code for the optimizer (i.e. init optim states, call the functional optim step with torch.compile...)

@MirMustafaAli
Copy link

MirMustafaAli commented Nov 5, 2024

Thanks!! Will work on your advice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants