Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate INT8 mixed-precision from torchao 0.7 #1552

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
cf3355e
add int8mp
gau-nernst Sep 12, 2024
d3bbaeb
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 13, 2024
5a61d3e
add a flag
gau-nernst Sep 13, 2024
560039d
create a quantizer
gau-nernst Sep 13, 2024
2b6e066
add notes on when speedup can be expected
gau-nernst Sep 13, 2024
d32f5b8
clarify doc message
gau-nernst Sep 13, 2024
60dad97
update docs
gau-nernst Sep 13, 2024
8395070
add tiny log
gau-nernst Sep 13, 2024
b7b8a7d
update comment
gau-nernst Sep 13, 2024
2829b03
add guard on torch version and CUDA sm
gau-nernst Sep 13, 2024
688a1c8
add integration test
gau-nernst Sep 13, 2024
21391ad
update test
gau-nernst Sep 13, 2024
f885d56
use dummy alpaca
gau-nernst Sep 13, 2024
7db782c
fix typo
gau-nernst Sep 14, 2024
8306f9a
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 14, 2024
25a2451
convert speed test to smoke test
gau-nernst Sep 14, 2024
86d5f04
Merge branch 'int8mp' of github.com:gau-nernst/torchtune into int8mp
gau-nernst Sep 14, 2024
6094cdb
fix test
gau-nernst Sep 14, 2024
19a2d3e
add ao version guard
gau-nernst Sep 14, 2024
faec18d
fix
gau-nernst Sep 14, 2024
f4f1945
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 14, 2024
8fc2826
attempt LoRA
gau-nernst Sep 14, 2024
911df57
fix lora
gau-nernst Sep 15, 2024
51bbeac
skip LoRA
gau-nernst Sep 15, 2024
1e5ae92
skip NF4
gau-nernst Sep 15, 2024
1e4eaf6
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 15, 2024
30585c2
Merge branch 'main' into int8mp
felipemello1 Oct 3, 2024
45b4365
typo
felipemello1 Oct 3, 2024
3e5b040
Merge branch 'main' into int8mp
gau-nernst Nov 3, 2024
1ac836a
remove unwanted chnages
gau-nernst Nov 3, 2024
5d94cb3
use module swap
gau-nernst Nov 3, 2024
06abd88
remove unused import
gau-nernst Nov 3, 2024
0ff702e
update docs. change to mixed_precision
gau-nernst Nov 3, 2024
05563f2
add test. small fixes
gau-nernst Nov 3, 2024
3050c32
add config entries
gau-nernst Nov 3, 2024
864c6fb
remove extra compile
gau-nernst Nov 3, 2024
1fed859
fix lora finetune
gau-nernst Nov 3, 2024
66e8cdd
Merge branch 'main' into int8mp
gau-nernst Nov 8, 2024
207308b
Merge branch 'main' into int8mp
gau-nernst Nov 12, 2024
0fecc26
fix version check
gau-nernst Nov 12, 2024
39e1fc1
dont set inductor config
gau-nernst Nov 12, 2024
b2bc5ef
Merge branch 'main' into int8mp
gau-nernst Dec 5, 2024
a334986
remove LoRA
gau-nernst Dec 5, 2024
d149801
remove PyTorch version check
gau-nernst Dec 5, 2024
03a1978
add checks in init. add entries to all applicable configs
gau-nernst Dec 5, 2024
35ca06a
Merge branch 'main' into int8mp
gau-nernst Dec 10, 2024
0699aa3
add space
gau-nernst Dec 10, 2024
be9c0fb
consolidate checks
gau-nernst Dec 10, 2024
ca29866
Merge branch 'pytorch:main' into int8mp
gau-nernst Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ output_dir: /tmp/full-llama3.1-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2/1B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ output_dir: /tmp/full-llama3.2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2_vision/11B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
11 changes: 11 additions & 0 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def setup(self, cfg: DictConfig) -> None:
model_state_dict=checkpoint_dict[training.MODEL_KEY],
ac_mode=cfg.get("ac_mode", None),
ac_option=cfg.get("ac_option", None),
mixed_precision_cfg=cfg.get("mixed_precision", None),
)
self._tokenizer = config.instantiate(cfg.tokenizer)

Expand Down Expand Up @@ -419,6 +420,7 @@ def _setup_model(
custom_sharded_layers: Optional[List[str]] = None,
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
mixed_precision_cfg: Optional[DictConfig] = None,
) -> nn.Module:
"""
Model initialization has some important considerations:
Expand Down Expand Up @@ -459,6 +461,15 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if mixed_precision_cfg is not None and mixed_precision_cfg.get(
"enabled", False
):
log.info(f"Preparing model with {mixed_precision_cfg._component_}")
cfg = mixed_precision_cfg.copy()
cfg.pop("enabled", None)
quantizer = config.instantiate(cfg)
model = quantizer.prepare(model)

# For FSDP sharding
fsdp_shard_conditions = [
partial(
Expand Down
11 changes: 11 additions & 0 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def setup(self, cfg: DictConfig) -> None:
enable_activation_offloading=self._enable_activation_offloading,
compile_model=self._compile,
model_state_dict=ckpt_dict[training.MODEL_KEY],
mixed_precision_cfg=cfg.get("mixed_precision", None),
)
self._tokenizer = config.instantiate(cfg.tokenizer)
log.info("Tokenizer is initialized from file.")
Expand Down Expand Up @@ -414,6 +415,7 @@ def _setup_model(
enable_activation_offloading: bool,
compile_model: bool,
model_state_dict: Dict[str, Any],
mixed_precision_cfg: Optional[DictConfig] = None,
) -> nn.Module:
"""
Set up the model including enabling activation checkpointing.
Expand All @@ -429,6 +431,15 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if mixed_precision_cfg is not None and mixed_precision_cfg.get(
"enabled", False
):
log.info(f"Preparing model with {mixed_precision_cfg._component_}")
cfg = mixed_precision_cfg.copy()
cfg.pop("enabled", None)
quantizer = config.instantiate(cfg)
model = quantizer.prepare(model)

model.load_state_dict(model_state_dict)

# Validate model was loaded in with the expected dtype.
Expand Down
11 changes: 11 additions & 0 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def setup(self, cfg: DictConfig) -> None:
if self._resume_from_checkpoint
else None
),
mixed_precision_cfg=cfg.get("mixed_precision", None),
)

self._tokenizer = config.instantiate(cfg.tokenizer)
Expand Down Expand Up @@ -420,6 +421,7 @@ def _setup_model(
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
mixed_precision_cfg: Optional[DictConfig] = None,
) -> nn.Module:
with training.set_default_dtype(self._dtype), self._device:
model = config.instantiate(cfg_model)
Expand All @@ -441,6 +443,15 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

if mixed_precision_cfg is not None and mixed_precision_cfg.get(
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved
"enabled", False
):
log.info(f"Preparing model with {mixed_precision_cfg._component_}")
cfg = mixed_precision_cfg.copy()
cfg.pop("enabled", None)
quantizer = config.instantiate(cfg)
model = quantizer.prepare(model)

base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
Expand Down
43 changes: 43 additions & 0 deletions tests/torchtune/training/test_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from tests.test_utils import gpu_test
from torch import nn
from torchtune.training.quantization import (
_SUPPORTS_INT8_MIXED_PRECISION_TRAINING,
Int8MixedPrecisionTrainingQuantizer,
)


@gpu_test(gpu_count=1)
@pytest.mark.skipif(
not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING,
reason="INT8 mixed-precision training is not supported",
)
def test_int8_mixed_precision_training_quantizer():
quantizer = Int8MixedPrecisionTrainingQuantizer()
model = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
).cuda()
quantizer.prepare(model)

# make sure class is changed
assert model[0].__class__ != nn.Linear
assert model[2].__class__ != nn.Linear

# smoke test forward and backward
model(torch.randn(2, 32).cuda()).sum().backward()
for p in model.parameters():
assert p.grad is not None

# state dict is plain tensor
state_dict = model.state_dict()
for v in state_dict.values():
assert v.__class__ == torch.Tensor
102 changes: 102 additions & 0 deletions torchtune/training/quantization.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: would it be nice to add some sort of import guard to tell the user they need torchao > N for this? Torchao is not a requirement anymore

cc: @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work with older GPUs? Does it work on CPU? Maybe we need something like this:

_SUPPORTS_INT8_MIXED_PRECISION = (
    torch_version_ge("2.5.0")
    and torch.cuda.is_available()
    and torch.cuda.get_device_capability() >= (7, 5)
)

and torch.cuda.get_device_capability() >= (7, 5)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! I use Triton for this so it will probably run on any GPUs Triton supports (same as torch.compile). Though I think only Ampere (sm80) and above has INT8 tensor cores. To be safe, I think we just guard for sm80 and above.

CPU is not supported. Technically it is possible, but I didn't add it in torchao since I can't reliably test it / see it useful.

This works with PyTorch 2.4 (though FlexAttention requires PyTorch 2.5 🤣).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw doesn't QAT also need some kind of guards like this? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I completely missed this convo previously. Actually now that we've asked users to install torchao manually I am kinda taking a similar stance there to what we have with PyTorch: people should always be running on the latest stable version. So actually I claim that the first two lines of the _SUPPORTS_INT8_MIXED_PRECISION_TRAINING check are not strictly necessary (fine to keep them in though, I don't have a strong preference). But anyways that should hopefully answer the question about why we don't have similar such guards for QAT

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from typing import Callable, Optional
from warnings import warn

import torch
import torchao
from packaging.version import Version
from torch import nn

try:
# torchao 0.7+
from torchao.dtypes import TensorCoreTiledLayout
Expand Down Expand Up @@ -43,6 +48,22 @@
Int8DynActInt4WeightQATQuantizer,
)

from torchtune.utils._version import torch_version_ge


_SUPPORTS_INT8_MIXED_PRECISION_TRAINING = (
Copy link
Contributor

@felipemello1 felipemello1 Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how bad is the performance if flex attention/compile are off? I am thinking if this should be part of the requirement. If so, then we can just include this import guard in the check:

_SUPPORTS_FLEX_ATTENTION = (

Maybe thats a bad idea, we can raise a warning instead: Use compile and dataset.packed for best results.

torch_version_ge("2.4.0")
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved
and Version(torchao.__version__) >= Version("0.7.0.dev")
and torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (8, 0)
)

if _SUPPORTS_INT8_MIXED_PRECISION_TRAINING:
from torchao.prototype.quantized_training import (
int8_mixed_precision_training,
Int8MixedPrecisionTrainingConfig,
)


__all__ = [
"get_quantizer_mode",
Expand All @@ -52,6 +73,7 @@
"Int8DynActInt4WeightQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"Int8DynActInt4WeightQATQuantizerModuleSwap",
"Int8MixedPrecisionTrainingQuantizer",
]


Expand Down Expand Up @@ -144,6 +166,86 @@ def quantize(self, model):
] = enable_8da4w_fake_quant_module_swap


class Int8MixedPrecisionTrainingQuantizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is my understanding:

  1. torchao has a nice api that does _quantize(model, config)

  2. This pr creates a wrapper specific for int8. The wrapper is needed for two reasons:
    i) ignora lora_a/lora_b
    ii) we dont do nested configs

  3. this means that if we want to support fp8. bitnet, or any other torchao technique, we have to create a custom wrapper, instead of doing _quantize(model, torchao_config)


IMO, this makes a lot of sense if:
a) every torchao technique interacts differently with torchtune, e.g. one doesnt work with offloading, another needs some extra work for ckpt, etc, and we cant solve it with a config parser
b) if there are realistically only a couple or two quantization methods we will use from torchao (e.g. int8 and fp8)

But if thats not the case, then we should probably avoid having a custom torchtune wrapper per torchao technique.

I guess we had a similar situation with OptimizerCPUOffload. We ended up giving up on a custom torchtune wrapper and just instantiated directly from the config.


In summary, instead of "Int8MixedPrecisionTrainingQuantizer", should we have a generalist torchtune "PrecisionTrainingQuantizer", just to handle LoRA, etc?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think PrecisionTrainingQuantizer is not needed atm. We can revisit it in the future once there are more things like this in the future (like you said, realistically I can only think of int8 and fp8 for now, but it may change)

"""Apply INT8 mixed-precision training. This only affects weights of ``nn.Linear``
modules. During training, weights and activations are dynamically quantized to INT8
to utilize fast matrix multiplication with INT8 tensor cores. This is also done in
the backward pass.

The expected end2end speedup is 40% on a single A100 and 70% on a single 4090, with
minimal accuracy loss. If convergence is an issue, please refer to torchao
documentation below.

For more details, as well as details about arguments of this quantizer, please refer to
https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision

Args:
output (bool): whether to apply INT8 mixed-precision for calculating output. Default: True
grad_input (bool): whether to apply INT8 mixed-precision for calculating grad_input. Default: True
grad_weight (bool): whether to apply INT8 mixed-precision for calculating grad_weight. Default: True

Raises:
RuntimeError: If runtime requirements for INT8 mixed-precision training are not met.

NOTE: Due to the limitations of the current implementation, the following
requirements must be satisfied to enjoy the expected speedup:

1. Must use ``torch.compile()`` (set ``compile=True``).
2. Inputs to the model must not be too dynamic. For example, when input tokens
length changes for every batch, you won't see the expected speedup.
felipemello1 marked this conversation as resolved.
Show resolved Hide resolved

To satisfy (2), you can use :class:`~torchtune.datasets.PackedDataset` (set
``dataset.packed=True`` and ``tokenizer.max_seq_len`` to a desired value.), which
ensures input tokens always have fixed length.
"""

def __init__(
self,
output: bool = True,
grad_input: bool = True,
grad_weight: bool = True,
Comment on lines +213 to +217
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think that there is a way around us always having our own wrapper, since we need it to deal with LoRA/other details. Therefore, I think that we should add "enabled" here, and have config like this:

mixed_precision:
	_component_: torchtune.training.quantization.Int8MixedPrecisionTraining
	arg1: True
	arg2: True
	enabled: False

That way the configs can advertise this functionality.

I also think that we should rename quantizer to "mixed_precision" or "mixed_precision_training", so its self-explanatory.

Let me know if you need help with making the changes or if you disagree with the ideas above.

thanks again for the PR! @ebsmothers might want to leave a few comments too
For me, after these changes and the test pass, i am good to merge.

Copy link
Contributor

@ebsmothers ebsmothers Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah personally I like the change to rename quantizer -> mixed_precision (unless the idea is to later support full low-precision training using the same API, in which case it wouldn't make sense).

Edit: after talking with @felipemello1 about this, what do you think about adding the enabled flag to the config and checking in the recipe? It'd look like this in the recipe. I think this is a good tradeoff because (a) we show the feature in our configs, (b) we don't conditionally define a no-op version of the quantizer (a pattern I really don't like) and are instead explicit about it in the recipe, and (c) it can be easily extended to other mixed precision strategies in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I agree quantizer name right now can sound confusing, since the weight is actually not quantized (well QAT also does not actually quantize the weight 😄)

Regarding mixed_precision name, I'm thinking about what other training recipes out there that would fall under this category (because it would be weird if we have mixed_precision but only 1 option). There are FP8 and FP16/BF16 mixed-precision training, but they follow a totally different UX

  • FP8: the current FP8 training recipes, in torchao or transformer engine, is actually mixed-precision training, since master weights is still in high precision, although people usually don't say it is mixed-precision. Currently FP8 UX is to do module swap (but I think for FP8 all-gather, FP8 also does tensor subclass). I think Vasiliy is having plans to update FP8 UX, not exactly sure. But anyway this is an implementation detail which can be hidden away from torchtune users. (also someone should add FP8 to torchtune 😄)
  • BF16 (w/ FP32 master weights): for single-GPU, it will be using torch.autocast(dtype=torch.bfloat16). For FSDP2, it would be using MixedPrecisionPolicy(param_dtype=torch.bfloat16). This is very different UX (and mechanism) from INT8 and FP8, and would be hard to unify the torchtune user-facing API (and not sure if torchtune wants to add this option? Btw there can be small gaps between full BF16 training (current torchtune) and BF16 mixed-precision training, especially when LR is small. Would take some efforts for someone to investigate this 😅).
  • FP16 (w/ FP32 master weights): probably won't consider this due to complexity with gradient scaling + finetune BF16 models with FP16 is a bad idea.

There are also MX FP6/FP4 in the new NVIDIA GPUs, but that's too far in the future. Considering that at least we can have INT8 and FP8 mixed-precision training, using the mixed_precision name seems reasonable!

Copy link
Contributor

@felipemello1 felipemello1 Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also someone should add FP8 to torchtune

Would it be too different from int8? Or could we leverage the same API and replace the quantization module?
BTW, does FP8 work with FSDP? In this case, is there anything that makes it special, since int8 seems to not be compatible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felipemello1 Currently torchao API is different https://github.com/pytorch/ao/tree/main/torchao/float8. But it can be hidden away by torchtune (e.g. a separate Quantizer class, or some other name like MixedPrecisionTrainer class)

Both FP8 (by FP8 folks) and INT8 (by me) support FSDP2. The issue is with how torchtune handles distributed weight loading, which is not trivial to handle. FP8 module swap might not have this problem though.

) -> None:
if not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING:
raise RuntimeError(
"INT8 mixed-precision training requires torch>=2.4, torchao>=0.7, and"
" a CUDA-capable device with compute capability >= 8.0"
)

self._config = Int8MixedPrecisionTrainingConfig(
output=output,
grad_input=grad_input,
grad_weight=grad_weight,
)

def prepare(self, model: nn.Module) -> nn.Module:
Copy link
Contributor

@felipemello1 felipemello1 Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neither 3b or 8b full single device are working for me using A100

torch 2.6.0.dev20241107+cu124
torchao 0.7.0.dev20241111+cu124
torchtune 0.0.0 /data/users/felipemello/torchtune
torchvision 0.20.0.dev20241107+cu124

I also tried with torch 2.5.1.

I think that the issue is with recompiling. It takes several minutes on this step (10+min). I wonder its recompiling non-stop, until it runs out of memory and errors with RuntimeError: std::bad_alloc

8b

  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
    output = self.chunked_output(h)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
    return [
           ^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
    self.output(chunk)
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc

3b

  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 649, in forward
    output = self.chunked_output(h)
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 501, in chunked_output
    return [
           ^
  File "/data/users/felipemello/torchtune/torchtune/modules/transformer.py", line 502, in <listcomp>
    self.output(chunk)
  File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 66, in __call__
    return self.linear(x, self.tied_module.weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/felipemello/torchtune/torchtune/modules/tied_linear.py", line 27, in forward
    return F.linear(x, weight)
           ^^^^^^^^^^^^^^^^^^^
RuntimeError: std::bad_alloc

PS: for 3b we use TiedLinear implementation:

I printed the shape and dtype, not that it matters

torch.Size([5, 256, 3072]) torch.bfloat16
torch.Size([128256, 3072]) torch.bfloat16

I also checked for 3b, and filter_fn returns false for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding std::bad_alloc, I think it's an issue with recent pytorch nightlies. ao CI has been seeing it too pytorch/ao#1229. Currently CI pins to 20241101 https://github.com/pytorch/ao/blob/b2642fb33e360ffe478fe19665b1c4efd80537c6/.github/workflows/regression_test.yml#L64. Can you try with this version?

I don't exactly recall, but I think usually I had to kill the first run (it kinda hangs?), then the subsequent run is fast. Didn't think much of it. I will try to reproduce it again in a fresh environment to see if my memory serves me right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I additionally set quantize_(set_inductor_config=False) to avoid exhaustive torch.compile tuning. Can you try again? It seemed fast to me (<2min). Worked with both Llama3.1-8B and Llama3.2-3B. torch==2.5.1+cu121 and torchao==0.7.0.dev20241112+cu121

The command

tune run full_finetune_single_device --config llama3_2/3B_full_single_device dataset.packed=True tokenizer.max_seq_len=8192 optimizer=torch.optim.AdamW optimizer.fused=True optimizer_in_bwd=False compile=True metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True mixed_precision._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer mixed_precision.enabled=True

# we use module-swap implementation so that the state_dict remains plain tensors,
# as well as better FSDP compatibility in torchtune.
quantize_fn = int8_mixed_precision_training(self._config, module_swap=True)

# custom filter_fn to work with torchtune's peft
def filter_fn(module: nn.Module, name: str) -> bool:
if isinstance(module, nn.Linear):
# skip LoRA adapters since they are too small, so the speedup will not
# outweight quantization overhead.
# also skip LM head since end2end speedup is slightly worse.
# there are also possible issues with tied word embeddings.
if (
name.endswith(".lora_a")
or name.endswith(".lora_b")
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved
or module.weight.shape[0] >= 32_000
):
return False
else:
return True

return False

# don't set inductor config, otherwise compile will be very slow
# (it will affect global torch.compile() config)
quantize_(model, quantize_fn, filter_fn=filter_fn, set_inductor_config=False)
return model


def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
"""Given a quantizer object, returns a string that specifies the type of quantization.

Expand Down