Skip to content

Commit

Permalink
Enabling torch.compile for quantized model for speedups
Browse files Browse the repository at this point in the history
Summary:
att

Next:
* we can follow up on memory usage as well

Test Plan:
follow instructions in https://github.com/pytorch/torchtune/tree/main/recipes#architecture-optimization to quantize the model
and modify the generate.yaml file so it can run the int4 weight only quantized model, and run:
```
tune run generate --config generate
```

```
2024-04-08:17:06:31,706 INFO     [generate.py:68] Model is initialized with precision torch.bfloat16.
2024-04-08:17:07:33,793 INFO     [generate.py:113] Hello, my name is Elizabeth. Introverts don't talk much unless we know you really well, which makes it difficult to get to know us.
I am a nerd and a geek, and I like to read. My favorite books are fantasy and sci-fi. I like to write. My favorite genre is romance, (though I hope for this blog to be more than one genre). I have been writing for as long as I can remember. I have a lot to write about but not enough time to do it. I want to write a book one day and someday I will.
I work at my family's restaurant and I take care of my grandmother. She is my whole world.
I like animals and I am a vegetarian.
I am afraid of everything as well as being a little weird because of my diagnosis of autism and anxiety disorder.
Music is my life and I can't live without it. I play the piano and drums and I write my own music.
My dreams are far and away.
I hope this is something you enjoy watching unfold.
Sup, sup, sup, sup!
You're very cool.
"I am afraid of everything as well as being a little weird because of my diagnosis of autism and anxiety disorder."
I am so glad you found me. I am an introvert as well and also a Nerd
2024-04-08:17:07:33,794 INFO     [generate.py:117] Time for inference: 61.78 sec total, 4.86 tokens/sec
2024-04-08:17:07:33,794 INFO     [generate.py:120] Memory used: 17.85 GB
```
Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 9, 2024
1 parent 2ab2721 commit d6279ed
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
25 changes: 25 additions & 0 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,30 @@ def generate(self, cfg: DictConfig):
tokens = self._tokenizer.encode(cfg.prompt, add_bos=True, add_eos=False)
prompt = torch.tensor(tokens, dtype=torch.int, device=self._device)

custom_generate_next_token = None

# since quantized model uses torch.compile to get speedup, it needs a warm up / prefill run
# to get the accurate performance measurement
if self._quantization_mode is not None:
t0 = time.perf_counter()
custom_generate_next_token = torch.compile(
utils.generate_next_token, mode="max-autotune", fullgraph=True
)
t = time.perf_counter() - t0
logger.info(f"compilation for generate_next_token takes: {t:.02f} sec")
t0 = time.perf_counter()
_ = utils.generate(
model=self._model,
prompt=prompt,
max_generated_tokens=2,
temperature=cfg.temperature,
top_k=cfg.top_k,
eos_id=self._tokenizer.eos_id,
custom_generate_next_token=custom_generate_next_token,
)
t = time.perf_counter() - t0
logger.info(f"warmup run for quantized model takes: {t:.02f} sec")

t0 = time.perf_counter()
generated_tokens = utils.generate(
model=self._model,
Expand All @@ -86,6 +110,7 @@ def generate(self, cfg: DictConfig):
temperature=cfg.temperature,
top_k=cfg.top_k,
eos_id=self._tokenizer.eos_id,
custom_generate_next_token=custom_generate_next_token,
)
t = time.perf_counter() - t0

Expand Down
2 changes: 1 addition & 1 deletion torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
validate_no_params_on_meta_device,
wrap_fsdp,
)
from ._generation import generate # noqa
from ._generation import generate, generate_next_token # noqa
from ._profiler import profiler
from .argparse import TuneRecipeArgumentParser
from .checkpointable_dataloader import CheckpointableDataLoader
Expand Down
11 changes: 9 additions & 2 deletions torchtune/utils/_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from typing import Callable, Optional

import torch

Expand Down Expand Up @@ -68,6 +68,7 @@ def generate(
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
custom_generate_next_token: Optional[Callable] = None,
) -> torch.Tensor:
"""
Generate tokens from a model conditioned on a prompt.
Expand All @@ -83,6 +84,9 @@ def generate(
the top_k probabilities. Default is None
eos_id (Optional[int]): If specified, generation is stopped when the eos token is
generated. Default is None
custom_generate_next_token (Optional[Callable]): If specified, we'll use the custom
generate_next_token function (e.g. compiled function) when generating the tokens,
otherwise we'll use the default `geenrate_next_token` function. Default is None
Returns:
List: list of generated tokens
Expand All @@ -100,6 +104,9 @@ def generate(
f"{(prompt_length + max_generated_tokens)} - 1"
)

if custom_generate_next_token is None:
custom_generate_next_token = generate_next_token

# generated_tokens is a list of tensors where each tensor contains tokens
# needed for the output
generated_tokens = [prompt]
Expand All @@ -120,7 +127,7 @@ def generate(
# we get the requested number of tokens or we hit eos_id
input_pos = torch.tensor([prompt_length], device=prompt.device)
for _ in range(max_generated_tokens - 1):
token = generate_next_token(
token = custom_generate_next_token(
model=model,
input_pos=input_pos,
x=token.view(1, -1),
Expand Down

0 comments on commit d6279ed

Please sign in to comment.