Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTQ for generate_v2 #1866

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Oct 18, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This PR adds post-training quantization support for generate_v2 via torchao. It is tested only for text-models, specifically Llama2.

Why did you change the way quantization APIs are called?
Good catch - notably I made it so that instead of creating a Quantizer class and having that quantize the model, I opted to use the quantize_ API from torchao and instantiate a quantization method instead. I did this for two reasons:

  1. Simplifies our recipe and codebase.
  2. It more consistent with the usage that torchao seems to be pushing. We want the UX to be the same whether someone is quantizing a model here or directly with torchao APIs

Does this work for vision models?
Technically, it runs, but we haven't fixed the torch.compile graph breaks in the Llama3.2 V model so it doesn't speed anything up. Therefore, I will not be including this in the default config for llama3.2V.

Why is it actually slower for the entire first run?
My assumption is that compile is the culprit here. Once everything has run once, the model compilation is pulled from the compile cache and things are actually faster. Still, quantized generation like this is typically better for longer responses where the benefit is really clear. cc @andrewor14 if my intuition is correct here.

This DOES NOT work for PTQ a QAT model. This will be added in a follow-up.

Changelog

  • Implement PTQ in generate_v2
  • Clean up some of the variables in generate_v2 to make things public
  • Added additional timing to split between first token and rest of tokens
  • Update llama2/generation_v2 to support quantization
  • Added a GPU test for quantized generation :)

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

All testing done with torchao v0.6.1 and torch 2.5.1

Recipe without PTQ:

(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ tune run dev/generate_v2 --config llama2/generation_v2
Running InferenceRecipe with resolved config:

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Llama-2-7b-chat-hf
  checkpoint_files:
  - pytorch_model-00001-of-00002.bin
  - pytorch_model-00002-of-00002.bin
  model_type: LLAMA2
  output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 500
model:
  _component_: torchtune.models.llama2.llama2_7b
prompt:
  system: You are a helpful and creative AI assistant.
  user: What is the capital of France?
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  max_seq_len: 2048
  path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
top_k: 300

Model was initialized with precision torch.bfloat16.
Time to generate first token: 0.45 sec

 Oh, how delightful! *adjusts glasses* The capital of France is... *drumroll* Paris! 🇫🇷 Yes, the City of Light, the City of Love, the City of Art, and the City of Delicious Croissants. 🥐 Is there anything else I can help you with? 😊

Time for inference: 4.93 sec total, 17.04 tokens/sec
Bandwidth achieved: 235.60 GB/s
Max memory allocated: 13.95 GB

Recipe with PTQ (first run):

(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ tune run dev/generate_v2 --config llama2/generation_v2
Running InferenceRecipe with resolved config:

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Llama-2-7b-chat-hf
  checkpoint_files:
  - pytorch_model-00001-of-00002.bin
  - pytorch_model-00002-of-00002.bin
  model_type: LLAMA2
  output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 500
model:
  _component_: torchtune.models.llama2.llama2_7b
prompt:
  system: You are a helpful and creative AI assistant.
  user: What is the capital of France?
quantization_method:
  _component_: torchao.quantization.quant_api.int4_weight_only
  use_hqq: false
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  max_seq_len: 2048
  path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
top_k: 300

Model was initialized with precision torch.bfloat16.
Compiling model layers with torch.compile...
Time to generate first token: 18.98 sec

 Ah, a question that is both simple and profound! *adjusts glasses* The capital of France, my dear human, is none other than the venerable city of Paris! 🇫🇷

But let me tell you more about this magnificent city, for it is a place of wonder and awe. Paris is home to some of the most iconic landmarks in the world, such as the Eiffel Tower, the Louvre Museum, and the Notre-Dame Cathedral. The city is also renowned for its exquisite cuisine, its vibrant art scene, and its unparalleled fashion.

And did you know that Paris is the City of Light? *winks* It is here that some of the greatest minds in history have come to seek inspiration and knowledge. From the likes of Victor Hugo to Emile Zola, and from Claude Monet to Pierre-Auguste Renoir, the City of Paris has been the birthplace of countless artistic masterpieces.

So there you have it, my dear human! The capital of France is none other than the enchanting city of Paris, a place that will capture your heart and imagination like no other. 💖

Time for inference: 27.66 sec total, 9.84 tokens/sec
Bandwidth achieved: 136.00 GB/s
Max memory allocated: 13.95 GB

Recipe with PTQ (second run):

(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ tune run dev/generate_v2 --config llama2/generation_v2
Running InferenceRecipe with resolved config:

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Llama-2-7b-chat-hf
  checkpoint_files:
  - pytorch_model-00001-of-00002.bin
  - pytorch_model-00002-of-00002.bin
  model_type: LLAMA2
  output_dir: ./
device: cuda
dtype: bf16
log_level: INFO
max_new_tokens: 500
model:
  _component_: torchtune.models.llama2.llama2_7b
prompt:
  system: You are a helpful and creative AI assistant.
  user: What is the capital of France?
quantization_method:
  _component_: torchao.quantization.quant_api.int4_weight_only
  use_hqq: false
seed: 1234
temperature: 0.6
tokenizer:
  _component_: torchtune.models.llama2.llama2_tokenizer
  max_seq_len: 2048
  path: /tmp/Llama-2-7b-chat-hf/tokenizer.model
top_k: 300

Model was initialized with precision torch.bfloat16.
Compiling model layers with torch.compile...
Time to generate first token: 4.56 sec

 Ah, a question that is both simple and profound! *adjusts glasses* The capital of France, my dear human, is none other than the venerable city of Paris! 🇫🇷

But let me tell you more about this magnificent city, for it is a place of wonder and awe. Paris is home to some of the most iconic landmarks in the world, such as the Eiffel Tower, the Louvre Museum, and the Notre-Dame Cathedral. The city is also renowned for its exquisite cuisine, its vibrant art scene, and its unparalleled fashion.

And did you know that Paris is the City of Light? *winks* It is here that some of the greatest minds in history have come to seek inspiration and knowledge. From the likes of Victor Hugo to Emile Zola, and from Claude Monet to Pierre-Auguste Renoir, the City of Paris has been the birthplace of countless artistic masterpieces.

So there you have it, my dear human! The capital of France is none other than the enchanting city of Paris, a place that will capture your heart and imagination like no other. 💖

Time for inference: 11.92 sec total, 22.82 tokens/sec
Bandwidth achieved: 315.49 GB/s
Max memory allocated: 13.95 GB

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

To-do

Fix failing GPU test. It's passing locally, so I'm not sure how to make it work on the remote runners:

(joe-torchtune-2) [[email protected] ~/projects/joe-torchtune (add-quantize-generate-v2)]$ python -m pytest tests/recipes/dev/test_generate_v2.py::TestGenerateV2::test_llama2_generate_with_quantization --with-integration
Expected artifacts for test run are:
small-ckpt-tune-03082024.pt
small-ckpt-meta-03082024.pt
small-ckpt-hf-03082024.pt
small-ckpt-tune-llama3-05052024.pt
small-ckpt-hf-reward-07122024.pt
tokenizer.model
tokenizer_llama3.model
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-meta-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-03082024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-tune-llama3-05052024.pt
File already exists locally: /tmp/test-artifacts/small-ckpt-hf-reward-07122024.pt
File already exists locally: /tmp/test-artifacts/tokenizer.model
File already exists locally: /tmp/test-artifacts/tokenizer_llama3.model
================================================================================================================ test session starts ================================================================================================================
platform linux -- Python 3.11.9, pytest-7.4.0, pluggy-1.5.0
rootdir: /home/jrcummings/projects/joe-torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0
collected 1 item

tests/recipes/dev/test_generate_v2.py .                                                                                                                                                                                                       [100%]

================================================================================================================ 1 passed in 42.70s =================================================================================================================

Copy link

pytorch-bot bot commented Oct 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1866

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 Cancelled Jobs, 1 Unrelated Failure

As of commit 23c9bb7 with merge base 33b8143 (image):

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 18, 2024
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
self.device = utils.get_device(device=cfg.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a public recipe, no need to be a "private" variable.

cc @pbontrager


# Quantize the model if specified
if cfg.get("quantization_method") is not None:
from torchao.quantization.quant_api import quantize_
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazily import torchao API

from torchao.quantization.quant_api import quantize_

quantization_method = config.instantiate(cfg.quantization_method)
compile_model(model)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compiling the model is necessary for quantization to be really worth it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious whether compiling the model results in greater speedups than compiling the next-token-prediction fn like gptfast do

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should compile after quantize_ for speedup actually

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! I was following the pattern from AO's README where the model is compiled first:

model = torchao.autoquant(torch.compile(model, mode='max-autotune'))

Why should the model be compiled after quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Anecdotally, I don't see much difference in tok/sec (after first token) between putting compile first or second. Can you share some more details about which one is correct?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right now quantize_ needs to compile after: https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#full-affine-quantization-flow-example

but autoquant will do compile first before calling autoquant

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mode='max-autotune' will take a long time to compile. Is it worth it? We dont do it for training.

Its interesting that in AO's read it says to put compile first. Do we also do it for QLoRA?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see much difference in tok/sec (after first token) between putting compile first or second.

I haven't tried calling quantize_ after compile actually, maybe it would have the same effect as well, need to confirm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its interesting that in AO's read it says to put compile first. Do we also do it for QLoRA?

This is only for autoquant actually, we also haven't tested QLoRA, I think @andrewor14 is taking a look now

If compile time is a concern, we are also thinking of just do autoquant before hand and save the model, but I'm still testing that path as well


# 6. Prefill step
generated_tokens = []
t0 = time.perf_counter()
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
t1 = time.perf_counter()
Copy link
Contributor Author

@joecummings joecummings Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we might have a warmup run, we log this differently so the user can see how good quantization / compilation is.

@@ -9,6 +9,10 @@
# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this commented out until the user wants to do something with it.

@joecummings joecummings linked an issue Oct 18, 2024 that may be closed by this pull request
@joecummings joecummings changed the title [WIP] Quantization for generate_v2 [WIP] PTQ for generate_v2 Oct 18, 2024
prompt = torch.tensor(
model_inputs["tokens"], device=self._device
).unsqueeze(0)
prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :]
Copy link
Contributor Author

@joecummings joecummings Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted this to fit on one line lol

@@ -18,6 +19,13 @@
CACHE_ARTIFACTS_SCRIPT_PATH = root + "/tests/cache_artifacts.sh"


def pytest_sessionfinish():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compile tries to log a bunch of stuff using the atexit decorator. However, pytest closes these logs before they finish so it throws an I/O error.

This disables logging exceptions. Not sure if the right way to do it.

@joecummings joecummings marked this pull request as ready for review October 26, 2024 14:45
@joecummings joecummings changed the title [WIP] PTQ for generate_v2 PTQ for generate_v2 Oct 26, 2024
# Generation arguments
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
max_new_tokens: 200
max_new_tokens: 500
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow longer generation to really see the benefit of quant + compile.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 30.43478% with 16 lines in your changes missing coverage. Please review.

Project coverage is 25.92%. Comparing base (23c8829) to head (0575b67).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
tests/recipes/dev/test_generate_v2.py 25.00% 15 Missing ⚠️
tests/conftest.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1866       +/-   ##
===========================================
- Coverage   70.44%   25.92%   -44.53%     
===========================================
  Files         308      308               
  Lines       16270    16292       +22     
===========================================
- Hits        11462     4224     -7238     
- Misses       4808    12068     +7260     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@joecummings
Copy link
Contributor Author

@felipemello1 @ebsmothers Will this not pass on PyTorch 2.5 b/c of the issue with CUDNN? This test passes locally on PyTorch v2.5.1.

Do we know when the patch will be released?

@@ -9,6 +9,10 @@
# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM
# quantization_method:
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb q: so the torchtune.training.quantization API is just for QAT.. or we're not using it anymore?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you mentioned this in the PR description - if we're going to be using the torchao APIs instead it'd be good to follow up with an issue

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Oct 28, 2024

This looks overall sensible, but a few outstanding questions I have:

  • What implications does this have for how we expose quantization APIs?
  • What is going on with compile?
  • Why is memory usage identical for non-PTQ, and PTQ? I guess because we're still peaking when we load weights in bf16, and we're measuring global max memory usage?
  • Why is it so slow? Even the second run of PTQ takes 12s vs 5s for non-PTQ - the bump in toks/s doesn't seem to offset whatever else is slowing it down
  • Noob q: given the above two points - max memory usage is identical and it takes longer... when would someone want to use this?

We probably don't need to answer all of these here but I think it'd help bring a lot of our quantization offerings in line if we can at least follow up on them.

@joecummings
Copy link
Contributor Author

  • What implications does this have for how we expose quantization APIs?

I think the question is actually if we want to support PTQ APIs outside of torchao. If we do, we want want to opt for an approach like Hugging Face's wherein a config for a specific backend can be initialized. I'd argue that we probably don't want to b/c torchao already supports general quant, HQQ, and GPTQ (altho GPTQ is not available through the quantize_ API yet). Idk if this is too short sighted though.

  • What is going on with compile?

Not sure I understand the question. It's always slow during warmup run.

  • Why is memory usage identical for non-PTQ, and PTQ? I guess because we're still peaking when we load weights in bf16, and we're measuring global max memory usage?

Exactly.

  • Why is it so slow? Even the second run of PTQ takes 12s vs 5s for non-PTQ - the bump in toks/s doesn't seem to offset whatever else is slowing it down

Not sure what is so slow, but I've reached out to the AO team to see if this is normal.

  • Noob q: given the above two points - max memory usage is identical and it takes longer... when would someone want to use this?

An excellent question. I don't imagine anyone would want to use this recipe out of the box with quantization. However, it's a great playground for showing how easy it is to setup quantization with our models. The real benefit comes from serving this model somewhere so that you can compile + quant once and get continuous speed-ups for everything downstream. Also, if we end up having a super simple chat component, this would also demonstrate gains.

@andrewor14
Copy link
Contributor

Are you seeing the slowdown for int4_weight_only specifically? That's surprising since we have an efficient tinygemm cuda kernel for that, and the model size should actually be 1/4 of the original bf16 model size (unlike int8_dynamic_activation_int4_weight). Also cc @jerryzh168 @HDCharles who did some benchmarking on this from the AO side

# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM
# quantization_method:
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory
# use_hqq: False # Turn on for more accurate results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HDCharles is this true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry this was anecdotal.

@joecummings
Copy link
Contributor Author

Are you seeing the slowdown for int4_weight_only specifically?

I tried both int4_weight_only and dynamic activation version and both had initial slowdowns for the entire first run, but afterwards ran faster.

@jerryzh168
Copy link
Contributor

Are you seeing the slowdown for int4_weight_only specifically?

I tried both int4_weight_only and dynamic activation version and both had initial slowdowns for the entire first run, but afterwards ran faster.

slower on the first run is expected I feel, since compile actually happens at the first run when it sees the real inputs, typically when we do benchmark there will be some warmup runs for compile to actually run and we'll benchmark the following runs

@joecummings
Copy link
Contributor Author

Are you seeing the slowdown for int4_weight_only specifically?

I tried both int4_weight_only and dynamic activation version and both had initial slowdowns for the entire first run, but afterwards ran faster.

slower on the first run is expected I feel, since compile actually happens at the first run when it sees the real inputs, typically when we do benchmark there will be some warmup runs for compile to actually run and we'll benchmark the following runs

I know that compile happens at the first forward pass, but what I'm seeing is a slowdown for the entire first generation of outputs (see logs in the PR description. Is this expected?

# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM
# quantization_method:
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory
# use_hqq: False # Turn on to use Half-Quadratic Quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean? Can you add if it makes it faster/more accurate/less memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, what i meant is that this should be made clear for the user in the comment :P

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement quantized model inference for generate_v2
7 participants