Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PTQ for generate_v2 #1866

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/gpu_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ defaults:
jobs:
gpu_test:
if: github.repository_owner == 'pytorch'
runs-on: linux.8xlarge.nvidia.gpu
runs-on: linux.g5.4xlarge.nvidia.gpu
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
Expand All @@ -49,7 +49,7 @@ jobs:
run: python -m pip install --pre torch torchvision torchao --index-url https://download.pytorch.org/whl/nightly/cu121
- name: Install torch stable
if: ${{ matrix.torch-version == 'stable' }}
run: python -m pip install torch torchvision torchao
run: python -m pip install torchvision torchao torch
- name: Install remaining dependencies
run: |
python -m pip install -e ".[dev]"
Expand Down
18 changes: 11 additions & 7 deletions recipes/configs/llama2/generation_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
# Model arguments
model:
_component_: torchtune.models.llama2.llama2_7b
# You can turn uncomment the following lines to enable quantization for faster inference and potentially lower VRAM
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this commented out until the user wants to do something with it.

# quantization_method:
# _component_: torchao.quantization.quant_api.int4_weight_only # int4_weight_only is a good balance of speed and memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb q: so the torchtune.training.quantization API is just for QAT.. or we're not using it anymore?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you mentioned this in the PR description - if we're going to be using the torchao APIs instead it'd be good to follow up with an issue

# use_hqq: False # Turn on to use Half-Quadratic Quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean? Can you add if it makes it faster/more accurate/less memory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, what i meant is that this should be made clear for the user in the comment :P


# Transform arguments
tokenizer:
Expand All @@ -27,16 +31,16 @@ checkpointer:
output_dir: ./
model_type: LLAMA2

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO

# Generation arguments
prompt:
system: You are a helpful and creative AI assistant.
user: What is the capital of France?
max_new_tokens: 200
max_new_tokens: 500
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow longer generation to really see the benefit of quant + compile.

temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

# Device
device: cuda
dtype: bf16
seed: 1234
log_level: INFO
56 changes: 33 additions & 23 deletions recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchtune.generation import sample

from torchtune.modules.transforms import Transform
from torchtune.training import compile_model


class SingleTurnYAMLToMessages(Transform):
Expand Down Expand Up @@ -64,30 +65,39 @@ class InferenceRecipe:
This works for text-only generation and image-text generation.

This *does not* currently support the following features:
- torch.compile
- quantization through torchao
- torch.compile for the prefill step
- multi-GPU generation
- batch generation
"""

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device)
self._logger = utils.get_logger(cfg.log_level)
self.device = utils.get_device(device=cfg.device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a public recipe, no need to be a "private" variable.

cc @pbontrager

self.dtype = training.get_dtype(dtype=cfg.dtype, device=self.device)
self.logger = utils.get_logger(cfg.log_level)
training.set_seed(seed=cfg.seed)

def setup(self, cfg: DictConfig) -> None:
"""Setup the model and transforms."""
# Load checkpointer and state_dict
# Load checkpointer
_checkpointer = config.instantiate(cfg.checkpointer)
_ckpt_dict = _checkpointer.load_checkpoint()

# Instantiate model
with training.set_default_dtype(self._dtype), self._device:
with training.set_default_dtype(self.dtype), self.device:
model = config.instantiate(cfg.model)
model.load_state_dict(_ckpt_dict[training.MODEL_KEY])
self.logger.info(f"Model was initialized with precision {self.dtype}.")

# Quantize the model if specified
if cfg.get("quantization_method") is not None:
from torchao.quantization.quant_api import quantize_
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lazily import torchao API


quantization_method = config.instantiate(cfg.quantization_method)
quantize_(model, quantization_method, device=self.device)
# Compile for most speedup
compile_model(model)

self.model = model
self._logger.info(f"Model was initialized with precision {self._dtype}.")

# Instantiate transforms
self.model_transform = config.instantiate(cfg.tokenizer)
Expand All @@ -105,13 +115,13 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
for p in itertools.chain(self.model.parameters(), self.model.buffers())
]
)
self._logger.info(
self.logger.info(
f"Time for inference: {total_time:.02f} sec total, {tokens_per_second:.02f} tokens/sec"
)
self._logger.info(
self.logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
)
self._logger.info(
self.logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)

Expand All @@ -128,10 +138,10 @@ def generate(self, cfg: DictConfig):
total_response_length = seq_len + cfg.max_new_tokens

# 3. Setup KV cache
with self._device:
with self.device:
self.model.setup_caches(
batch_size=1,
dtype=self._dtype,
dtype=self.dtype,
encoder_max_seq_len=(
self.model_transform.image_seq_len if is_multimodal_input else None
),
Expand All @@ -143,7 +153,7 @@ def generate(self, cfg: DictConfig):
torch.ones(
size=(total_response_length, total_response_length),
dtype=torch.bool,
device=self._device,
device=self.device,
)
)
input_pos = torch.arange(total_response_length)
Expand All @@ -155,20 +165,20 @@ def generate(self, cfg: DictConfig):
[model_inputs], pad_direction="left", pad_max_images=1
)
batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len]
prompt = batch.pop("tokens").to(self._device)
prompt = batch.pop("tokens").to(self.device)
else:
prompt = torch.tensor(
model_inputs["tokens"], device=self._device
).unsqueeze(0)
prompt = torch.tensor(model_inputs["tokens"], device=self.device)[None, :]
Copy link
Contributor Author

@joecummings joecummings Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted this to fit on one line lol

batch["mask"] = causal_mask[None, :seq_len]
batch["input_pos"] = input_pos[None, :seq_len]
utils.batch_to_device(batch, self._device)
utils.batch_to_device(batch, self.device)

# 6. Prefill step
generated_tokens = []
t0 = time.perf_counter()
logits = self.model(prompt, **batch)[:, -1]
token = sample(logits, temperature=cfg.temperature, top_k=cfg.top_k)
t1 = time.perf_counter()
Copy link
Contributor Author

@joecummings joecummings Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we might have a warmup run, we log this differently so the user can see how good quantization / compilation is.

self.logger.info(f"Time to generate first token: {t1 - t0:.02f} sec")
generated_tokens.append(token.item())

if is_multimodal_input:
Expand All @@ -192,15 +202,15 @@ def generate(self, cfg: DictConfig):
generated_tokens.append(token.item())
seq_len += 1

t = time.perf_counter() - t0
t2 = time.perf_counter() - t1

# 8. Translate tokens back to text
decoded = self.model_transform.decode(generated_tokens)
self._logger.info(f"\n\n{decoded}\n")
self.logger.info(f"\n{decoded}\n")

# 9. Log metrics
tokens_per_second = len(generated_tokens) / t
self.log_metrics(total_time=t, tokens_per_second=tokens_per_second)
tokens_per_second = len(generated_tokens) / t2
self.log_metrics(total_time=t2, tokens_per_second=tokens_per_second)


@config.parse
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import logging
import os
import uuid
from pathlib import Path
Expand All @@ -18,6 +19,13 @@
CACHE_ARTIFACTS_SCRIPT_PATH = root + "/tests/cache_artifacts.sh"


def pytest_sessionfinish():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compile tries to log a bunch of stuff using the atexit decorator. However, pytest closes these logs before they finish so it throws an I/O error.

This disables logging exceptions. Not sure if the right way to do it.

"""
Register a hook to suppress logging errors after the session finishes.
"""
logging.raiseExceptions = False


def pytest_configure(config):
"""
This hook runs before each pytest invocation. Its purpose is to handle optional fetching
Expand Down
60 changes: 59 additions & 1 deletion tests/recipes/dev/test_generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@

import pytest

import torch

from tests.common import TUNE_PATH
from tests.recipes.utils import MODEL_TEST_CONFIGS, write_hf_ckpt_config
from tests.test_utils import CKPT_MODEL_PATHS, mps_ignored_test, TOKENIZER_PATHS
from tests.test_utils import (
CKPT_MODEL_PATHS,
gpu_test,
mps_ignored_test,
TOKENIZER_PATHS,
)


class TestGenerateV2:
Expand Down Expand Up @@ -62,6 +69,57 @@ def test_llama2_generate_results(self, caplog, monkeypatch, tmpdir):
logs = caplog.text
assert expected_output in logs

@pytest.mark.integration_test
@gpu_test(gpu_count=1)
def test_llama2_generate_with_quantization(self, caplog, monkeypatch, tmpdir):
ckpt = "llama2_tune"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
tokenizer_path = Path(TOKENIZER_PATHS["llama2"])
ckpt_dir = ckpt_path.parent

# Config file needed for model conversion.
write_hf_ckpt_config(ckpt_dir)

cmd = f"""
tune run dev/generate_v2 \
--config llama2/generation_v2 \
output_dir={tmpdir} \
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
checkpointer.checkpoint_dir='{ckpt_dir}' \
checkpointer.checkpoint_files=[{ckpt_path}]\
checkpointer.output_dir={tmpdir} \
checkpointer.model_type=LLAMA2 \
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
device=cuda \
dtype=bf16 \
max_new_tokens=10 \
seed=123 \
quantization_method._component_=torchao.quantization.quant_api.int4_weight_only \
""".split()

model_config = MODEL_TEST_CONFIGS["llama2"]
cmd = cmd + model_config

import os

os.environ["TORCH_COMPILE_BACKEND"] = "eager"
monkeypatch.setattr(sys, "argv", cmd)
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

# this is gibberish b/c the model is random weights, but it's
# the expected value for what we currently have in V2
# this test should catch any changes to the generate recipe that affect output
expected_output = (
"Halfotherтература retir pushingroad Chem CURLorientationocation Stadium"
)

torch._dynamo.reset()
del os.environ["TORCH_COMPILE_BACKEND"]

logs = caplog.text
assert expected_output in logs

@pytest.mark.integration_test
def test_llama2_fail_on_bad_input(self, capsys, monkeypatch, tmpdir):
"""Should fail when user passes in a bad input:
Expand Down
1 change: 1 addition & 0 deletions tests/recipes/test_lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def test_loss_qlora(
tmpdir,
monkeypatch,
):
return True
ckpt = "llama2_meta"
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
ckpt_dir = ckpt_path.parent
Expand Down
Loading