Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate INT8 mixed-precision from torchao 0.7 #1552

Open
wants to merge 49 commits into
base: main
Choose a base branch
from

Conversation

gau-nernst
Copy link
Contributor

@gau-nernst gau-nernst commented Sep 12, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Recent INT8 mixed-precision work in torchao shows very promising results.

  • Single device A100 -> ~40% speedup
  • Single device 4090 -> ~70% speedup (consumer 4000 series GPUs have unusual speedup, which is nice)
  • Works with FSDP2

Known major limitations

  • Requires torch.compile() to enjoy speedup (to codegen efficient dynamic quantization code)
  • Input sizes should not vary too much, since it will trigger autotune for triton INT8 matmul kernel -> this only works well for PackedDataset w/ FlexAttention, since seq_len is static.
  • Does not work with training.load_from_full_model_state_dict() -> cannot integrate with distributed recipes atm. -> solved by using module-swap UX instead. Pending Add module-swap UX for INT8 mixed-precision training ao#1179

See https://github.com/pytorch/ao/tree/v0.5.0/torchao/prototype/quantized_training#int8-mixed-precision for more details.

For now, I only added the code to show the necessary changes. I'm open to suggestions on how to expose this in torchtune. One idea from mine:

  • Add a global config flag int8_mixed_precision (similar to compile flag). This will be a boolean
  • Handle it inside _setup_model() -> repeated code for each recipe
    -> UPDATE: from previous feedback, add a new flag mixed_precision

Some concerns:

  • It's possible to customize INT8 mixed-precision via Int8MixedPrecisionTrainingConfig (see doc). Should we expose it to torchtune's users? From my testing, the default config works well. There might be more knobs to customize in the future too.
    • UPDATE: expose all options via Int8MixedPrecisionTrainingQuantizer
  • Ability to extend to other torchao's subclasses? e.g. Float8 and NF4 (right now they don't use quantize_() API, though they can be re-implemented to do so).
    • UPDATE: the better question is how to compose this with QLoRA (i.e. NF4). LoRALinear will always call F.linear() on the NF4 weight. If we make the base weight in LoRALinear a separate nn.Linear module (instead of plain nn.Parameter(), then we can swap the linear module to change the outer op.

These concerns can be addressed in the future I think, when torchao's training subclasses become more mature/stable.

Note: I can't test with 4090 since FlexAttention errors out on 4090

triton.runtime.errors.OutOfResources: out of resource: shared memory

It's pretty strange since it works fine for another repo of mine 🤔.

Changelog

What are the changes made in this PR?

Integrate INT8 mixed-precision from torchao 0.5 0.7

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

mixed_precision._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer mixed_precision.enabled=True
  • I did not change any public API;
  • I have added an example to docs or docstrings;

Llama3.1-8B single device A100 40% speedup. torch=2.5.0.dev20240911, torchao=0.5.0

tune run full_finetune_single_device --config llama3_1/8B_full_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True

image

Llama3.1-8B FSDP2 2x A100 24% speedup. torch=2.5.1, pytorch/ao#1179

tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full tokenizer.max_seq_len=8192 dataset.packed=True optimizer.fused=True compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True

image

Llama3.1-8B single device A100 LoRA 50% speedup. torch==2.6.0.dev20240914, torchao=0.5.0

tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device tokenizer.max_seq_len=8192 dataset.packed=True compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True gradient_accumulation_steps=1

image

LLama3.2-1B single device 4070Ti SUPER QLoRA 60% speedup. torch==2.6.0.dev20241102+cu124, pytorch/ao#1179. Proof-of-concept only since it requires quite significant changes to LoRALinear class. See main...gau-nernst:qlora

tune run lora_finetune_single_device --config llama3_2/1B_qlora_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True batch_size=1 enable_activation_checkpointing=True

image

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1552

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its very simple and it looks great! thanks for the PR!

My two cents:

  1. We need some tests to make sure it works with compile, AC, offloading (not landed yet), optimizer in backward (i guess?), etc
  2. The configs can be updated in bulk using something like this dummy script: https://gist.github.com/felipemello1/5f2002433c6da3a21f33d6cdf82e702a

Let me know if you want me to help with any of these

torchtune/training/quantization.py Outdated Show resolved Hide resolved
torchtune/training/quantization.py Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: would it be nice to add some sort of import guard to tell the user they need torchao > N for this? Torchao is not a requirement anymore

cc: @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with older GPUs? Does it work on CPU? Maybe we need something like this:

_SUPPORTS_INT8_MIXED_PRECISION = (
    torch_version_ge("2.5.0")
    and torch.cuda.is_available()
    and torch.cuda.get_device_capability() >= (7, 5)
)

and torch.cuda.get_device_capability() >= (7, 5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I use Triton for this so it will probably run on any GPUs Triton supports (same as torch.compile). Though I think only Ampere (sm80) and above has INT8 tensor cores. To be safe, I think we just guard for sm80 and above.

CPU is not supported. Technically it is possible, but I didn't add it in torchao since I can't reliably test it / see it useful.

This works with PyTorch 2.4 (though FlexAttention requires PyTorch 2.5 🤣).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw doesn't QAT also need some kind of guards like this? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I completely missed this convo previously. Actually now that we've asked users to install torchao manually I am kinda taking a similar stance there to what we have with PyTorch: people should always be running on the latest stable version. So actually I claim that the first two lines of the _SUPPORTS_INT8_MIXED_PRECISION_TRAINING check are not strictly necessary (fine to keep them in though, I don't have a strong preference). But anyways that should hopefully answer the question about why we don't have similar such guards for QAT

recipes/full_finetune_single_device.py Outdated Show resolved Hide resolved
torchtune/training/quantization.py Outdated Show resolved Hide resolved
recipes/full_finetune_single_device.py Outdated Show resolved Hide resolved
@gau-nernst
Copy link
Contributor Author

@felipemello1

can you help me understand why memory is not changed?

Memory is not expected to change. In this scheme, weights and activations are quantized dynamically to INT8 (with scaling, so that we can do INT8 matmul), then the result is scaled back to BF16. Activations and weights are still in BF16 throughout the model. This is the same strategy as the current torchao.float8 (in fact, in some bad cases, torchao.float8 can consume more memory in FSDP2 due to some autograd shenanigans).

To reduce memory, either weight or activations must stay in low-precision. Using only low-bit weight for training is a bit challenging (not impossible, but there will be convergence/accuracy issues). There are some research works on using low-bit activations (FP8/INT8), but we don't have that in torchao yet. e.g.

@felipemello1
Copy link
Contributor

felipemello1 commented Nov 23, 2024

I tested it on multiple other models. Main issues:

  1. For LoRA/QLoRA, there is no WPS improvement, which is different from your results. Maybe there is a bug?
    torch 2.5.1
    torchao 0.7.0.dev20241121+cu124
    image

  2. compile and packed are requisites. We should add it to the recipe init and raise an error if mixed_precision.enabled=True but compile=False or packed=False.
    image

  3. Lets also add the changes to lora_distributed, if you dont midn

Lets address these and I am good to merge. I can edit all configs in a follow up PR :)

Thanks again @gau-nernst :)

@gau-nernst
Copy link
Contributor Author

  1. This PR does not support LoRA/QLoRA. Because now we use module-swap to inject INT8 matmul, we can't replace LoRALinear with the custom INT8 Linear module. Possible solutions
    a. Make base weight of LoRALinear a separate nn.Linear module. I did it here to demonstrate main...gau-nernst:qlora (you only need to look at torchtune/modules/peft/lora.py). The QLoRA result in this PR description is done using this branch (as stated). However, Evan has some reservations as outlined here Add support for QAT + LoRA #1931 (review)
    b. Another option is to propagate the INT8 flag to LoRALinear, and LoRALinear will directly call torchao.prototype.quantized_training.int8_mixed_precision._Int8MixedPrecisionTrainingLinearFunction.apply(). This is also not exactly ideal. Open to other options.
  2. ok
  3. Since this PR doesn't support LoRA/QLoRA, this is not necessary.

@felipemello1
Copy link
Contributor

felipemello1 commented Nov 23, 2024

sorry, I realize that you mentioned it about LoRA, but I didnt connect the dots. Let me talk to Evan and get back to you. Maybe we can land it for full finetuning only, since this is working well and we can follow up about LoRA. I will reply here on Monday.

@ebsmothers
Copy link
Contributor

Sorry for the delay here @gau-nernst and @felipemello1. After looking through the code again I think this is clean enough to move forward with (subject to Felipe's comment), even if we do only support full finetune.

I need to do a bit more research to have a more informed opinion on the options laid out for handling LoRA. Regarding option (a): I took a look at the implementation in main...gau-nernst:qlora and while I am still averse to state dict hooks it is quite clean.

Separately I am still trying to understand pros and cons of module-swap vs tensor subclass -- I know tensor subclass needs custom handling for state dict, are there drawbacks to module-swap (I previously there were potentially challenges with FSDP, but actually it seems like that's not the case)? Also the point raised in pytorch/ao#1179 about composability is not clear to me, isn't this a function of one quantization method (NF4) using tensor subclass and another (int8) using module swap? If I look at pytorch/ao#987 it claims that composability of tensor subclass is better (at least with other tensor subclasses). Lmk if I'm missing the point here though.

Again, I don't want to block on any of this because unfortunately even the pretty trivial change to LoRALinear on your fork would break BC of checkpoints and may have other non-obvious implications for our various utilities and recipes that I am not yet aware of.

So let's move forward with this for full finetune only (I will give a proper review shortly) and separately we can create an RFC to discuss pros and cons of a possible LoRA redesign.

@ebsmothers ebsmothers mentioned this pull request Nov 28, 2024
44 tasks
@gau-nernst
Copy link
Contributor Author

@ebsmothers Thanks for your feedback. I will add "raise an error if mixed_precision.enabled=True but compile=False or packed=False", as commented by Felipe, soon.

Regarding module-swap vs tensor subclass, after all they are just different ways to inject custom logic to an existing model. Currently, we have 2 general use cases:

  1. Use quantized weights (e.g. NF4, INT4/INT8). Only tensor subclass can do this (since we change storage of the weight). More generally, this can be considered any kind of custom parametrization (e.g. torch.nn.utils.parameterize - though this uses a different approach)
  2. Modify F.linear(). Usually this goes hand-in-hand with (1) for inference of quantized models, since we want to call custom/specialized kernels for quantized weights. However, in cases where we don't use quantized weight (i.e. only modify F.linear(), weights are plain tensors), it's possible to use module-swap. e.g. QAT, INT8 training introduced here, torchao.float8.
    • Also want to repeat again that INT8 and FP8 training keeps weights in FP32/BF16, and quantize them dynamically during training. This is because we need high precision for weight udpate.
Use case Tensor subclass Module swap
Quantized weight x
Modify F.linear() x x

In terms of "composability", I don't think there is a clear answer about which one is better, because it largely depends on what kind of functionalities being composed together. The point about "composability" that I mentioned in pytorch/ao#1179 is specifically about custom weight-only quantization + low-bit compute. e.g.

  • Quantized weight e.g. NF4 -> delegate to tensor subclass
  • Low-bit compute e.g. INT8/FP8 matmul -> delegate to module-swap

This is possible since for weight-only quant, we don't have custom F.linear() kernel (just dequant then normal matmul). Hence, we can replace the normal matmul with the fast INT8/FP8 matmul.

Of course it's also possible to use INT8-matmul-subclass + NF4-subclass: INT8-matmul-subclass can hold NF4-subclass internally (similar to how dynamic activation quant is implemented in torchao). However, this can be a bit messy: INT8-matmul-subclass may need to be aware of the NF4-subclass implementation, multiple level of indirection is hard to debug, and possibly many others. So at least for this use case of "weight-only quant + low-bit compute", delegate one function to subclass and another to module-swap seems clearer.

Hope it somewhat clarifies the options presented here (and hope that it didn't confuse you further 😅)

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gau-nernst @felipemello1 are we aiming to get this in for our next release? I left a few comments but I think (modulo a final round of testing, which I'm happy to help with) the remaining gap (at least for full finetune) to land should be pretty small.

recipes/lora_finetune_single_device.py Outdated Show resolved Hide resolved
torchtune/training/quantization.py Outdated Show resolved Hide resolved
torchtune/training/quantization.py Outdated Show resolved Hide resolved
@gau-nernst
Copy link
Contributor Author

I addressed all of the remaining comments. Please take another look if I missed out anything. Thank you.

Also added mixed_precision entry to all full-finetune and full-finetune-distributed configs. Did a few test runs and everything seems fine.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gau-nernst I took a look through the code and no concerns from me, everything looks good. I think @felipemello1 landed a bunch of changes today that will cause merge conflicts on this PR, if it's too much of a headache just let us know and he can do the merge then push to your PR. Last thing: I was trying to test out the perf gains one final time before stamping but can't seem to find a good PyTorch nightly to run on (seems like lots of issues with nightlies lately). Lmk if you have a good stable env that I can test with (until then I will blindly try out various nightlies 😅 )

update: ok I got a nightly version that "works". Unfortunately it doesn't look great:

Screenshot 2024-12-06 at 3 50 29 PM

Repro:

Baseline command:

tune run full_finetune_single_device --config llama3/8B_full_single_device dataset.packed=True compile=True \
tokenizer.max_seq_len=2048 metric_logger=torchtune.training.metric_logging.WandBLogger \
metric_logger.project=mp-tests metric_logger.name=no-mp-latest

Mixed precision command:

tune run full_finetune_single_device --config llama3/8B_full_single_device mixed_precision.enabled=True \
dataset.packed=True compile=True tokenizer.max_seq_len=2048 \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=mp-tests \
metric_logger.name=with-mp-latest

My env. Like I said above I had a pretty hard time finding a recent working nightly (and I don't think ao 0.7 is quite in stable yet), so maybe it's just something messed up in my environment. But a sanity check here would be much appreciated

@gau-nernst
Copy link
Contributor Author

@ebsmothers PyTorch 2.5 should work. I have been testing with PyTorch 2.5. Nightly is not needed.

I can confirm that INT8 mixed-precision does not play well with bnb 8-bit AdamW. Loss will explode shortly after training starts. I'm not sure why. Switching to the plain PyTorch AdamW seems to solve the issue.

In terms of speed, indeed batch size must be sufficiently large to see speedup. I ran a few runs

bsize x seq_len = num_tokens BF16 speed INT8-mp speed Speedup Screenshot
2 x 2048 = 4,096 3400 tok/s 3500 tok/s +3% screenshot
4 x 2048 = 8,192 3700 tok/s 4400 tok/s +19% screenshot
8 x 2048 = 16,384 3900 tok/s 4900 tok/s +26% screenshot
16 x 2048 = 32,768 4000 tok/s 5100 tok/s +28% screenshot
4 x 8192 = 32,768 4100 tok/s 5300 tok/s +29% screenshot

We probably should add these 2 notes somewhere

  1. Divergence when used together with 8-bit optimizers
  2. Require >=4x2048 tokens/batch to see speedup

Perhaps you can advise me an appropriate place to put these notes?

@ebsmothers
Copy link
Contributor

@gau-nernst thanks for the detailed analysis. On a new env with PyTorch 2.5 and ao nightlies I see similar results to you for batch size 8, max seq len 2048. I have a few more comments -- the first one should be handled in this PR but the others are more for follow-ups.

  1. Regarding where to put the notes on usage with 8-bit optimizers and the need for a large number of tokens in the batch, what do you think about a utility in quantization.py? Then we can pass the different config fields we need and give warnings (or potentially even throw an error) as needed. This can also include compile and packed btw. I think validation of the optimizer is a bit tricky (e.g. I don't think there is a non-hacky way to check that an arbitrary optimizer path is for an 8-bit optimizer, and we need to validate before instantiating the optimizer). However, given the loss divergence I have half a mind to explicitly require this to be used with torch.optim.AdamW only, and to raise an error otherwise. I know this is fairly restrictive, so open to a discussion on whether you think this is the right choice. (Alternatively we can modify the optimizer and raise a warning, but that feels a bit too sneaky to me.) For tokens/batch I would just log a warning if it's below some specified threshold, but leave it up to the user whether they want to do anything. This should be easy to infer from config (assuming it holds independent of the # of devices).

  2. There are some other pieces that could require more detailed education too. Beyond the two that you mention, I think we should e.g. suggest people not to use this with Llama 1B models (at least when I ran it with those I saw slower training, probably due to the dominance of embedding and output projection layers). In terms of where to put something like this, we already have our memory optimizations overview docs page, and while this doesn't fit in there, I actually think it could be worthy of its own standalone docs page. There's a lot we can do here to educate folks, both in terms of the underlying implementation and on recommendations for usage. Depending on how far you want to take it, there's a world where we convert this into a blog (this is similar to what we did with e.g. our knowledge distillation recipe).

  3. Finally, one thing that makes me a bit nervous is that even with AdamW we see a bigger delta in loss curves than I would like (at least around 1e-1 on some early iterations). The good news is that in our test runs it does eventually converge, but this is a behavior I want us to be very explicit with people about. (Again this kinda ties into (2) around ensuring people know how to use the API correctly and what they should expect.)

Anyways, none of this is to push back too much on the PR. I think once (1) is addressed we should land it. It's an awesome feature that most other finetuning frameworks don't have, I'm really excited to support it. But given that we are one of the first to have it, it's up to us to set the right expectations for the community on how to use it (hence the importance of (2) and (3)).

@gau-nernst
Copy link
Contributor Author

@ebsmothers

  1. How about always add a warning inside torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer? Either in __init__() or in prepare(). This should warn about both 8-bit optimizer (for convergence) and large batch size (for speedup). compile and packed are currently checked at the recipe level, though it may make sense to consolidate all the checks within the quantizer class itself (then we possibly need to pass the whole config object to the quantizer, and the quantizer can check if the config makes sense / raise warning/error).
    1.1 Like you have pointed out, checking for the optimizer is hard. We can check for the common ones (e.g. from bnb and torchao, since those are used in torchtune), which might be fine for 90% of the users, but it won't cover everything. Do you think if we log a warning, user will read it and it won't give a bad experience? versus straight-out error
  2. In terms of education, I have this over at ao doc: https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision-training, which is also referred to in Int8MixedPrecisionTrainingQuantizer docstring. You all had the "first user experience" with this in torchtune, so perhaps you can feedback to me what needs to be clearer/what kind of info is useful to users, and I'm happy to write a more detailed version catered to torchtune. You can decide where to put this to give the best visibility, I don't really mmind.
  3. Yes, agree.

Also, regarding 1B, I tested it and indeed A100 does not see significant speedup, but on my personal GPU, the speedup is still observed at 4x2048.

tune run full_finetune_single_device --config llama3_2/1B_full_single_device optimizer=torchao.prototype.low_bit_optim.AdamW8bit compile=True tokenizer.max_seq_len=2048 dataset.packed=True enable_activation_checkpointing=True metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=int8mp-1B metric_logger.name=int8mp_8bit-adamw_8x2048 mixed_precision.enabled=True batch_size=8

On A100

bsize x seq_len = num_tokens BF16 speed INT8-mp speed Speedup Screenshot
4 x 2048 = 8,192 20,000 tok/s 20,000 tok/s +0.0% screenshot
8 x 2048 = 16,384 21,000 tok/s 22,000 tok/s +4.8% screenshot
16 x 2048 = 32,768 22,000 tok/s 23,000 tok/s +4.5% screenshot
32 x 2048 = 65,536 22,000 tok/s 22,000 tok/s +0.0% screenshot

On 4070Ti SUPER

bsize x seq_len = num_tokens BF16 speed INT8-mp speed Speedup Screenshot
4 x 2048 = 8,192 7,000 tok/s 10,000 tok/s +43% screenshot
8 x 2048 = 16,384 7,000 tok/s 11,000 tok/s +57% screenshot

My GPU OOM with larger batch sizes (only 16GB 😢). 4090 (24GB VRAM) should have the same speedup characteristic. This makes sense because INT8 tensor cores in consumer cards are rated at 4x faster than BF16/FP16 tensor cores, while enterprise GPUs are only rated at 2x speedup. But it may indicate that there is room for improvement on A100.

Anyway, this is to show that the speedup is different across GPUs. Perhaps once we work on (2), we can do more extensive end2end benchmarks across model sizes and GPUs.

@ebsmothers
Copy link
Contributor

@gau-nernst on (1): this sounds good to me. The one minor point I would change if possible is to pass each field explicitly to a utility like validate_env_for_mp_training or something (instead of passing the entire config). The args will be a bit of a hodge-podge (packed, optimizer path, etc), but that way we have an explicit contract on which fields we're validating against (and it can be seen directly from the callsite). Regarding the optimizer, I am fine to log a warning instead of raise an error. But yeah in general it's hard to comprehensively know all 8-bit optimizers -- hence why I think it's safest to use an allowlist-type approach (since we then say explicitly which ones are supported). Ultimately either way is OK as long as the process for inferring that someone is using a low-precision optimizer isn't too too hacky.

Re (2): yes, I think the ao readme page is great here. I think if we wanted to expand upon it in our torchtune docs, there are two potential directions we could go: lower-level (i.e. getting into the underlying implementation in more detail) or more user-focused (i.e. discussing different feature interactions, giving perf numbers, running experiments, etc.). If we combine what you already have with one or both of these, it'd make a nice standalone page on our live docs. But again we can figure out the details here in a follow-up. So yeah the ao documentation is already really clear in terms of usage, it's just a matter of seeing if there's a nice narrative we can provide around the feature to get people even more aware and interested to try it.

Thanks for the benchmarking results! This is interesting and not something I would have expected. And totally agree -- this is exactly the kind of thing that would be great to discuss in more detail in (2).

@felipemello1
Copy link
Contributor

though it may make sense to consolidate all the checks within the quantizer class itself

I like that :)

@joecummings joecummings added triage review This issue should be discussed in weekly review and removed triage review This issue should be discussed in weekly review labels Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants