Skip to content

Commit

Permalink
modifications to dataclasses to support fast kernels on full finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
achew010 committed Sep 11, 2024
1 parent b15a9c7 commit 0303497
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AccelerationFrameworkConfig:
fast_kernels: Annotated[
FastKernelsConfig,
ConfigAnnotation(
path="peft.quantization",
path="training",
key="fused_ops_and_kernels",
experimental=True,
required_packages=["foak"],
Expand Down Expand Up @@ -127,6 +127,15 @@ def _verify_configured_dataclasses(self):
self.padding_free is not None
), "`--multipack` is currently only supported with `--padding_free`"

# Check that fused lora must be activated with either auto_gptq or bitsandbytes
if self.fused_lora is not None:
assert (
self.bitsandbytes is not None or self.auto_gptq is not None
), "`--fused_lora` must be accompanied by a quantized base layer"\
" `--auto_gptq` or `--bitsandbytes`."



@staticmethod
def from_dataclasses(*dataclasses: Type):
"Convert one or many FMS config dataclasses to a monolithic AccelerationConfig"
Expand Down
17 changes: 1 addition & 16 deletions tuning/config/acceleration_configs/fused_ops_and_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,11 @@ class FastKernelsConfig(List):
fast_loss: bool = False

# fast rms norm triton kernels
fast_rsm_layernorm: bool = False
fast_rms_layernorm: bool = False

# fast RoPE embedding triton kernels
fast_rope_embeddings: bool = False

def __post_init__(self):

if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings:
raise ValueError(
"fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled "
"together. This restriction may be relaxed in the future."
)


@dataclass
Expand All @@ -77,14 +70,6 @@ class FusedOpsAndKernelsConfig:
# fast kernels
fast_kernels: FastKernelsConfig = None

def __post_init__(self):
if (self.fused_lora is not None and self.fast_kernels is None) or (
self.fused_lora is None and self.fast_kernels is not None
):
raise ValueError(
"fused lora and fast_kernels must be used together. "
"This restriction may be relaxed in the future."
)

# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)

0 comments on commit 0303497

Please sign in to comment.