diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index 1aad6abcb..75f521097 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -92,7 +92,7 @@ class AccelerationFrameworkConfig: fast_kernels: Annotated[ FastKernelsConfig, ConfigAnnotation( - path="peft.quantization", + path="training", key="fused_ops_and_kernels", experimental=True, required_packages=["foak"], @@ -127,6 +127,15 @@ def _verify_configured_dataclasses(self): self.padding_free is not None ), "`--multipack` is currently only supported with `--padding_free`" + # Check that fused lora must be activated with either auto_gptq or bitsandbytes + if self.fused_lora is not None: + assert ( + self.bitsandbytes is not None or self.auto_gptq is not None + ), "`--fused_lora` must be accompanied by a quantized base layer"\ + " `--auto_gptq` or `--bitsandbytes`." + + + @staticmethod def from_dataclasses(*dataclasses: Type): "Convert one or many FMS config dataclasses to a monolithic AccelerationConfig" diff --git a/tuning/config/acceleration_configs/fused_ops_and_kernels.py b/tuning/config/acceleration_configs/fused_ops_and_kernels.py index ded51415e..4777394c8 100644 --- a/tuning/config/acceleration_configs/fused_ops_and_kernels.py +++ b/tuning/config/acceleration_configs/fused_ops_and_kernels.py @@ -54,18 +54,11 @@ class FastKernelsConfig(List): fast_loss: bool = False # fast rms norm triton kernels - fast_rsm_layernorm: bool = False + fast_rms_layernorm: bool = False # fast RoPE embedding triton kernels fast_rope_embeddings: bool = False - def __post_init__(self): - - if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings: - raise ValueError( - "fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled " - "together. This restriction may be relaxed in the future." - ) @dataclass @@ -77,14 +70,6 @@ class FusedOpsAndKernelsConfig: # fast kernels fast_kernels: FastKernelsConfig = None - def __post_init__(self): - if (self.fused_lora is not None and self.fast_kernels is None) or ( - self.fused_lora is None and self.fast_kernels is not None - ): - raise ValueError( - "fused lora and fast_kernels must be used together. " - "This restriction may be relaxed in the future." - ) # ensure nested dataclasses initialized ensure_nested_dataclasses_initialized(self)