Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate INT8 mixed-precision from torchao 0.7 #1552

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
cf3355e
add int8mp
gau-nernst Sep 12, 2024
d3bbaeb
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 13, 2024
5a61d3e
add a flag
gau-nernst Sep 13, 2024
560039d
create a quantizer
gau-nernst Sep 13, 2024
2b6e066
add notes on when speedup can be expected
gau-nernst Sep 13, 2024
d32f5b8
clarify doc message
gau-nernst Sep 13, 2024
60dad97
update docs
gau-nernst Sep 13, 2024
8395070
add tiny log
gau-nernst Sep 13, 2024
b7b8a7d
update comment
gau-nernst Sep 13, 2024
2829b03
add guard on torch version and CUDA sm
gau-nernst Sep 13, 2024
688a1c8
add integration test
gau-nernst Sep 13, 2024
21391ad
update test
gau-nernst Sep 13, 2024
f885d56
use dummy alpaca
gau-nernst Sep 13, 2024
7db782c
fix typo
gau-nernst Sep 14, 2024
8306f9a
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 14, 2024
25a2451
convert speed test to smoke test
gau-nernst Sep 14, 2024
86d5f04
Merge branch 'int8mp' of github.com:gau-nernst/torchtune into int8mp
gau-nernst Sep 14, 2024
6094cdb
fix test
gau-nernst Sep 14, 2024
19a2d3e
add ao version guard
gau-nernst Sep 14, 2024
faec18d
fix
gau-nernst Sep 14, 2024
f4f1945
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 14, 2024
8fc2826
attempt LoRA
gau-nernst Sep 14, 2024
911df57
fix lora
gau-nernst Sep 15, 2024
51bbeac
skip LoRA
gau-nernst Sep 15, 2024
1e5ae92
skip NF4
gau-nernst Sep 15, 2024
1e4eaf6
Merge branch 'pytorch:main' into int8mp
gau-nernst Sep 15, 2024
30585c2
Merge branch 'main' into int8mp
felipemello1 Oct 3, 2024
45b4365
typo
felipemello1 Oct 3, 2024
3e5b040
Merge branch 'main' into int8mp
gau-nernst Nov 3, 2024
1ac836a
remove unwanted chnages
gau-nernst Nov 3, 2024
5d94cb3
use module swap
gau-nernst Nov 3, 2024
06abd88
remove unused import
gau-nernst Nov 3, 2024
0ff702e
update docs. change to mixed_precision
gau-nernst Nov 3, 2024
05563f2
add test. small fixes
gau-nernst Nov 3, 2024
3050c32
add config entries
gau-nernst Nov 3, 2024
864c6fb
remove extra compile
gau-nernst Nov 3, 2024
1fed859
fix lora finetune
gau-nernst Nov 3, 2024
66e8cdd
Merge branch 'main' into int8mp
gau-nernst Nov 8, 2024
207308b
Merge branch 'main' into int8mp
gau-nernst Nov 12, 2024
0fecc26
fix version check
gau-nernst Nov 12, 2024
39e1fc1
dont set inductor config
gau-nernst Nov 12, 2024
b2bc5ef
Merge branch 'main' into int8mp
gau-nernst Dec 5, 2024
a334986
remove LoRA
gau-nernst Dec 5, 2024
d149801
remove PyTorch version check
gau-nernst Dec 5, 2024
03a1978
add checks in init. add entries to all applicable configs
gau-nernst Dec 5, 2024
35ca06a
Merge branch 'main' into int8mp
gau-nernst Dec 10, 2024
0699aa3
add space
gau-nernst Dec 10, 2024
be9c0fb
consolidate checks
gau-nernst Dec 10, 2024
ca29866
Merge branch 'pytorch:main' into int8mp
gau-nernst Dec 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions recipes/configs/code_llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/dev/8B_full_experimental.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ output_dir: /tmp/alpaca-llama3-finetune
log_every_n_steps: null
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/gemma/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ output_dir: /tmp/alpaca-gemma-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/gemma/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ output_dir: /tmp/alpaca-gemma-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/gemma2/27B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,8 @@ metric_logger:
output_dir: /tmp/alpaca-gemma2-27b-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false
5 changes: 5 additions & 0 deletions recipes/configs/gemma2/2B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,8 @@ metric_logger:
output_dir: /tmp/alpaca-gemma2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false
5 changes: 5 additions & 0 deletions recipes/configs/gemma2/9B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,8 @@ metric_logger:
output_dir: /tmp/alpaca-gemma2-9b-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false
5 changes: 5 additions & 0 deletions recipes/configs/llama2/13B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ output_dir: /tmp/alpaca-llama2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ output_dir: /tmp/full-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separatel
# Reduced precision
dtype: bf16

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ output_dir: /tmp/full-llama3-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_1/70B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ output_dir: /tmp/full-llama3_1-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_1/8B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ output_dir: /tmp/full-llama3.1-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false
gau-nernst marked this conversation as resolved.
Show resolved Hide resolved

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_1/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ output_dir: /tmp/full-llama3.1-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2/1B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ output_dir: /tmp/full-llama3.2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2/1B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ output_dir: /tmp/full-llama3.2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2/3B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ output_dir: /tmp/full-llama3.2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2/3B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ output_dir: /tmp/full-llama3.2-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2_vision/11B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2_vision/11B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (default is disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/llama3_2_vision/90B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/mistral/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ output_dir: /tmp/Mistral-7B-v0.1/
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/mistral/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ output_dir: /tmp/Mistral-7B-v0.1/
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/phi3/mini_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/phi3/mini_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/qwen2/0.5B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/qwen2/0.5B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ output_dir: /tmp/Qwen2-0.5B-Instruct-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/qwen2/1.5B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ output_dir: /tmp/Qwen2-1.5B-Instruct-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/qwen2/1.5B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ output_dir: /tmp/Qwen2-1.5B-Instruct-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/qwen2/7B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ output_dir: /tmp/Qwen2-7B-Instruct-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/qwen2/7B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ output_dir: /tmp/Qwen2-7B-Instruct-finetune
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
5 changes: 5 additions & 0 deletions recipes/configs/qwen2_5/0.5B_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ metric_logger:
log_every_n_steps: 1
log_peak_memory_stats: True

# mixed precision (disabled)
mixed_precision:
_component_: torchtune.training.quantization.Int8MixedPrecisionTrainingQuantizer
enabled: false

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
Expand Down
Loading
Loading