From cf3355e03aa374ba889f6c3626bb4a8e97a8114d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 12 Sep 2024 10:41:33 +0800 Subject: [PATCH 01/37] add int8mp --- recipes/full_finetune_single_device.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 94f804ef85..96a97600cd 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -372,6 +372,11 @@ def _setup_model( memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) + from torchao.prototype.quantized_training import int8_mixed_precision_training + from torchao.quantization import quantize_ + + quantize_(model.layers, int8_mixed_precision_training()) + return model def _setup_optimizer( From 5a61d3e2db18784ed985be1e4b8064edb3091c19 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 08:19:46 +0000 Subject: [PATCH 02/37] add a flag --- recipes/full_finetune_single_device.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 96a97600cd..26a3d621fa 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -16,6 +16,8 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler +from torchao import quantize_ +from torchao.prototype.quantized_training import int8_mixed_precision_training from torchtune import config, modules, training, utils from torchtune.data import padded_collate_packed, padded_collate_sft @@ -212,6 +214,7 @@ def setup(self, cfg: DictConfig) -> None: enable_activation_checkpointing=cfg.enable_activation_checkpointing, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], + int8_mixed_precision_training=cfg.get("int8_mixed_precision_training", False), ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -345,6 +348,7 @@ def _setup_model( enable_activation_checkpointing: bool, compile_model: bool, model_state_dict: Dict[str, Any], + int8_mixed_precision_training: bool = False, ) -> nn.Module: """ Set up the model including enabling activation checkpointing. @@ -360,6 +364,10 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if int8_mixed_precision_training: + # don't apply to LM head + quantize_(model.layers, int8_mixed_precision_training()) + model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. @@ -372,11 +380,6 @@ def _setup_model( memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) - from torchao.prototype.quantized_training import int8_mixed_precision_training - from torchao.quantization import quantize_ - - quantize_(model.layers, int8_mixed_precision_training()) - return model def _setup_optimizer( From 560039d8a3e5c8b96636d22c49d87fbf608fa276 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 20:55:56 +0800 Subject: [PATCH 03/37] create a quantizer --- recipes/full_finetune_single_device.py | 12 ++++----- torchtune/training/quantization.py | 36 ++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 26a3d621fa..91d9b4ced5 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -16,8 +16,6 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchao import quantize_ -from torchao.prototype.quantized_training import int8_mixed_precision_training from torchtune import config, modules, training, utils from torchtune.data import padded_collate_packed, padded_collate_sft @@ -214,7 +212,7 @@ def setup(self, cfg: DictConfig) -> None: enable_activation_checkpointing=cfg.enable_activation_checkpointing, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], - int8_mixed_precision_training=cfg.get("int8_mixed_precision_training", False), + quantizer_cfg=cfg.get("quantizer", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -348,7 +346,7 @@ def _setup_model( enable_activation_checkpointing: bool, compile_model: bool, model_state_dict: Dict[str, Any], - int8_mixed_precision_training: bool = False, + quantizer_cfg: Optional[DictConfig] = None, ) -> nn.Module: """ Set up the model including enabling activation checkpointing. @@ -364,9 +362,9 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if int8_mixed_precision_training: - # don't apply to LM head - quantize_(model.layers, int8_mixed_precision_training()) + if quantizer_cfg is not None: + quantizer = config.instantiate(quantizer_cfg) + model = quantizer.prepare(model) model.load_state_dict(model_state_dict) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 7046f90b64..e244030937 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -6,6 +6,11 @@ from typing import Callable, Optional +from torch import nn +from torchao.prototype.quantized_training import ( + int8_mixed_precision_training, + Int8MixedPrecisionTrainingConfig, +) from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ from torchao.quantization.prototype.qat import ( disable_8da4w_fake_quant, @@ -18,11 +23,14 @@ Int8DynActInt4WeightQATQuantizerModuleSwap, ) +from torchtune.modules import TransformerDecoder + __all__ = [ "get_quantizer_mode", "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightQATQuantizer", + "Int8MixedPrecisionTrainingQuantizer", ] @@ -74,6 +82,34 @@ def quantize(self, model): ] = enable_8da4w_fake_quant_module_swap +class Int8MixedPrecisionTrainingQuantizer: + """Apply INT8 mixed-precision training. During training, weights and activations + are dynamically quantized to INT8 to utilize INT8 tensor cores. This is also done + in the backward pass.""" + + def __init__( + self, + output: bool = True, + grad_input: bool = True, + grad_weight: bool = True, + ) -> None: + self._config = Int8MixedPrecisionTrainingConfig( + output=output, + grad_input=grad_input, + grad_weight=grad_weight, + ) + + def prepare(self, model: nn.Module) -> nn.Module: + # don't apply INT8 mixed-precision training to LM head + # since speed is slightly lower. + quantize_fn = int8_mixed_precision_training(self._config) + if isinstance(model, TransformerDecoder): + quantize_(model.layers, quantize_fn) + else: + quantize_(model, quantize_fn) + return model + + def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: """Given a quantizer object, returns a string that specifies the type of quantization. From 2b6e066aae3a75e66e52338ad673ce2fd6eb2dea Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 21:45:29 +0800 Subject: [PATCH 04/37] add notes on when speedup can be expected --- torchtune/training/quantization.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index e244030937..7823442546 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -85,7 +85,18 @@ def quantize(self, model): class Int8MixedPrecisionTrainingQuantizer: """Apply INT8 mixed-precision training. During training, weights and activations are dynamically quantized to INT8 to utilize INT8 tensor cores. This is also done - in the backward pass.""" + in the backward pass. + + NOTE: due to the limitations of the current implementation, the following + requirements must be satisfied to enjoy speedup: + + 1. Must use ``torch.compile()`` (set ``compile=True``). + 2. Inputs to the model must not be too dynamic e.g. input sequence length changes + for every batch. + + To satisfy (2), you can use :class:`~torchtune.datasets.PackedDataset` (set + ``dataset.packed=True``), which ensures input tokens always have fixed length. + """ def __init__( self, From d32f5b84d6fa36a8b5e0cf423280c0debb3afb25 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 22:01:30 +0800 Subject: [PATCH 05/37] clarify doc message --- torchtune/training/quantization.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 7823442546..a3c7bc1889 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -87,12 +87,12 @@ class Int8MixedPrecisionTrainingQuantizer: are dynamically quantized to INT8 to utilize INT8 tensor cores. This is also done in the backward pass. - NOTE: due to the limitations of the current implementation, the following + NOTE: Due to the limitations of the current implementation, the following requirements must be satisfied to enjoy speedup: 1. Must use ``torch.compile()`` (set ``compile=True``). - 2. Inputs to the model must not be too dynamic e.g. input sequence length changes - for every batch. + 2. Inputs to the model must not be too dynamic. For example, when input tokens + length changes for every batch, you won't see the expected speedup. To satisfy (2), you can use :class:`~torchtune.datasets.PackedDataset` (set ``dataset.packed=True``), which ensures input tokens always have fixed length. From 60dad97e7348c295927df5d7daed5984db9dc2d8 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 23:14:05 +0800 Subject: [PATCH 06/37] update docs --- torchtune/training/quantization.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index a3c7bc1889..cf40a76555 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -83,19 +83,33 @@ def quantize(self, model): class Int8MixedPrecisionTrainingQuantizer: - """Apply INT8 mixed-precision training. During training, weights and activations - are dynamically quantized to INT8 to utilize INT8 tensor cores. This is also done - in the backward pass. + """Apply INT8 mixed-precision training. This only affects weights of ``nn.Linear`` + modules. During training, weights and activations are dynamically quantized to INT8 + to utilize fast matrix multiplication with INT8 tensor cores. This is also done in + the backward pass. + + The expected end2end speedup is 40% on a single A100 and 70% on a single 4090, with + minimal accuracy loss. If convergence is an issue, please refer to torchao + documentation below. + + For more details, as well as details about arguments of this quantizer, please refer to + https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision + + Args: + output (bool): whether to apply INT8 mixed-precision for calculating output. + grad_input (bool): whether to apply INT8 mixed-precision for calculating grad_input. + grad_weight (bool): whether to apply INT8 mixed-precision for calculating grad_weight. NOTE: Due to the limitations of the current implementation, the following - requirements must be satisfied to enjoy speedup: + requirements must be satisfied to enjoy the expected speedup: 1. Must use ``torch.compile()`` (set ``compile=True``). 2. Inputs to the model must not be too dynamic. For example, when input tokens length changes for every batch, you won't see the expected speedup. To satisfy (2), you can use :class:`~torchtune.datasets.PackedDataset` (set - ``dataset.packed=True``), which ensures input tokens always have fixed length. + ``dataset.packed=True`` and ``tokenizer.max_seq_len`` to a desired value.), which + ensures input tokens always have fixed length. """ def __init__( From 8395070e2a7fd0a57ffca512c550ff3e5030f628 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 23:19:24 +0800 Subject: [PATCH 07/37] add tiny log --- recipes/full_finetune_single_device.py | 1 + 1 file changed, 1 insertion(+) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 91d9b4ced5..3f61733845 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -363,6 +363,7 @@ def _setup_model( ) if quantizer_cfg is not None: + log.info(f"Preparing model with {quantizer_cfg._component_}") quantizer = config.instantiate(quantizer_cfg) model = quantizer.prepare(model) From b7b8a7dd6961614556ae4b5e7007bfd0b584c55c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 23:36:08 +0800 Subject: [PATCH 08/37] update comment --- torchtune/training/quantization.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index cf40a76555..9b505bc1b2 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -125,8 +125,9 @@ def __init__( ) def prepare(self, model: nn.Module) -> nn.Module: - # don't apply INT8 mixed-precision training to LM head - # since speed is slightly lower. + # Don't apply INT8 mixed-precision training to LM head since end2end speedup + # will be slightly worse. There are also possible issues with tied word + # embeddings. quantize_fn = int8_mixed_precision_training(self._config) if isinstance(model, TransformerDecoder): quantize_(model.layers, quantize_fn) From 2829b035c33d48c053146e229e881eec01e6170f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 13 Sep 2024 23:48:38 +0800 Subject: [PATCH 09/37] add guard on torch version and CUDA sm --- torchtune/training/quantization.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 9b505bc1b2..93c31cb4ad 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -6,6 +6,7 @@ from typing import Callable, Optional +import torch from torch import nn from torchao.prototype.quantized_training import ( int8_mixed_precision_training, @@ -24,6 +25,15 @@ ) from torchtune.modules import TransformerDecoder +from torchtune.utils._version import torch_version_ge + + +# TODO: add guard to torchao version +_SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( + torch_version_ge("2.4.0") + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) +) __all__ = [ @@ -100,6 +110,9 @@ class Int8MixedPrecisionTrainingQuantizer: grad_input (bool): whether to apply INT8 mixed-precision for calculating grad_input. grad_weight (bool): whether to apply INT8 mixed-precision for calculating grad_weight. + Raises: + RuntimeError: If runtime requirements for INT8 mixed-precision training are not met. + NOTE: Due to the limitations of the current implementation, the following requirements must be satisfied to enjoy the expected speedup: @@ -118,6 +131,12 @@ def __init__( grad_input: bool = True, grad_weight: bool = True, ) -> None: + if not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: + raise RuntimeError( + "INT8 mixed-precision training requires torch>=2.4 and a CUDA capable" + " device with compute capability >= 8.0" + ) + self._config = Int8MixedPrecisionTrainingConfig( output=output, grad_input=grad_input, From 688a1c87c71c343d00f266be346e099dbe6e0f9c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 00:01:34 +0800 Subject: [PATCH 10/37] add integration test --- .../test_full_finetune_single_device.py | 88 +++++++++++++++++++ tests/test_utils.py | 18 +++- 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 9b8a75ceb9..1ea75cc2f1 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -28,6 +28,7 @@ CKPT_MODEL_PATHS, gen_log_file_name, get_loss_values_from_metric_logger, + get_tps_values_from_metric_logger, TOKENIZER_PATHS, ) @@ -263,3 +264,90 @@ def test_gradient_accumulation(self, tmpdir, monkeypatch): accum_loss = np.mean(get_loss_values_from_metric_logger(grad_accum_log_file)) torch.testing.assert_close(no_accum_loss, accum_loss, atol=1e-5, rtol=1e-5) + + +class TestFullFinetuneInt8MixedPrecisionTraining: + def _get_test_config_overrides(self): + return [ + "dataset=tests.recipes.utils.DummyDataset", + "dataset.train_on_input=False", + "seed=9", + "epochs=1", + "max_steps_per_epoch=5", + "optimizer=torch.optim.AdamW", + "optimizer_in_bwd=False", + "compile=True", + ] + + @pytest.mark.integration_test + def test_speed(self, tmpdir, monkeypatch): + model_type = "llama3" + ckpt_type = "tune" + ckpt_component = CKPT_COMPONENT_MAP["tune"] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file_baseline = gen_log_file_name(tmpdir, suffix="baseline") + log_file_int8mp = gen_log_file_name(tmpdir, suffix="int8mp") + + model_config = MODEL_TEST_CONFIGS[model_type] + + # set dataset.packed=True to have fixed input seq len + cmd1 = f""" + tune run full_finetune_single_device \ + --config llama3/8B_full_single_device \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + tokenizer.max_seq_len=4096 \ + dataset.packed=True \ + metric_logger.filename={log_file_baseline} \ + compile=True \ + """.split() + cmd1 = cmd1 + self._get_test_config_overrides() + model_config + + # Make sure to clear compile state in between tests + torch._dynamo.reset() + monkeypatch.setattr(sys, "argv", cmd1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + quantizer = ( + "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer" + ) + cmd2 = f""" + tune run full_finetune_single_device \ + --config llama3/8B_full_single_device \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + tokenizer.max_seq_len=4096 \ + dataset.packed=True \ + metric_logger.filename={log_file_int8mp} \ + compile=True \ + quantizer._component=quantizer._component_={quantizer} \ + """.split() + cmd2 = cmd2 + self._get_test_config_overrides() + model_config + + torch._dynamo.reset() + monkeypatch.setattr(sys, "argv", cmd2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # skip the first iteration since it includes compile time + tps_baseline = get_tps_values_from_metric_logger(log_file_baseline)[1:] + tps_int8mp = get_tps_values_from_metric_logger(log_file_int8mp)[1:] + + # check that it is at least 20% faster + assert np.mean(tps_int8mp) > np.mean(tps_baseline) * 1.2 diff --git a/tests/test_utils.py b/tests/test_utils.py index 211e783c3f..458a8c48ac 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,7 +13,7 @@ from functools import partial from io import StringIO from pathlib import Path -from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union +from typing import Any, Generator, List, Mapping, Optional, TextIO, Tuple, Union import pytest @@ -332,7 +332,7 @@ def gpu_test(gpu_count: int = 1): return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message) -def get_loss_values_from_metric_logger(log_file_path: str) -> Dict[str, float]: +def get_loss_values_from_metric_logger(log_file_path: str) -> list[float]: """ Given an output directory containing metric logger .txt file, parse the .txt and return a list of losses from each logged iteration. @@ -343,6 +343,20 @@ def get_loss_values_from_metric_logger(log_file_path: str) -> Dict[str, float]: return losses +def get_tps_values_from_metric_logger(log_file_path: str) -> list[float]: + """ + Given an output directory containing metric logger .txt file, + parse the .txt and return a list of tokens per second (tps) values + from each logged iteration. + """ + with open(log_file_path, "r") as f: + logs = f.read() + tps_values = [ + float(x) for x in re.findall(r"tokens_per_second_per_gpu:(\d+\.\d+)", logs) + ] + return tps_values + + def gen_log_file_name(tmpdir, suffix: Optional[str] = None) -> str: """ Take the tmpdir and just append a non-path version of it as the From 21391add834fdb40e683e1709c953a762d95898c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 00:06:47 +0800 Subject: [PATCH 11/37] update test --- tests/recipes/test_full_finetune_single_device.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 1ea75cc2f1..f920b6ba35 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -346,8 +346,9 @@ def test_speed(self, tmpdir, monkeypatch): runpy.run_path(TUNE_PATH, run_name="__main__") # skip the first iteration since it includes compile time - tps_baseline = get_tps_values_from_metric_logger(log_file_baseline)[1:] - tps_int8mp = get_tps_values_from_metric_logger(log_file_int8mp)[1:] + tps_baseline = np.mean(get_tps_values_from_metric_logger(log_file_baseline)[1:]) + tps_int8mp = np.mean(get_tps_values_from_metric_logger(log_file_int8mp)[1:]) # check that it is at least 20% faster - assert np.mean(tps_int8mp) > np.mean(tps_baseline) * 1.2 + speedup = tps_int8mp / tps_baseline + assert speedup > 1.2, speedup From f885d56921d8e35409dc980fe6784ed5cdc923de Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 00:37:41 +0800 Subject: [PATCH 12/37] use dummy alpaca --- tests/recipes/test_full_finetune_single_device.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index f920b6ba35..1e8bd849b6 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -269,15 +269,13 @@ def test_gradient_accumulation(self, tmpdir, monkeypatch): class TestFullFinetuneInt8MixedPrecisionTraining: def _get_test_config_overrides(self): return [ - "dataset=tests.recipes.utils.DummyDataset", - "dataset.train_on_input=False", "seed=9", "epochs=1", "max_steps_per_epoch=5", "optimizer=torch.optim.AdamW", "optimizer_in_bwd=False", "compile=True", - ] + ] + dummy_alpaca_dataset_config() @pytest.mark.integration_test def test_speed(self, tmpdir, monkeypatch): @@ -305,7 +303,7 @@ def test_speed(self, tmpdir, monkeypatch): checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ - tokenizer.max_seq_len=4096 \ + tokenizer.max_seq_len=256 \ dataset.packed=True \ metric_logger.filename={log_file_baseline} \ compile=True \ From 7db782c8414882e4edc70066fc648d241e50dbba Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 08:17:40 +0800 Subject: [PATCH 13/37] fix typo --- tests/recipes/test_full_finetune_single_device.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 1e8bd849b6..17189958be 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -334,7 +334,7 @@ def test_speed(self, tmpdir, monkeypatch): dataset.packed=True \ metric_logger.filename={log_file_int8mp} \ compile=True \ - quantizer._component=quantizer._component_={quantizer} \ + quantizer._component_={quantizer} \ """.split() cmd2 = cmd2 + self._get_test_config_overrides() + model_config From 25a24519d0488e59fff1cf58e0cde024438a0251 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 08:56:12 +0800 Subject: [PATCH 14/37] convert speed test to smoke test --- .../test_full_finetune_single_device.py | 84 ++++++++----------- tests/test_utils.py | 14 ---- 2 files changed, 35 insertions(+), 63 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 17189958be..eda71805e6 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -28,9 +28,9 @@ CKPT_MODEL_PATHS, gen_log_file_name, get_loss_values_from_metric_logger, - get_tps_values_from_metric_logger, TOKENIZER_PATHS, ) +from torchtune.models.llama3 import llama3 class TestFullFinetuneSingleDeviceRecipe: @@ -278,75 +278,61 @@ def _get_test_config_overrides(self): ] + dummy_alpaca_dataset_config() @pytest.mark.integration_test - def test_speed(self, tmpdir, monkeypatch): + def test_smoke(self, tmpdir, monkeypatch): model_type = "llama3" - ckpt_type = "tune" ckpt_component = CKPT_COMPONENT_MAP["tune"] - ckpt = model_type + "_" + ckpt_type - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) tokenizer_path = Path(TOKENIZER_PATHS[model_type]) - ckpt_dir = ckpt_path.parent - log_file_baseline = gen_log_file_name(tmpdir, suffix="baseline") - log_file_int8mp = gen_log_file_name(tmpdir, suffix="int8mp") - - model_config = MODEL_TEST_CONFIGS[model_type] - - # set dataset.packed=True to have fixed input seq len - cmd1 = f""" - tune run full_finetune_single_device \ - --config llama3/8B_full_single_device \ - output_dir={tmpdir} \ - checkpointer._component_={ckpt_component} \ - checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}]\ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type={model_type.upper()} \ - tokenizer.path='{tokenizer_path}' \ - tokenizer.prompt_template=null \ - tokenizer.max_seq_len=256 \ - dataset.packed=True \ - metric_logger.filename={log_file_baseline} \ - compile=True \ - """.split() - cmd1 = cmd1 + self._get_test_config_overrides() + model_config + log_file = gen_log_file_name(tmpdir) - # Make sure to clear compile state in between tests - torch._dynamo.reset() - monkeypatch.setattr(sys, "argv", cmd1) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") + # MODEL_TEST_CONFIGS["llama3"] doesn't work with FlexAttention + # because some dims are not multiple of 128 + # create a dummy model and save state_dict so it works with torchtune + model_config = [ + "model._component_=torchtune.models.llama3.llama3", + "model.vocab_size=128_256", + "model.num_layers=2", + "model.num_heads=2", + "model.num_kv_heads=2", + "model.embed_dim=256", + "model.max_seq_len=1024", + ] + dummy_model = llama3( + vocab_size=128_256, + num_layers=2, + num_heads=8, + num_kv_heads=8, + embed_dim=1024, + max_seq_len=1024, + ) + ckpt_dir = tmpdir / "ckpt_dir" + ckpt_dir.mkdir() + torch.save(dummy_model.state_dict(), ckpt_dir / "model.pt") quantizer = ( "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer" ) - cmd2 = f""" + + # set dataset.packed=True to have fixed input seq len + cmd = f""" tune run full_finetune_single_device \ --config llama3/8B_full_single_device \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.checkpoint_files=[model.pt]\ checkpointer.output_dir={tmpdir} \ checkpointer.model_type={model_type.upper()} \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ - tokenizer.max_seq_len=4096 \ + tokenizer.max_seq_len=256 \ dataset.packed=True \ - metric_logger.filename={log_file_int8mp} \ + metric_logger.filename={log_file} \ compile=True \ - quantizer._component_={quantizer} \ + quantizer._component_={quantizer}" \ """.split() - cmd2 = cmd2 + self._get_test_config_overrides() + model_config + cmd = cmd + self._get_test_config_overrides() + model_config torch._dynamo.reset() - monkeypatch.setattr(sys, "argv", cmd2) + monkeypatch.setattr(sys, "argv", cmd) with pytest.raises(SystemExit, match=""): runpy.run_path(TUNE_PATH, run_name="__main__") - - # skip the first iteration since it includes compile time - tps_baseline = np.mean(get_tps_values_from_metric_logger(log_file_baseline)[1:]) - tps_int8mp = np.mean(get_tps_values_from_metric_logger(log_file_int8mp)[1:]) - - # check that it is at least 20% faster - speedup = tps_int8mp / tps_baseline - assert speedup > 1.2, speedup diff --git a/tests/test_utils.py b/tests/test_utils.py index 458a8c48ac..66728eeee5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -343,20 +343,6 @@ def get_loss_values_from_metric_logger(log_file_path: str) -> list[float]: return losses -def get_tps_values_from_metric_logger(log_file_path: str) -> list[float]: - """ - Given an output directory containing metric logger .txt file, - parse the .txt and return a list of tokens per second (tps) values - from each logged iteration. - """ - with open(log_file_path, "r") as f: - logs = f.read() - tps_values = [ - float(x) for x in re.findall(r"tokens_per_second_per_gpu:(\d+\.\d+)", logs) - ] - return tps_values - - def gen_log_file_name(tmpdir, suffix: Optional[str] = None) -> str: """ Take the tmpdir and just append a non-path version of it as the From 6094cdb804a2f44f62b1ab3f5ca523c59846c392 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 09:07:43 +0800 Subject: [PATCH 15/37] fix test --- tests/recipes/test_full_finetune_single_device.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index eda71805e6..4e331a6d88 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -299,19 +299,15 @@ def test_smoke(self, tmpdir, monkeypatch): dummy_model = llama3( vocab_size=128_256, num_layers=2, - num_heads=8, - num_kv_heads=8, - embed_dim=1024, + num_heads=2, + num_kv_heads=2, + embed_dim=256, max_seq_len=1024, ) ckpt_dir = tmpdir / "ckpt_dir" ckpt_dir.mkdir() torch.save(dummy_model.state_dict(), ckpt_dir / "model.pt") - quantizer = ( - "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer" - ) - # set dataset.packed=True to have fixed input seq len cmd = f""" tune run full_finetune_single_device \ @@ -328,7 +324,7 @@ def test_smoke(self, tmpdir, monkeypatch): dataset.packed=True \ metric_logger.filename={log_file} \ compile=True \ - quantizer._component_={quantizer}" \ + quantizer._component_=torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer \ """.split() cmd = cmd + self._get_test_config_overrides() + model_config From 19a2d3ef0bf42e6f8c333d4b139e0c5a9ad0c732 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 09:15:30 +0800 Subject: [PATCH 16/37] add ao version guard --- torchtune/training/quantization.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 93c31cb4ad..500af9bda4 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -8,10 +8,6 @@ import torch from torch import nn -from torchao.prototype.quantized_training import ( - int8_mixed_precision_training, - Int8MixedPrecisionTrainingConfig, -) from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_ from torchao.quantization.prototype.qat import ( disable_8da4w_fake_quant, @@ -25,16 +21,25 @@ ) from torchtune.modules import TransformerDecoder +from torchtune.modules.low_precision._utils import _get_torchao_version from torchtune.utils._version import torch_version_ge -# TODO: add guard to torchao version +_TORCHAO_VERSION, _ = _get_torchao_version() + _SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( torch_version_ge("2.4.0") + and _TORCHAO_VERSION.split(".") >= ("0", "5", "0") and torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) ) +if _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: + from torchao.prototype.quantized_training import ( + int8_mixed_precision_training, + Int8MixedPrecisionTrainingConfig, + ) + __all__ = [ "get_quantizer_mode", @@ -133,8 +138,8 @@ def __init__( ) -> None: if not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: raise RuntimeError( - "INT8 mixed-precision training requires torch>=2.4 and a CUDA capable" - " device with compute capability >= 8.0" + "INT8 mixed-precision training requires torch>=2.4, torchao>=0.5, and" + " a CUDA capable device with compute capability >= 8.0" ) self._config = Int8MixedPrecisionTrainingConfig( From faec18da1855dc52076a8443d6d8b9b6af070bd1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 14 Sep 2024 09:16:05 +0800 Subject: [PATCH 17/37] fix --- torchtune/training/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 500af9bda4..f99d0b0c6f 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -29,7 +29,7 @@ _SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( torch_version_ge("2.4.0") - and _TORCHAO_VERSION.split(".") >= ("0", "5", "0") + and tuple(_TORCHAO_VERSION.split(".")) >= ("0", "5", "0") and torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) ) From 8fc2826b609fa8e0db63653ac566c49b7bf69d69 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 15 Sep 2024 07:58:13 +0800 Subject: [PATCH 18/37] attempt LoRA --- recipes/lora_finetune_single_device.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 0862675a77..6d65b2aa20 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -391,6 +391,20 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + use_int8_mixed_precision_training = False + if use_int8_mixed_precision_training: + from torchao import quantize_ + from torchao.prototype.quantized_training import ( + int8_mixed_precision_training, + ) + + from torchtune.modules.peft import DoRALinear, LoRALinear + + def filter_fn(module, name): + return isinstance(module, LoRALinear, DoRALinear) + + quantize_(model, int8_mixed_precision_training(), filter_fn=filter_fn) + base_missing, base_unexpected = model.load_state_dict( base_model_state_dict, strict=False ) From 911df57585e7d6a5593a87fc41bc0e4f52dd84de Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 15 Sep 2024 08:04:17 +0800 Subject: [PATCH 19/37] fix lora --- recipes/lora_finetune_single_device.py | 19 ++++++------------- torchtune/training/quantization.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 6d65b2aa20..38b0236063 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -229,6 +229,7 @@ def setup(self, cfg: DictConfig) -> None: if self._resume_from_checkpoint else None ), + quantizer_cfg=cfg.get("quantizer", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -370,6 +371,7 @@ def _setup_model( compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, + quantizer_cfg: Optional[DictConfig] = None, ) -> nn.Module: with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) @@ -391,19 +393,10 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - use_int8_mixed_precision_training = False - if use_int8_mixed_precision_training: - from torchao import quantize_ - from torchao.prototype.quantized_training import ( - int8_mixed_precision_training, - ) - - from torchtune.modules.peft import DoRALinear, LoRALinear - - def filter_fn(module, name): - return isinstance(module, LoRALinear, DoRALinear) - - quantize_(model, int8_mixed_precision_training(), filter_fn=filter_fn) + if quantizer_cfg is not None: + log.info(f"Preparing model with {quantizer_cfg._component_}") + quantizer = config.instantiate(quantizer_cfg) + model = quantizer.prepare(model) base_missing, base_unexpected = model.load_state_dict( base_model_state_dict, strict=False diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index f99d0b0c6f..9e1eedbef0 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -22,6 +22,7 @@ from torchtune.modules import TransformerDecoder from torchtune.modules.low_precision._utils import _get_torchao_version +from torchtune.modules.peft import DoRALinear, LoRALinear from torchtune.utils._version import torch_version_ge @@ -149,14 +150,19 @@ def __init__( ) def prepare(self, model: nn.Module) -> nn.Module: + quantize_fn = int8_mixed_precision_training(self._config) + + # custom filter_fn to work with torchtune's peft + def filter_fn(module, name): + return isinstance(module, (nn.Linear, LoRALinear, DoRALinear)) + # Don't apply INT8 mixed-precision training to LM head since end2end speedup # will be slightly worse. There are also possible issues with tied word # embeddings. - quantize_fn = int8_mixed_precision_training(self._config) if isinstance(model, TransformerDecoder): - quantize_(model.layers, quantize_fn) + quantize_(model.layers, quantize_fn, filter_fn=filter_fn) else: - quantize_(model, quantize_fn) + quantize_(model, quantize_fn, filter_fn=filter_fn) return model From 51bbeac1bf752b2664de9fa2139e49a6ddc4aba5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 15 Sep 2024 08:27:52 +0800 Subject: [PATCH 20/37] skip LoRA --- torchtune/training/quantization.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 9e1eedbef0..7f258853c3 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -153,8 +153,11 @@ def prepare(self, model: nn.Module) -> nn.Module: quantize_fn = int8_mixed_precision_training(self._config) # custom filter_fn to work with torchtune's peft - def filter_fn(module, name): - return isinstance(module, (nn.Linear, LoRALinear, DoRALinear)) + def filter_fn(module: nn.Module, name: str) -> bool: + if isinstance(module, nn.Linear): + return not (name.endswith(".lora_a") or name.endswith(".lora_b")) + + return isinstance(module, (LoRALinear, DoRALinear)) # Don't apply INT8 mixed-precision training to LM head since end2end speedup # will be slightly worse. There are also possible issues with tied word From 1e5ae927cc060a7b6527fee90db8562064c8877b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 15 Sep 2024 08:54:02 +0800 Subject: [PATCH 21/37] skip NF4 --- torchtune/training/quantization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 7f258853c3..ea94f6fda1 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -157,7 +157,10 @@ def filter_fn(module: nn.Module, name: str) -> bool: if isinstance(module, nn.Linear): return not (name.endswith(".lora_a") or name.endswith(".lora_b")) - return isinstance(module, (LoRALinear, DoRALinear)) + if isinstance(module, (LoRALinear, DoRALinear)): + return not module._quantize_base # doesn't work with NF4 yet + + return False # Don't apply INT8 mixed-precision training to LM head since end2end speedup # will be slightly worse. There are also possible issues with tied word From 45b4365f84ee99b14abbda16e5dcdbd2adcb17e7 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 3 Oct 2024 13:56:22 -0400 Subject: [PATCH 22/37] typo --- torchtune/training/quantization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 1b394967b8..3777db638e 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -108,9 +108,7 @@ class Int4WeightOnlyQuantizer: to linear layers in the model using the efficient tinygemm kernel. """ - def __init__(self, groupsize: int = 128, i - - er_k_tiles: int = 8): + def __init__(self, groupsize: int = 128, inner_k_tiles: int = 8): self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles From 1ac836aca9810aa126aeba67b83d7fd09395a3e5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 19:01:19 +0800 Subject: [PATCH 23/37] remove unwanted chnages --- tests/recipes/test_full_finetune_single_device.py | 1 - tests/test_utils.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/recipes/test_full_finetune_single_device.py b/tests/recipes/test_full_finetune_single_device.py index 8f4ff8fde7..819c70fdf0 100644 --- a/tests/recipes/test_full_finetune_single_device.py +++ b/tests/recipes/test_full_finetune_single_device.py @@ -28,7 +28,6 @@ get_loss_values_from_metric_logger, TOKENIZER_PATHS, ) -from torchtune.models.llama3 import llama3 class TestFullFinetuneSingleDeviceRecipe: diff --git a/tests/test_utils.py b/tests/test_utils.py index 2599a287f3..bcb26285a1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,7 +13,7 @@ from functools import partial from io import StringIO from pathlib import Path -from typing import Any, Generator, List, Mapping, Optional, TextIO, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, TextIO, Tuple, Union import pytest @@ -306,7 +306,7 @@ def gpu_test(gpu_count: int = 1): return pytest.mark.skipif(local_gpu_count < gpu_count, reason=message) -def get_loss_values_from_metric_logger(log_file_path: str) -> list[float]: +def get_loss_values_from_metric_logger(log_file_path: str) -> Dict[str, float]: """ Given an output directory containing metric logger .txt file, parse the .txt and return a list of losses from each logged iteration. From 5d94cb39594739e9ff8b77753e43facf10dbfbb1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 19:07:21 +0800 Subject: [PATCH 24/37] use module swap --- torchtune/training/quantization.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 86b85ef44a..a931a45326 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -8,7 +8,9 @@ from warnings import warn import torch +import torchao from torch import nn +from packaging.version import Version try: # torchao 0.7+ @@ -53,7 +55,7 @@ _SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( torch_version_ge("2.4.0") - # and tuple(_TORCHAO_VERSION.split(".")) >= ("0", "5", "0") + and Version(torchao.__version__) >= Version("0.7") and torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) ) @@ -207,8 +209,8 @@ def __init__( ) -> None: if not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: raise RuntimeError( - "INT8 mixed-precision training requires torch>=2.4, torchao>=0.5, and" - " a CUDA capable device with compute capability >= 8.0" + "INT8 mixed-precision training requires torch>=2.4, torchao>=0.7, and" + " a CUDA-capable device with compute capability >= 8.0" ) self._config = Int8MixedPrecisionTrainingConfig( @@ -218,21 +220,16 @@ def __init__( ) def prepare(self, model: nn.Module) -> nn.Module: - quantize_fn = int8_mixed_precision_training(self._config) + quantize_fn = int8_mixed_precision_training(self._config, module_swap=True) # custom filter_fn to work with torchtune's peft def filter_fn(module: nn.Module, name: str) -> bool: if isinstance(module, nn.Linear): return not (name.endswith(".lora_a") or name.endswith(".lora_b")) - - if isinstance(module, (LoRALinear, DoRALinear)): - return not module._quantize_base # doesn't work with NF4 yet - return False # Don't apply INT8 mixed-precision training to LM head since end2end speedup - # will be slightly worse. There are also possible issues with tied word - # embeddings. + # will be slightly worse. There are also possible issues with tied word embeddings. if isinstance(model, TransformerDecoder): quantize_(model.layers, quantize_fn, filter_fn=filter_fn) else: From 06abd88092c950fba20add1b0c7473f5441646d1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 19:08:52 +0800 Subject: [PATCH 25/37] remove unused import --- torchtune/training/quantization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index a931a45326..538b9d6d4f 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -9,8 +9,8 @@ import torch import torchao -from torch import nn from packaging.version import Version +from torch import nn try: # torchao 0.7+ @@ -49,7 +49,6 @@ ) from torchtune.modules import TransformerDecoder -from torchtune.modules.peft import DoRALinear, LoRALinear from torchtune.utils._version import torch_version_ge From 0ff702e5eca32242f8b56397812cabb97ea4a8bd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 19:42:52 +0800 Subject: [PATCH 26/37] update docs. change to mixed_precision --- .../llama3_2/1B_full_single_device.yaml | 5 ++++ recipes/full_finetune_distributed.py | 9 ++++++ recipes/full_finetune_single_device.py | 12 ++++---- torchtune/training/quantization.py | 30 ++++++++++++------- 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index fe641f3479..76f4b667cb 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -79,6 +79,11 @@ output_dir: /tmp/full-llama3.2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 1ce0db98d8..8c2b40114f 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -275,6 +275,7 @@ def setup(self, cfg: DictConfig) -> None: model_state_dict=checkpoint_dict[training.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), + mixed_precision_cfg=cfg.get("mixed_precision", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -415,6 +416,7 @@ def _setup_model( custom_sharded_layers: Optional[List[str]] = None, ac_mode: Optional[str] = None, ac_option: Optional[int] = None, + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: """ Model initialization has some important considerations: @@ -455,6 +457,13 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if mixed_precision_cfg is not None and mixed_precision_cfg.get("enabled", False): + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = dict(mixed_precision_cfg) # shallow copy + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) + model = quantizer.prepare(model) + # For FSDP sharding fsdp_shard_conditions = [ partial( diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index d8263e1967..bf687e06d0 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -263,7 +263,7 @@ def setup(self, cfg: DictConfig) -> None: enable_activation_offloading=self._enable_activation_offloading, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], - quantizer_cfg=cfg.get("quantizer", None), + mixed_precision_cfg=cfg.get("mixed_precision", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -407,7 +407,7 @@ def _setup_model( enable_activation_offloading: bool, compile_model: bool, model_state_dict: Dict[str, Any], - quantizer_cfg: Optional[DictConfig] = None, + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: """ Set up the model including enabling activation checkpointing. @@ -423,9 +423,11 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if quantizer_cfg is not None: - log.info(f"Preparing model with {quantizer_cfg._component_}") - quantizer = config.instantiate(quantizer_cfg) + if mixed_precision_cfg is not None and mixed_precision_cfg.get("enabled", False): + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = dict(mixed_precision_cfg) # shallow copy + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) model = quantizer.prepare(model) model.load_state_dict(model_state_dict) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 538b9d6d4f..109663beb2 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -48,7 +48,6 @@ Int8DynActInt4WeightQATQuantizer, ) -from torchtune.modules import TransformerDecoder from torchtune.utils._version import torch_version_ge @@ -181,9 +180,9 @@ class Int8MixedPrecisionTrainingQuantizer: https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision Args: - output (bool): whether to apply INT8 mixed-precision for calculating output. - grad_input (bool): whether to apply INT8 mixed-precision for calculating grad_input. - grad_weight (bool): whether to apply INT8 mixed-precision for calculating grad_weight. + output (bool): whether to apply INT8 mixed-precision for calculating output. Default: True + grad_input (bool): whether to apply INT8 mixed-precision for calculating grad_input. Default: True + grad_weight (bool): whether to apply INT8 mixed-precision for calculating grad_weight. Default: True Raises: RuntimeError: If runtime requirements for INT8 mixed-precision training are not met. @@ -219,20 +218,29 @@ def __init__( ) def prepare(self, model: nn.Module) -> nn.Module: + # we use module-swap implementation so that the state_dict remains plain tensors, + # as well as better FSDP compatibility in torchtune. quantize_fn = int8_mixed_precision_training(self._config, module_swap=True) # custom filter_fn to work with torchtune's peft def filter_fn(module: nn.Module, name: str) -> bool: if isinstance(module, nn.Linear): - return not (name.endswith(".lora_a") or name.endswith(".lora_b")) + # skip LoRA adapters since they are too small, so the speedup will not + # outweight quantization overhead. + # also skip LM head since end2end speedup is slightly worse. + # there are also possible issues with tied word embeddings. + if ( + name.endswith(".lora_a") + or name.endswith(".lora_b") + or module.weight.shape[0] >= 32_000 + ): + return False + else: + return True + return False - # Don't apply INT8 mixed-precision training to LM head since end2end speedup - # will be slightly worse. There are also possible issues with tied word embeddings. - if isinstance(model, TransformerDecoder): - quantize_(model.layers, quantize_fn, filter_fn=filter_fn) - else: - quantize_(model, quantize_fn, filter_fn=filter_fn) + quantize_(model, quantize_fn, filter_fn=filter_fn) return model From 05563f2753a5e0ef574ce796dcc89e0d004dcc1c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 20:00:23 +0800 Subject: [PATCH 27/37] add test. small fixes --- recipes/full_finetune_distributed.py | 6 ++- recipes/full_finetune_single_device.py | 6 ++- tests/torchtune/training/test_quantization.py | 43 +++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 tests/torchtune/training/test_quantization.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 8c2b40114f..7506cd49ae 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -457,9 +457,11 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if mixed_precision_cfg is not None and mixed_precision_cfg.get("enabled", False): + if mixed_precision_cfg is not None and mixed_precision_cfg.get( + "enabled", False + ): log.info(f"Preparing model with {mixed_precision_cfg._component_}") - cfg = dict(mixed_precision_cfg) # shallow copy + cfg = mixed_precision_cfg.copy() cfg.pop("enabled", None) quantizer = config.instantiate(cfg) model = quantizer.prepare(model) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index bf687e06d0..f84fdd341f 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -423,9 +423,11 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if mixed_precision_cfg is not None and mixed_precision_cfg.get("enabled", False): + if mixed_precision_cfg is not None and mixed_precision_cfg.get( + "enabled", False + ): log.info(f"Preparing model with {mixed_precision_cfg._component_}") - cfg = dict(mixed_precision_cfg) # shallow copy + cfg = mixed_precision_cfg.copy() cfg.pop("enabled", None) quantizer = config.instantiate(cfg) model = quantizer.prepare(model) diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py new file mode 100644 index 0000000000..2b8a5afea3 --- /dev/null +++ b/tests/torchtune/training/test_quantization.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import gpu_test +from torch import nn +from torchtune.training.quantization import ( + _SUPPORTS_INT8_MIXED_PRECISION_TRAINING, + Int8MixedPrecisionTrainingQuantizer, +) + + +@gpu_test(gpu_count=1) +@pytest.mark.skipif( + not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING, + reason="INT8 mixed-precision training is not supported", +) +def test_int8_mixed_precision_training_quantizer(): + quantizer = Int8MixedPrecisionTrainingQuantizer() + model = nn.Sequential( + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + ).cuda() + quantizer.prepare(model) + + # make sure class is changed + assert model[0].__class__ != nn.Linear + assert model[2].__class__ != nn.Linear + + # smoke test forward and backward + model(torch.randn(2, 32).cuda()).sum().backward() + for p in model.parameters(): + assert p.grad is not None + + # state dict is plain tensor + state_dict = model.state_dict() + for v in state_dict.values(): + assert v.__class__ == torch.Tensor From 3050c32091e0ffdae83f3023e034d9cb4fdc18df Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 20:14:57 +0800 Subject: [PATCH 28/37] add config entries --- recipes/configs/llama3_1/8B_full.yaml | 5 +++++ recipes/configs/llama3_2_vision/11B_full.yaml | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 71ab8eedeb..55d10afaaf 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -82,3 +82,8 @@ metric_logger: output_dir: /tmp/full-llama3.1-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index f40cde9f90..0c49576fb6 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -82,3 +82,8 @@ metric_logger: log_dir: /tmp/Llama-3.2-11B-Vision-Instruct/logs log_every_n_steps: 1 log_peak_memory_stats: True + +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false From 864c6fb67189446690032ac2f503b5f756467ddf Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 20:33:41 +0800 Subject: [PATCH 29/37] remove extra compile --- recipes/configs/llama3_1/8B_full.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 55d10afaaf..0b01ecac3a 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -70,7 +70,6 @@ device: cuda enable_activation_checkpointing: True enable_activation_offloading: False # True reduces memory custom_sharded_layers: ['tok_embeddings', 'output'] -compile: False # pytorch compile, set to true for perf/memory improvement # Reduced precision dtype: bf16 From 1fed8598f01bf2403301b1bbc8d2a3f53731bcd7 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 3 Nov 2024 22:52:01 +0800 Subject: [PATCH 30/37] fix lora finetune --- recipes/lora_finetune_single_device.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 7815b8b537..538aadb770 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -266,7 +266,7 @@ def setup(self, cfg: DictConfig) -> None: if self._resume_from_checkpoint else None ), - quantizer_cfg=cfg.get("quantizer", None), + mixed_precision_cfg=cfg.get("mixed_precision", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -411,7 +411,7 @@ def _setup_model( compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, - quantizer_cfg: Optional[DictConfig] = None, + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) @@ -433,9 +433,13 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if quantizer_cfg is not None: - log.info(f"Preparing model with {quantizer_cfg._component_}") - quantizer = config.instantiate(quantizer_cfg) + if mixed_precision_cfg is not None and mixed_precision_cfg.get( + "enabled", False + ): + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = mixed_precision_cfg.copy() + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) model = quantizer.prepare(model) base_missing, base_unexpected = model.load_state_dict( From 0fecc26bfbfefa57fb3d20244949ec7d13738bca Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 12 Nov 2024 02:02:31 +0000 Subject: [PATCH 31/37] fix version check --- torchtune/training/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 109663beb2..7c3805063b 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -53,7 +53,7 @@ _SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( torch_version_ge("2.4.0") - and Version(torchao.__version__) >= Version("0.7") + and Version(torchao.__version__) >= Version("0.7.0.dev") and torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) ) From 39e1fc1f963774b41e536bd6ac060f12fc73013b Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 12 Nov 2024 02:20:07 +0000 Subject: [PATCH 32/37] dont set inductor config --- torchtune/training/quantization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 7c3805063b..2bf31c8702 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -240,7 +240,9 @@ def filter_fn(module: nn.Module, name: str) -> bool: return False - quantize_(model, quantize_fn, filter_fn=filter_fn) + # don't set inductor config, otherwise compile will be very slow + # (it will affect global torch.compile() config) + quantize_(model, quantize_fn, filter_fn=filter_fn, set_inductor_config=False) return model From a3349865f64c860098b75b114a9b505044cda4d3 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 5 Dec 2024 15:31:28 +0800 Subject: [PATCH 33/37] remove LoRA --- recipes/lora_finetune_single_device.py | 11 ----------- torchtune/training/quantization.py | 19 +++---------------- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 4475f48983..fcdb3e4ea5 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -276,7 +276,6 @@ def setup(self, cfg: DictConfig) -> None: if self._resume_from_checkpoint else None ), - mixed_precision_cfg=cfg.get("mixed_precision", None), ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -421,7 +420,6 @@ def _setup_model( compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, - mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) @@ -443,15 +441,6 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if mixed_precision_cfg is not None and mixed_precision_cfg.get( - "enabled", False - ): - log.info(f"Preparing model with {mixed_precision_cfg._component_}") - cfg = mixed_precision_cfg.copy() - cfg.pop("enabled", None) - quantizer = config.instantiate(cfg) - model = quantizer.prepare(model) - base_missing, base_unexpected = model.load_state_dict( base_model_state_dict, strict=False ) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index e0094228ec..add1ec26e8 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -230,23 +230,10 @@ def prepare(self, model: nn.Module) -> nn.Module: # as well as better FSDP compatibility in torchtune. quantize_fn = int8_mixed_precision_training(self._config, module_swap=True) - # custom filter_fn to work with torchtune's peft def filter_fn(module: nn.Module, name: str) -> bool: - if isinstance(module, nn.Linear): - # skip LoRA adapters since they are too small, so the speedup will not - # outweight quantization overhead. - # also skip LM head since end2end speedup is slightly worse. - # there are also possible issues with tied word embeddings. - if ( - name.endswith(".lora_a") - or name.endswith(".lora_b") - or module.weight.shape[0] >= 32_000 - ): - return False - else: - return True - - return False + # skip LM head since end2end speedup is slightly worse. + # there are also possible issues with tied word embeddings. + return isinstance(module, nn.Linear) and module.out_features < 32_000 # don't set inductor config, otherwise compile will be very slow # (it will affect global torch.compile() config) From d149801d42ada0834f182ae7f8a75612521b3842 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 5 Dec 2024 15:33:49 +0800 Subject: [PATCH 34/37] remove PyTorch version check --- torchtune/training/quantization.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index add1ec26e8..2dc4fe989a 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -50,12 +50,9 @@ Int8DynActInt4WeightQATQuantizer, ) -from torchtune.utils._version import torch_version_ge - _SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( - torch_version_ge("2.4.0") - and Version(torchao.__version__) >= Version("0.7.0.dev") + Version(torchao.__version__) >= Version("0.7.0.dev") and torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0) ) From 03a19788a5ddba7d3ddea2a40f7ec952d3f4d558 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 5 Dec 2024 19:32:00 +0800 Subject: [PATCH 35/37] add checks in init. add entries to all applicable configs --- recipes/configs/code_llama2/7B_full_low_memory.yaml | 5 +++++ recipes/configs/dev/8B_full_experimental.yaml | 5 +++++ recipes/configs/gemma/2B_full.yaml | 5 +++++ recipes/configs/gemma/7B_full.yaml | 5 +++++ recipes/configs/gemma2/27B_full.yaml | 5 +++++ recipes/configs/gemma2/2B_full.yaml | 5 +++++ recipes/configs/gemma2/9B_full.yaml | 5 +++++ recipes/configs/llama2/13B_full.yaml | 5 +++++ recipes/configs/llama2/7B_full.yaml | 5 +++++ recipes/configs/llama2/7B_full_low_memory.yaml | 5 +++++ recipes/configs/llama3/70B_full.yaml | 5 +++++ recipes/configs/llama3/8B_full.yaml | 5 +++++ recipes/configs/llama3/8B_full_single_device.yaml | 5 +++++ recipes/configs/llama3_1/70B_full.yaml | 5 +++++ recipes/configs/llama3_1/8B_full_single_device.yaml | 5 +++++ recipes/configs/llama3_2/1B_full.yaml | 5 +++++ recipes/configs/llama3_2/3B_full.yaml | 5 +++++ recipes/configs/llama3_2/3B_full_single_device.yaml | 5 +++++ recipes/configs/llama3_2_vision/11B_full_single_device.yaml | 5 +++++ recipes/configs/llama3_2_vision/90B_full.yaml | 5 +++++ recipes/configs/mistral/7B_full.yaml | 5 +++++ recipes/configs/mistral/7B_full_low_memory.yaml | 5 +++++ recipes/configs/phi3/mini_full.yaml | 5 +++++ recipes/configs/phi3/mini_full_low_memory.yaml | 5 +++++ recipes/configs/qwen2/0.5B_full.yaml | 5 +++++ recipes/configs/qwen2/0.5B_full_single_device.yaml | 5 +++++ recipes/configs/qwen2/1.5B_full.yaml | 5 +++++ recipes/configs/qwen2/1.5B_full_single_device.yaml | 5 +++++ recipes/configs/qwen2/7B_full.yaml | 5 +++++ recipes/configs/qwen2/7B_full_single_device.yaml | 5 +++++ recipes/configs/qwen2_5/0.5B_full.yaml | 5 +++++ recipes/configs/qwen2_5/0.5B_full_single_device.yaml | 5 +++++ recipes/configs/qwen2_5/1.5B_full.yaml | 5 +++++ recipes/configs/qwen2_5/1.5B_full_single_device.yaml | 5 +++++ recipes/configs/qwen2_5/3B_full.yaml | 5 +++++ recipes/configs/qwen2_5/3B_full_single_device.yaml | 5 +++++ recipes/configs/qwen2_5/7B_full.yaml | 5 +++++ recipes/configs/qwen2_5/7B_full_single_device.yaml | 5 +++++ recipes/full_finetune_distributed.py | 6 ++++++ recipes/full_finetune_single_device.py | 6 ++++++ 40 files changed, 202 insertions(+) diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index b6586e8b5a..3a555b0508 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -80,6 +80,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml index c3fb212093..d6e2b3a10f 100644 --- a/recipes/configs/dev/8B_full_experimental.yaml +++ b/recipes/configs/dev/8B_full_experimental.yaml @@ -82,6 +82,11 @@ output_dir: /tmp/alpaca-llama3-finetune log_every_n_steps: null log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index 5bd0f05a02..1ff80b67bc 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -76,6 +76,11 @@ output_dir: /tmp/alpaca-gemma-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index 4555235385..fbba3bb097 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -78,6 +78,11 @@ output_dir: /tmp/alpaca-gemma-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index ddc89b38b2..d8d7edd652 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -72,3 +72,8 @@ metric_logger: output_dir: /tmp/alpaca-gemma2-27b-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml index b87cf1ccf9..07f177dcfd 100644 --- a/recipes/configs/gemma2/2B_full.yaml +++ b/recipes/configs/gemma2/2B_full.yaml @@ -74,3 +74,8 @@ metric_logger: output_dir: /tmp/alpaca-gemma2-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index 0fc7e6e4e4..5d0a4ab88b 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -72,3 +72,8 @@ metric_logger: output_dir: /tmp/alpaca-gemma2-9b-finetune log_every_n_steps: 1 log_peak_memory_stats: True + +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index fd7b7421c1..d784fa9be1 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -80,6 +80,11 @@ output_dir: /tmp/alpaca-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index 7e69c8f5a6..391a801cac 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -79,6 +79,11 @@ output_dir: /tmp/alpaca-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index d7ee50898e..55e628beb0 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -83,6 +83,11 @@ output_dir: /tmp/alpaca-llama2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 5878b2fd95..8ee8b1cb94 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -110,6 +110,11 @@ output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index a065fa9ece..7d07acabbb 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -72,6 +72,11 @@ custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separatel # Reduced precision dtype: bf16 +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Logging metric_logger: _component_: torchtune.training.metric_logging.DiskLogger diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index a63845fc30..5396bdba51 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -82,6 +82,11 @@ output_dir: /tmp/full-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index d92fcef1f6..acf920c874 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -112,6 +112,11 @@ output_dir: /tmp/full-llama3_1-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 66f397e1df..502eda782b 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -82,6 +82,11 @@ output_dir: /tmp/full-llama3.1-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml index 56fc968b0d..5d850dd8cb 100644 --- a/recipes/configs/llama3_2/1B_full.yaml +++ b/recipes/configs/llama3_2/1B_full.yaml @@ -80,6 +80,11 @@ output_dir: /tmp/full-llama3.2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 4128bb58e7..eece82cea9 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -80,6 +80,11 @@ output_dir: /tmp/full-llama3.2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml index ebc49ae1fb..aa2549957c 100644 --- a/recipes/configs/llama3_2/3B_full_single_device.yaml +++ b/recipes/configs/llama3_2/3B_full_single_device.yaml @@ -80,6 +80,11 @@ output_dir: /tmp/full-llama3.2-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml index d10afdcbfe..a8b0a68e45 100644 --- a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml @@ -83,6 +83,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (default is disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2_vision/90B_full.yaml b/recipes/configs/llama3_2_vision/90B_full.yaml index 2ef3c271eb..6484380b16 100644 --- a/recipes/configs/llama3_2_vision/90B_full.yaml +++ b/recipes/configs/llama3_2_vision/90B_full.yaml @@ -80,6 +80,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index 23c82e1d71..365d44e50d 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -82,6 +82,11 @@ output_dir: /tmp/Mistral-7B-v0.1/ log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index 01de2f11ea..dedc4a38fd 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -85,6 +85,11 @@ output_dir: /tmp/Mistral-7B-v0.1/ log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index 594ffdc916..4bcd593e4b 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -77,6 +77,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index 05c1db379a..a2ac8c2739 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -78,6 +78,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 84336894be..bbb90d3d02 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -78,6 +78,11 @@ output_dir: /tmp/Qwen2-0.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 8b60a17090..aac740b0f5 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -78,6 +78,11 @@ output_dir: /tmp/Qwen2-0.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index 37b5c0a926..0d652e61bc 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -78,6 +78,11 @@ output_dir: /tmp/Qwen2-1.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index 2acdfb3810..fb168dd27d 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -83,6 +83,11 @@ output_dir: /tmp/Qwen2-1.5B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index 20d74346e1..eb3990e352 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -81,6 +81,11 @@ output_dir: /tmp/Qwen2-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index cff3244b18..5ca1b9612b 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -82,6 +82,11 @@ output_dir: /tmp/Qwen2-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/0.5B_full.yaml b/recipes/configs/qwen2_5/0.5B_full.yaml index 1298c058e9..73dd4f64f5 100644 --- a/recipes/configs/qwen2_5/0.5B_full.yaml +++ b/recipes/configs/qwen2_5/0.5B_full.yaml @@ -71,6 +71,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml index 39dfb2f8a0..21172a0659 100644 --- a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml @@ -71,6 +71,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/1.5B_full.yaml b/recipes/configs/qwen2_5/1.5B_full.yaml index e0fb09c152..3361391823 100644 --- a/recipes/configs/qwen2_5/1.5B_full.yaml +++ b/recipes/configs/qwen2_5/1.5B_full.yaml @@ -71,6 +71,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml index 480249631d..1dbc231c78 100644 --- a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml @@ -74,6 +74,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/3B_full.yaml b/recipes/configs/qwen2_5/3B_full.yaml index 7267dd5efe..65ca2d89c6 100644 --- a/recipes/configs/qwen2_5/3B_full.yaml +++ b/recipes/configs/qwen2_5/3B_full.yaml @@ -79,6 +79,11 @@ output_dir: /tmp/Qwen2_5-3B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/3B_full_single_device.yaml b/recipes/configs/qwen2_5/3B_full_single_device.yaml index ef8d283098..150145c0f3 100644 --- a/recipes/configs/qwen2_5/3B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/3B_full_single_device.yaml @@ -80,6 +80,11 @@ output_dir: /tmp/Qwen2_5-3B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/7B_full.yaml b/recipes/configs/qwen2_5/7B_full.yaml index e1de8d5584..9a7696f756 100644 --- a/recipes/configs/qwen2_5/7B_full.yaml +++ b/recipes/configs/qwen2_5/7B_full.yaml @@ -81,6 +81,11 @@ output_dir: /tmp/Qwen2_5-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/qwen2_5/7B_full_single_device.yaml b/recipes/configs/qwen2_5/7B_full_single_device.yaml index 3bc3428410..49fbe3a189 100644 --- a/recipes/configs/qwen2_5/7B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/7B_full_single_device.yaml @@ -82,6 +82,11 @@ output_dir: /tmp/Qwen2_5-7B-Instruct-finetune log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 2a5f221ebc..9bf324ea40 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -181,6 +181,12 @@ def __init__(self, cfg: DictConfig) -> None: "Enabling activation offloading should reduce memory further.", ) + if cfg.mixed_precision.enabled: + if not cfg.compile or not cfg.dataset.packed: + raise ValueError( + "When mixed_precision.enabled is True, both compile and dataset.packed must be True." + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 6c84d1496c..99b4ee248a 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -182,6 +182,12 @@ def __init__(self, cfg: DictConfig) -> None: "Enabling activation offloading should reduce memory further.", ) + if cfg.mixed_precision.enabled: + if not cfg.compile or not cfg.dataset.packed: + raise ValueError( + "When mixed_precision.enabled is True, both compile and dataset.packed must be True." + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) From 0699aa3bbd1ac5a47ea3b0e4f84a84710041e9c2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 10 Dec 2024 15:20:47 +0800 Subject: [PATCH 36/37] add space --- recipes/configs/gemma2/27B_full.yaml | 1 + recipes/configs/gemma2/2B_full.yaml | 1 + recipes/configs/gemma2/9B_full.yaml | 1 + 3 files changed, 3 insertions(+) diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index a4af7dfbcb..e6f4149ad5 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -80,6 +80,7 @@ log_peak_memory_stats: True mixed_precision: _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml index 11fcdf43af..1c144bec62 100644 --- a/recipes/configs/gemma2/2B_full.yaml +++ b/recipes/configs/gemma2/2B_full.yaml @@ -82,6 +82,7 @@ log_peak_memory_stats: True mixed_precision: _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index bf9a286ddf..c5331433d7 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -80,6 +80,7 @@ log_peak_memory_stats: True mixed_precision: _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler From be9c0fb60bcbb6019c9e934abfe9cc0e52cf0f58 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 10 Dec 2024 16:00:39 +0800 Subject: [PATCH 37/37] consolidate checks --- recipes/full_finetune_distributed.py | 18 +++++++++++------- recipes/full_finetune_single_device.py | 18 +++++++++++------- torchtune/training/quantization.py | 21 +++++++++++++++++++++ 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 9bf324ea40..2c1ef6c7d6 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -27,6 +27,7 @@ from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.activations import apply_selective_activation_checkpointing from torchtune.training.lr_schedulers import get_lr +from torchtune.training.quantization import Int8MixedPrecisionTrainingQuantizer from tqdm import tqdm @@ -182,9 +183,14 @@ def __init__(self, cfg: DictConfig) -> None: ) if cfg.mixed_precision.enabled: - if not cfg.compile or not cfg.dataset.packed: - raise ValueError( - "When mixed_precision.enabled is True, both compile and dataset.packed must be True." + if ( + cfg.mixed_precision._component_ + == "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer" + ): + Int8MixedPrecisionTrainingQuantizer.validate_config( + compile=cfg.compile, + dataset_packed=cfg.dataset.packed, + optimizer_path=cfg.optimizer._component_, ) # These are public properties which are updated by the checkpoint loader @@ -274,7 +280,7 @@ def setup(self, cfg: DictConfig) -> None: model_state_dict=checkpoint_dict[training.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), - mixed_precision_cfg=cfg.get("mixed_precision", None), + mixed_precision_cfg=cfg.mixed_precision, ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -512,9 +518,7 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if mixed_precision_cfg is not None and mixed_precision_cfg.get( - "enabled", False - ): + if mixed_precision_cfg is not None and mixed_precision_cfg.enabled: log.info(f"Preparing model with {mixed_precision_cfg._component_}") cfg = mixed_precision_cfg.copy() cfg.pop("enabled", None) diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 99b4ee248a..546a2e218f 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -24,6 +24,7 @@ from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.lr_schedulers import get_lr +from torchtune.training.quantization import Int8MixedPrecisionTrainingQuantizer from tqdm import tqdm @@ -183,9 +184,14 @@ def __init__(self, cfg: DictConfig) -> None: ) if cfg.mixed_precision.enabled: - if not cfg.compile or not cfg.dataset.packed: - raise ValueError( - "When mixed_precision.enabled is True, both compile and dataset.packed must be True." + if ( + cfg.mixed_precision._component_ + == "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer" + ): + Int8MixedPrecisionTrainingQuantizer.validate_config( + compile=cfg.compile, + dataset_packed=cfg.dataset.packed, + optimizer_path=cfg.optimizer._component_, ) # These are public properties which are updated by the checkpoint loader @@ -277,7 +283,7 @@ def setup(self, cfg: DictConfig) -> None: enable_activation_offloading=self._enable_activation_offloading, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], - mixed_precision_cfg=cfg.get("mixed_precision", None), + mixed_precision_cfg=cfg.mixed_precision, ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -437,9 +443,7 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - if mixed_precision_cfg is not None and mixed_precision_cfg.get( - "enabled", False - ): + if mixed_precision_cfg is not None and mixed_precision_cfg.enabled: log.info(f"Preparing model with {mixed_precision_cfg._component_}") cfg = mixed_precision_cfg.copy() cfg.pop("enabled", None) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 2dc4fe989a..9c88fd7a3c 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -222,6 +222,27 @@ def __init__( grad_weight=grad_weight, ) + @staticmethod + def validate_config( + *, compile: bool, dataset_packed: bool, optimizer_path: str + ) -> None: + if not (compile and dataset_packed): + raise ValueError( + "Both compile and dataset.packed must be True to use INT8 mixed-precision training." + ) + + if not optimizer_path.startswith("torch.optim."): + warn( + "Using low-bit optimizer might have convergence issues with INT8 mixed-precision training. " + "If you observe divergence, try again with the standard torch.optim.AdamW instead." + ) + + warn( + "INT8 mixed-precision might not speedup training if model and/or batch size is too small " + "for the current GPU(s). If it is the case, try increasing batch size or sequence length. " + "On A100, Llama-3-8B only has speedup for batch_size=4, max_seq_len=2048 and above." + ) + def prepare(self, model: nn.Module) -> nn.Module: # we use module-swap implementation so that the state_dict remains plain tensors, # as well as better FSDP compatibility in torchtune.