Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate INT8 mixed-precision from torchao 0.7 #1552

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
cf3355e
add int8mp
gau-nernst Sep 12, 2024
d3bbaeb
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 13, 2024
5a61d3e
add a flag
gau-nernst Sep 13, 2024
560039d
create a quantizer
gau-nernst Sep 13, 2024
2b6e066
add notes on when speedup can be expected
gau-nernst Sep 13, 2024
d32f5b8
clarify doc message
gau-nernst Sep 13, 2024
60dad97
update docs
gau-nernst Sep 13, 2024
8395070
add tiny log
gau-nernst Sep 13, 2024
b7b8a7d
update comment
gau-nernst Sep 13, 2024
2829b03
add guard on torch version and CUDA sm
gau-nernst Sep 13, 2024
688a1c8
add integration test
gau-nernst Sep 13, 2024
21391ad
update test
gau-nernst Sep 13, 2024
f885d56
use dummy alpaca
gau-nernst Sep 13, 2024
7db782c
fix typo
gau-nernst Sep 14, 2024
8306f9a
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 14, 2024
25a2451
convert speed test to smoke test
gau-nernst Sep 14, 2024
86d5f04
Merge branch 'int8mp' of github.com:gau-nernst/torchtune into int8mp
gau-nernst Sep 14, 2024
6094cdb
fix test
gau-nernst Sep 14, 2024
19a2d3e
add ao version guard
gau-nernst Sep 14, 2024
faec18d
fix
gau-nernst Sep 14, 2024
f4f1945
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 14, 2024
8fc2826
attempt LoRA
gau-nernst Sep 14, 2024
911df57
fix lora
gau-nernst Sep 15, 2024
51bbeac
skip LoRA
gau-nernst Sep 15, 2024
1e5ae92
skip NF4
gau-nernst Sep 15, 2024
1e4eaf6
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 15, 2024
30585c2
Merge branch 'main' into int8mp
felipemello1 Oct 3, 2024
45b4365
typo
felipemello1 Oct 3, 2024
3e5b040
Merge branch 'main' into int8mp
gau-nernst Nov 3, 2024
1ac836a
remove unwanted chnages
gau-nernst Nov 3, 2024
5d94cb3
use module swap
gau-nernst Nov 3, 2024
06abd88
remove unused import
gau-nernst Nov 3, 2024
0ff702e
update docs. change to mixed_precision
gau-nernst Nov 3, 2024
05563f2
add test. small fixes
gau-nernst Nov 3, 2024
3050c32
add config entries
gau-nernst Nov 3, 2024
864c6fb
remove extra compile
gau-nernst Nov 3, 2024
1fed859
fix lora finetune
gau-nernst Nov 3, 2024
66e8cdd
Merge branch 'main' into int8mp
gau-nernst Nov 8, 2024
207308b
Merge branch 'main' into int8mp
gau-nernst Nov 12, 2024
0fecc26
fix version check
gau-nernst Nov 12, 2024
39e1fc1
dont set inductor config
gau-nernst Nov 12, 2024
b2bc5ef
Merge branch 'main' into int8mp
gau-nernst Dec 5, 2024
a334986
remove LoRA
gau-nernst Dec 5, 2024
d149801
remove PyTorch version check
gau-nernst Dec 5, 2024
03a1978
add checks in init. add entries to all applicable configs
gau-nernst Dec 5, 2024
35ca06a
Merge branch 'main' into int8mp
gau-nernst Dec 10, 2024
0699aa3
add space
gau-nernst Dec 10, 2024
be9c0fb
consolidate checks
gau-nernst Dec 10, 2024
ca29866
Merge branch 'pytorch:main' into int8mp
gau-nernst Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def setup(self, cfg: DictConfig) -> None:
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
compile_model=self._compile,
model_state_dict=ckpt_dict[training.MODEL_KEY],
quantizer_cfg=cfg.get("quantizer", None),
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
)
self._tokenizer = config.instantiate(cfg.tokenizer)
log.info("Tokenizer is initialized from file.")
Expand Down Expand Up @@ -345,6 +346,7 @@ def _setup_model(
enable_activation_checkpointing: bool,
compile_model: bool,
model_state_dict: Dict[str, Any],
quantizer_cfg: Optional[DictConfig] = None,
) -> nn.Module:
"""
Set up the model including enabling activation checkpointing.
Expand All @@ -360,6 +362,10 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if quantizer_cfg is not None:
Copy link
Contributor

@ebsmothers ebsmothers Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Context in this comment)

Suggested change
if quantizer_cfg is not None:
if quantizer_cfg is not None and quantizer_cfg.get("enabled", False):

quantizer = config.instantiate(quantizer_cfg)
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
model = quantizer.prepare(model)

model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
Expand Down
47 changes: 47 additions & 0 deletions torchtune/training/quantization.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: would it be nice to add some sort of import guard to tell the user they need torchao > N for this? Torchao is not a requirement anymore

cc: @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with older GPUs? Does it work on CPU? Maybe we need something like this:

_SUPPORTS_INT8_MIXED_PRECISION = (
    torch_version_ge("2.5.0")
    and torch.cuda.is_available()
    and torch.cuda.get_device_capability() >= (7, 5)
)

and torch.cuda.get_device_capability() >= (7, 5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I use Triton for this so it will probably run on any GPUs Triton supports (same as torch.compile). Though I think only Ampere (sm80) and above has INT8 tensor cores. To be safe, I think we just guard for sm80 and above.

CPU is not supported. Technically it is possible, but I didn't add it in torchao since I can't reliably test it / see it useful.

This works with PyTorch 2.4 (though FlexAttention requires PyTorch 2.5 🤣).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw doesn't QAT also need some kind of guards like this? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I completely missed this convo previously. Actually now that we've asked users to install torchao manually I am kinda taking a similar stance there to what we have with PyTorch: people should always be running on the latest stable version. So actually I claim that the first two lines of the _SUPPORTS_INT8_MIXED_PRECISION_TRAINING check are not strictly necessary (fine to keep them in though, I don't have a strong preference). But anyways that should hopefully answer the question about why we don't have similar such guards for QAT

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

from typing import Callable, Optional

from torch import nn
from torchao.prototype.quantized_training import (
int8_mixed_precision_training,
Int8MixedPrecisionTrainingConfig,
)
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.quantization.prototype.qat import (
disable_8da4w_fake_quant,
Expand All @@ -18,11 +23,14 @@
Int8DynActInt4WeightQATQuantizerModuleSwap,
)

from torchtune.modules import TransformerDecoder


__all__ = [
"get_quantizer_mode",
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8MixedPrecisionTrainingQuantizer",
]


Expand Down Expand Up @@ -74,6 +82,45 @@ def quantize(self, model):
] = enable_8da4w_fake_quant_module_swap


class Int8MixedPrecisionTrainingQuantizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my understanding:

  1. torchao has a nice api that does _quantize(model, config)

  2. This pr creates a wrapper specific for int8. The wrapper is needed for two reasons:
    i) ignora lora_a/lora_b
    ii) we dont do nested configs

  3. this means that if we want to support fp8. bitnet, or any other torchao technique, we have to create a custom wrapper, instead of doing _quantize(model, torchao_config)


IMO, this makes a lot of sense if:
a) every torchao technique interacts differently with torchtune, e.g. one doesnt work with offloading, another needs some extra work for ckpt, etc, and we cant solve it with a config parser
b) if there are realistically only a couple or two quantization methods we will use from torchao (e.g. int8 and fp8)

But if thats not the case, then we should probably avoid having a custom torchtune wrapper per torchao technique.

I guess we had a similar situation with OptimizerCPUOffload. We ended up giving up on a custom torchtune wrapper and just instantiated directly from the config.


In summary, instead of "Int8MixedPrecisionTrainingQuantizer", should we have a generalist torchtune "PrecisionTrainingQuantizer", just to handle LoRA, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think PrecisionTrainingQuantizer is not needed atm. We can revisit it in the future once there are more things like this in the future (like you said, realistically I can only think of int8 and fp8 for now, but it may change)

"""Apply INT8 mixed-precision training. During training, weights and activations
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
are dynamically quantized to INT8 to utilize INT8 tensor cores. This is also done
in the backward pass.

NOTE: due to the limitations of the current implementation, the following
requirements must be satisfied to enjoy speedup:

1. Must use ``torch.compile()`` (set ``compile=True``).
2. Inputs to the model must not be too dynamic e.g. input sequence length changes
for every batch.

To satisfy (2), you can use :class:`~torchtune.datasets.PackedDataset` (set
``dataset.packed=True``), which ensures input tokens always have fixed length.
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
output: bool = True,
grad_input: bool = True,
grad_weight: bool = True,
Comment on lines +213 to +217
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think that there is a way around us always having our own wrapper, since we need it to deal with LoRA/other details. Therefore, I think that we should add "enabled" here, and have config like this:

mixed_precision:
	_component_: torchtune.training.quantization.Int8MixedPrecisionTraining
	arg1: True
	arg2: True
	enabled: False

That way the configs can advertise this functionality.

I also think that we should rename quantizer to "mixed_precision" or "mixed_precision_training", so its self-explanatory.

Let me know if you need help with making the changes or if you disagree with the ideas above.

thanks again for the PR! @ebsmothers might want to leave a few comments too
For me, after these changes and the test pass, i am good to merge.

Copy link
Contributor

@ebsmothers ebsmothers Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah personally I like the change to rename quantizer -> mixed_precision (unless the idea is to later support full low-precision training using the same API, in which case it wouldn't make sense).

Edit: after talking with @felipemello1 about this, what do you think about adding the enabled flag to the config and checking in the recipe? It'd look like this in the recipe. I think this is a good tradeoff because (a) we show the feature in our configs, (b) we don't conditionally define a no-op version of the quantizer (a pattern I really don't like) and are instead explicit about it in the recipe, and (c) it can be easily extended to other mixed precision strategies in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I agree quantizer name right now can sound confusing, since the weight is actually not quantized (well QAT also does not actually quantize the weight 😄)

Regarding mixed_precision name, I'm thinking about what other training recipes out there that would fall under this category (because it would be weird if we have mixed_precision but only 1 option). There are FP8 and FP16/BF16 mixed-precision training, but they follow a totally different UX

  • FP8: the current FP8 training recipes, in torchao or transformer engine, is actually mixed-precision training, since master weights is still in high precision, although people usually don't say it is mixed-precision. Currently FP8 UX is to do module swap (but I think for FP8 all-gather, FP8 also does tensor subclass). I think Vasiliy is having plans to update FP8 UX, not exactly sure. But anyway this is an implementation detail which can be hidden away from torchtune users. (also someone should add FP8 to torchtune 😄)
  • BF16 (w/ FP32 master weights): for single-GPU, it will be using torch.autocast(dtype=torch.bfloat16). For FSDP2, it would be using MixedPrecisionPolicy(param_dtype=torch.bfloat16). This is very different UX (and mechanism) from INT8 and FP8, and would be hard to unify the torchtune user-facing API (and not sure if torchtune wants to add this option? Btw there can be small gaps between full BF16 training (current torchtune) and BF16 mixed-precision training, especially when LR is small. Would take some efforts for someone to investigate this 😅).
  • FP16 (w/ FP32 master weights): probably won't consider this due to complexity with gradient scaling + finetune BF16 models with FP16 is a bad idea.

There are also MX FP6/FP4 in the new NVIDIA GPUs, but that's too far in the future. Considering that at least we can have INT8 and FP8 mixed-precision training, using the mixed_precision name seems reasonable!

Copy link
Contributor

@felipemello1 felipemello1 Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also someone should add FP8 to torchtune

Would it be too different from int8? Or could we leverage the same API and replace the quantization module?
BTW, does FP8 work with FSDP? In this case, is there anything that makes it special, since int8 seems to not be compatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felipemello1 Currently torchao API is different https://github.com/pytorch/ao/tree/main/torchao/float8. But it can be hidden away by torchtune (e.g. a separate Quantizer class, or some other name like MixedPrecisionTrainer class)

Both FP8 (by FP8 folks) and INT8 (by me) support FSDP2. The issue is with how torchtune handles distributed weight loading, which is not trivial to handle. FP8 module swap might not have this problem though.

) -> None:
self._config = Int8MixedPrecisionTrainingConfig(
output=output,
grad_input=grad_input,
grad_weight=grad_weight,
)

def prepare(self, model: nn.Module) -> nn.Module:
Copy link
Contributor

@felipemello1 felipemello1 Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neither 3b or 8b full single device are working for me using A100

torch 2.6.0.dev20241107+cu124
torchao 0.7.0.dev20241111+cu124
torchtune 0.0.0 /data/users/felipemello/torchtune
torchvision 0.20.0.dev20241107+cu124

I also tried with torch 2.5.1.

I think that the issue is with recompiling. It takes several minutes on this step (10+min). I wonder its recompiling non-stop, until it runs out of memory and errors with RuntimeError: std::bad_alloc

8b

  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
    output = self.chunked_output(h)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
    return [
           ^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
    self.output(chunk)
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc

3b

  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
    output = self.chunked_output(h)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
    return [
           ^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
    self.output(chunk)
  File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 66, in __call__
    return self.linear(x, self.tied_module.weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 27, in forward
    return F.linear(x, weight)
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc

PS: for 3b we use TiedLinear implementation:

I printed the shape and dtype, not that it matters

torch.Size([5, 256, 3072]) torch.bfloat16
torch.Size([128256, 3072]) torch.bfloat16

I also checked for 3b, and filter_fn returns false for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding std::bad_alloc, I think it's an issue with recent pytorch nightlies. ao CI has been seeing it too pytorch/ao#1229. Currently CI pins to 20241101 https://github.com/pytorch/ao/blob/b2642fb33e360ffe478fe19665b1c4efd80537c6/.github/workflows/regression_test.yml#L64. Can you try with this version?

I don't exactly recall, but I think usually I had to kill the first run (it kinda hangs?), then the subsequent run is fast. Didn't think much of it. I will try to reproduce it again in a fresh environment to see if my memory serves me right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I additionally set quantize_(set_inductor_config=False) to avoid exhaustive torch.compile tuning. Can you try again? It seemed fast to me (<2min). Worked with both Llama3.1-8B and Llama3.2-3B. torch==2.5.1+cu121 and torchao==0.7.0.dev20241112+cu121

The command

tune run full_finetune_single_device --config llama3_2/3B_full_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True mixed_precision._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer mixed_precision.enabled=True

# don't apply INT8 mixed-precision training to LM head
# since speed is slightly lower.
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved
quantize_fn = int8_mixed_precision_training(self._config)
if isinstance(model, TransformerDecoder):
quantize_(model.layers, quantize_fn)
else:
quantize_(model, quantize_fn)
return model


def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
"""Given a quantizer object, returns a string that specifies the type of quantization.

Expand Down