Skip to content

Commit

Permalink
raise compile error (#2188)
Browse files Browse the repository at this point in the history
Co-authored-by: Felipe Mello <[email protected]>
  • Loading branch information
felipemello1 and Felipe Mello authored Dec 20, 2024
1 parent 46a1ef0 commit de8b57c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 21 deletions.
9 changes: 4 additions & 5 deletions recipes/configs/llama3/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -108,8 +112,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
9 changes: 4 additions & 5 deletions recipes/configs/llama3_1/8B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -111,8 +115,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
9 changes: 4 additions & 5 deletions recipes/configs/llama3_2/1B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -107,8 +111,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
9 changes: 4 additions & 5 deletions recipes/configs/llama3_2/3B_qat_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ dtype: bf16
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256

# Profiler (disabled)
profiler:
Expand All @@ -108,8 +112,3 @@ profiler:
warmup_steps: 3
active_steps: 2
num_cycles: 1

# QAT arguments
quantizer:
_component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
groupsize: 256
6 changes: 6 additions & 0 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class QATRecipeDistributed(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
ValueError: If ``compile`` is set to True.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
Expand All @@ -133,6 +134,11 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

if cfg.get("compile", False):
raise ValueError(
"Compile is not yet supported for QAT. Please set compile=False."
)

# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
Expand Down
8 changes: 7 additions & 1 deletion recipes/qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ class QATLoRAFinetuneRecipeDistributed(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
ValueError: If world_size is 1
ValueError: If world_size is 1.
ValueError: If ``compile`` is set to True.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
Expand All @@ -149,6 +150,11 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

if cfg.get("compile", False):
raise ValueError(
"Compile is not yet supported for QAT. Please set compile=False."
)

_, rank = utils.get_world_size_and_rank()

# _is_rank_zero is used primarily for logging. In the future, the logger
Expand Down

0 comments on commit de8b57c

Please sign in to comment.