diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index ad941803bb..c32fc2406f 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -81,6 +81,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml index d8f5e8956f..1218f9f01a 100644 --- a/recipes/configs/dev/8B_full_experimental.yaml +++ b/recipes/configs/dev/8B_full_experimental.yaml @@ -83,6 +83,10 @@ metric_logger: log_every_n_steps: null log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index fa692e0f0d..a23ab1efea 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -77,6 +77,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index 47206ed291..07d1ea7c2f 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -79,6 +79,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/gemma2/27B_full.yaml b/recipes/configs/gemma2/27B_full.yaml index 46a31b6821..e6f4149ad5 100644 --- a/recipes/configs/gemma2/27B_full.yaml +++ b/recipes/configs/gemma2/27B_full.yaml @@ -76,6 +76,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/gemma2/2B_full.yaml b/recipes/configs/gemma2/2B_full.yaml index 42b034fa2c..1c144bec62 100644 --- a/recipes/configs/gemma2/2B_full.yaml +++ b/recipes/configs/gemma2/2B_full.yaml @@ -78,6 +78,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/gemma2/9B_full.yaml b/recipes/configs/gemma2/9B_full.yaml index bbb31fb268..c5331433d7 100644 --- a/recipes/configs/gemma2/9B_full.yaml +++ b/recipes/configs/gemma2/9B_full.yaml @@ -76,6 +76,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index 67932bbb1b..2d693869a5 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -81,6 +81,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index 40fb804035..f6a7d799ad 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -80,6 +80,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index 29d157dbf6..05f498a473 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -84,6 +84,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 5491ae093d..f7d062b1f3 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -82,6 +82,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 2723d08c90..827c2632f3 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -74,6 +74,11 @@ custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separatel # Reduced precision dtype: bf16 +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Logging metric_logger: _component_: torchtune.training.metric_logging.DiskLogger diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index ad534c62b9..75703870a4 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -83,6 +83,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 1ecf130e1a..fa838ef4c5 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -84,6 +84,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 357d20356d..ea8f0274a8 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -84,6 +84,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 1429b9cc2b..d394635b61 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -83,6 +83,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/1B_full.yaml b/recipes/configs/llama3_2/1B_full.yaml index 25c7de45c1..8084eb1700 100644 --- a/recipes/configs/llama3_2/1B_full.yaml +++ b/recipes/configs/llama3_2/1B_full.yaml @@ -81,6 +81,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/1B_full_single_device.yaml b/recipes/configs/llama3_2/1B_full_single_device.yaml index e24fc56219..e07b69a50b 100644 --- a/recipes/configs/llama3_2/1B_full_single_device.yaml +++ b/recipes/configs/llama3_2/1B_full_single_device.yaml @@ -79,6 +79,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/3B_full.yaml b/recipes/configs/llama3_2/3B_full.yaml index 0703437596..2f88b6fde3 100644 --- a/recipes/configs/llama3_2/3B_full.yaml +++ b/recipes/configs/llama3_2/3B_full.yaml @@ -81,6 +81,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2/3B_full_single_device.yaml b/recipes/configs/llama3_2/3B_full_single_device.yaml index 052c524019..b9ced8c932 100644 --- a/recipes/configs/llama3_2/3B_full_single_device.yaml +++ b/recipes/configs/llama3_2/3B_full_single_device.yaml @@ -81,6 +81,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/11B_full.yaml b/recipes/configs/llama3_2_vision/11B_full.yaml index 5f0e970a66..fa95a95189 100644 --- a/recipes/configs/llama3_2_vision/11B_full.yaml +++ b/recipes/configs/llama3_2_vision/11B_full.yaml @@ -84,6 +84,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml index daa678d0e5..a7241fd580 100644 --- a/recipes/configs/llama3_2_vision/11B_full_single_device.yaml +++ b/recipes/configs/llama3_2_vision/11B_full_single_device.yaml @@ -84,6 +84,11 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false + # Profiler (default is disabled) profiler: _component_: torchtune.training.setup_torch_profiler diff --git a/recipes/configs/llama3_2_vision/90B_full.yaml b/recipes/configs/llama3_2_vision/90B_full.yaml index 9d96b966cd..63e3790eca 100644 --- a/recipes/configs/llama3_2_vision/90B_full.yaml +++ b/recipes/configs/llama3_2_vision/90B_full.yaml @@ -81,6 +81,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index 15a6ec7b89..51b0fdaf8f 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -83,6 +83,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index 287a66dbd0..9dc9d2ab6d 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -86,6 +86,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index 7dc954576d..74a43733d2 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -78,6 +78,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index 8162e73c18..8b0019df93 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -79,6 +79,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 093887fb59..fb49ae0e06 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -79,6 +79,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 4f670695ca..36270ff8ae 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -79,6 +79,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index 04017db7ec..011ae618dd 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -79,6 +79,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index d529629823..b57689ddbe 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -84,6 +84,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index ec82a0d701..5d7583f6d0 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -82,6 +82,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index 0b01526ba4..fe6ac94b72 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -83,6 +83,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/0.5B_full.yaml b/recipes/configs/qwen2_5/0.5B_full.yaml index c415425d5b..45262c819d 100644 --- a/recipes/configs/qwen2_5/0.5B_full.yaml +++ b/recipes/configs/qwen2_5/0.5B_full.yaml @@ -72,6 +72,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml index 2ac3a79f00..f6c4f23d21 100644 --- a/recipes/configs/qwen2_5/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/0.5B_full_single_device.yaml @@ -72,6 +72,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/1.5B_full.yaml b/recipes/configs/qwen2_5/1.5B_full.yaml index 431c1b519a..5a2b1060ab 100644 --- a/recipes/configs/qwen2_5/1.5B_full.yaml +++ b/recipes/configs/qwen2_5/1.5B_full.yaml @@ -72,6 +72,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml index d48176616d..f828818d05 100644 --- a/recipes/configs/qwen2_5/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/1.5B_full_single_device.yaml @@ -75,6 +75,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: True +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/3B_full.yaml b/recipes/configs/qwen2_5/3B_full.yaml index 217769ad8c..c545d51798 100644 --- a/recipes/configs/qwen2_5/3B_full.yaml +++ b/recipes/configs/qwen2_5/3B_full.yaml @@ -80,6 +80,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/3B_full_single_device.yaml b/recipes/configs/qwen2_5/3B_full_single_device.yaml index 38b1645817..8142798e43 100644 --- a/recipes/configs/qwen2_5/3B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/3B_full_single_device.yaml @@ -81,6 +81,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/7B_full.yaml b/recipes/configs/qwen2_5/7B_full.yaml index d071687103..ccc16d3053 100644 --- a/recipes/configs/qwen2_5/7B_full.yaml +++ b/recipes/configs/qwen2_5/7B_full.yaml @@ -82,6 +82,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/configs/qwen2_5/7B_full_single_device.yaml b/recipes/configs/qwen2_5/7B_full_single_device.yaml index e6ebbcb8cf..242a661ec2 100644 --- a/recipes/configs/qwen2_5/7B_full_single_device.yaml +++ b/recipes/configs/qwen2_5/7B_full_single_device.yaml @@ -83,6 +83,10 @@ metric_logger: log_every_n_steps: 1 log_peak_memory_stats: False +# mixed precision (disabled) +mixed_precision: + _component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer + enabled: false # Profiler (disabled) profiler: diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 01c0607bbf..04a7c8fe84 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -31,6 +31,7 @@ TrainingProgress, ) from torchtune.training.lr_schedulers import get_lr +from torchtune.training.quantization import Int8MixedPrecisionTrainingQuantizer from tqdm import tqdm @@ -187,6 +188,17 @@ def __init__(self, cfg: DictConfig) -> None: "Enabling activation offloading should reduce memory further.", ) + if cfg.mixed_precision.enabled: + if ( + cfg.mixed_precision._component_ + == "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer" + ): + Int8MixedPrecisionTrainingQuantizer.validate_config( + compile=cfg.compile, + dataset_packed=cfg.dataset.packed, + optimizer_path=cfg.optimizer._component_, + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -260,6 +272,7 @@ def setup(self, cfg: DictConfig) -> None: model_state_dict=checkpoint_dict[training.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), + mixed_precision_cfg=cfg.mixed_precision, ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -481,6 +494,7 @@ def _setup_model( custom_sharded_layers: Optional[List[str]] = None, ac_mode: Optional[str] = None, ac_option: Optional[int] = None, + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: """ Model initialization has some important considerations: @@ -521,6 +535,13 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if mixed_precision_cfg is not None and mixed_precision_cfg.enabled: + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = mixed_precision_cfg.copy() + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) + model = quantizer.prepare(model) + # For FSDP sharding fsdp_shard_conditions = [ partial( diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 946e970206..dc03585e54 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -24,6 +24,7 @@ from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import DummyProfiler, PROFILER_KEY from torchtune.training.lr_schedulers import get_lr +from torchtune.training.quantization import Int8MixedPrecisionTrainingQuantizer from tqdm import tqdm @@ -182,6 +183,17 @@ def __init__(self, cfg: DictConfig) -> None: "Enabling activation offloading should reduce memory further.", ) + if cfg.mixed_precision.enabled: + if ( + cfg.mixed_precision._component_ + == "torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer" + ): + Int8MixedPrecisionTrainingQuantizer.validate_config( + compile=cfg.compile, + dataset_packed=cfg.dataset.packed, + optimizer_path=cfg.optimizer._component_, + ) + # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests self.seed = training.set_seed(seed=cfg.seed) @@ -271,6 +283,7 @@ def setup(self, cfg: DictConfig) -> None: enable_activation_offloading=self._enable_activation_offloading, compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], + mixed_precision_cfg=cfg.mixed_precision, ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -414,6 +427,7 @@ def _setup_model( enable_activation_offloading: bool, compile_model: bool, model_state_dict: Dict[str, Any], + mixed_precision_cfg: Optional[DictConfig] = None, ) -> nn.Module: """ Set up the model including enabling activation checkpointing. @@ -429,6 +443,13 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + if mixed_precision_cfg is not None and mixed_precision_cfg.enabled: + log.info(f"Preparing model with {mixed_precision_cfg._component_}") + cfg = mixed_precision_cfg.copy() + cfg.pop("enabled", None) + quantizer = config.instantiate(cfg) + model = quantizer.prepare(model) + model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. diff --git a/tests/torchtune/training/test_quantization.py b/tests/torchtune/training/test_quantization.py new file mode 100644 index 0000000000..2b8a5afea3 --- /dev/null +++ b/tests/torchtune/training/test_quantization.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import gpu_test +from torch import nn +from torchtune.training.quantization import ( + _SUPPORTS_INT8_MIXED_PRECISION_TRAINING, + Int8MixedPrecisionTrainingQuantizer, +) + + +@gpu_test(gpu_count=1) +@pytest.mark.skipif( + not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING, + reason="INT8 mixed-precision training is not supported", +) +def test_int8_mixed_precision_training_quantizer(): + quantizer = Int8MixedPrecisionTrainingQuantizer() + model = nn.Sequential( + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + ).cuda() + quantizer.prepare(model) + + # make sure class is changed + assert model[0].__class__ != nn.Linear + assert model[2].__class__ != nn.Linear + + # smoke test forward and backward + model(torch.randn(2, 32).cuda()).sum().backward() + for p in model.parameters(): + assert p.grad is not None + + # state dict is plain tensor + state_dict = model.state_dict() + for v in state_dict.values(): + assert v.__class__ == torch.Tensor diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index b158d4b9a3..cc10432984 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -7,6 +7,9 @@ from typing import Callable, Optional from warnings import warn +import torch +import torchao +from packaging.version import Version from torch import nn from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear @@ -48,6 +51,19 @@ ) +_SUPPORTS_INT8_MIXED_PRECISION_TRAINING = ( + Version(torchao.__version__) >= Version("0.7.0.dev") + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (8, 0) +) + +if _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: + from torchao.prototype.quantized_training import ( + int8_mixed_precision_training, + Int8MixedPrecisionTrainingConfig, + ) + + __all__ = [ "get_quantizer_mode", "Int4WeightOnlyQuantizer", @@ -56,6 +72,7 @@ "Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightQATQuantizer", "Int8DynActInt4WeightQATQuantizerModuleSwap", + "Int8MixedPrecisionTrainingQuantizer", ] @@ -160,6 +177,94 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize ] = enable_8da4w_fake_quant_module_swap +class Int8MixedPrecisionTrainingQuantizer: + """Apply INT8 mixed-precision training. This only affects weights of ``nn.Linear`` + modules. During training, weights and activations are dynamically quantized to INT8 + to utilize fast matrix multiplication with INT8 tensor cores. This is also done in + the backward pass. + + The expected end2end speedup is 40% on a single A100 and 70% on a single 4090, with + minimal accuracy loss. If convergence is an issue, please refer to torchao + documentation below. + + For more details, as well as details about arguments of this quantizer, please refer to + https://github.com/pytorch/ao/tree/main/torchao/prototype/quantized_training#int8-mixed-precision + + Args: + output (bool): whether to apply INT8 mixed-precision for calculating output. Default: True + grad_input (bool): whether to apply INT8 mixed-precision for calculating grad_input. Default: True + grad_weight (bool): whether to apply INT8 mixed-precision for calculating grad_weight. Default: True + + Raises: + RuntimeError: If runtime requirements for INT8 mixed-precision training are not met. + + NOTE: Due to the limitations of the current implementation, the following + requirements must be satisfied to enjoy the expected speedup: + + 1. Must use ``torch.compile()`` (set ``compile=True``). + 2. Inputs to the model must not be too dynamic. For example, when input tokens + length changes for every batch, you won't see the expected speedup. + + To satisfy (2), you can use :class:`~torchtune.datasets.PackedDataset` (set + ``dataset.packed=True`` and ``tokenizer.max_seq_len`` to a desired value.), which + ensures input tokens always have fixed length. + """ + + def __init__( + self, + output: bool = True, + grad_input: bool = True, + grad_weight: bool = True, + ) -> None: + if not _SUPPORTS_INT8_MIXED_PRECISION_TRAINING: + raise RuntimeError( + "INT8 mixed-precision training requires torch>=2.4, torchao>=0.7, and" + " a CUDA-capable device with compute capability >= 8.0" + ) + + self._config = Int8MixedPrecisionTrainingConfig( + output=output, + grad_input=grad_input, + grad_weight=grad_weight, + ) + + @staticmethod + def validate_config( + *, compile: bool, dataset_packed: bool, optimizer_path: str + ) -> None: + if not (compile and dataset_packed): + raise ValueError( + "Both compile and dataset.packed must be True to use INT8 mixed-precision training." + ) + + if not optimizer_path.startswith("torch.optim."): + warn( + "Using low-bit optimizer might have convergence issues with INT8 mixed-precision training. " + "If you observe divergence, try again with the standard torch.optim.AdamW instead." + ) + + warn( + "INT8 mixed-precision might not speedup training if model and/or batch size is too small " + "for the current GPU(s). If it is the case, try increasing batch size or sequence length. " + "On A100, Llama-3-8B only has speedup for batch_size=4, max_seq_len=2048 and above." + ) + + def prepare(self, model: nn.Module) -> nn.Module: + # we use module-swap implementation so that the state_dict remains plain tensors, + # as well as better FSDP compatibility in torchtune. + quantize_fn = int8_mixed_precision_training(self._config, module_swap=True) + + def filter_fn(module: nn.Module, name: str) -> bool: + # skip LM head since end2end speedup is slightly worse. + # there are also possible issues with tied word embeddings. + return isinstance(module, nn.Linear) and module.out_features < 32_000 + + # don't set inductor config, otherwise compile will be very slow + # (it will affect global torch.compile() config) + quantize_(model, quantize_fn, filter_fn=filter_fn, set_inductor_config=False) + return model + + def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: """Given a quantizer object, returns a string that specifies the type of quantization.