Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add liger kernel with fused cross entropy loss #93

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _check_config_and_maybe_check_values(
t = list(t.keys())[0] # otherwise take the first value

if t not in values:
if default is None:
if t is not None or default is None:
raise AccelerationPluginConfigError(
f"{self.__class__.__name__}: Value at '{key}' was '{t}'. "
f"Not found in expected set '{values}'."
Expand Down
3 changes: 2 additions & 1 deletion plugins/fused-ops-and-kernels/.isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ known_firstparty=
known_localfolder=tuning

# skip code imported from unsloth
skip_glob=**/unsloth*/**
skip_glob=**/unsloth*/**,
**/liger*/**
4 changes: 3 additions & 1 deletion plugins/fused-ops-and-kernels/.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ ignore=CVS,protobufs
# format. Because '\\' represents the directory delimiter on Windows systems,
# it can't be used as an escape character.
# NOTE: do not lint code imported from unsloth
ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth*
ignore-paths=.*fused_ops/unsloth_lora.*,
.*fused_ops/liger_ce.*,
.*kernels/unsloth*,

# Files or directories matching the regular expression patterns are skipped.
# The regex matches against base names, not paths. The default value ignores
Expand Down
15 changes: 14 additions & 1 deletion plugins/fused-ops-and-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,23 @@ It is realtively easy by following an existing template, in what follows we use
)
```

### Running Liger Kernel Benchmarks

Using the [scenarios-liger.yaml](../../scripts/benchmarks/scenarios-liger.yaml), this will run full fine tuning, lora peft, autoGPTQ lora peft, and bits-and-bytes lora peft with the triton kernels (Fast RMS, RoPE, CrossEnt) as a base and then run with the liger kernel for LigerFusedLinearCrossEntropy as well as Fast RMS, RoPE to compare results. It only runs against mistral and llama models.

The benchmarks were ran separately for each `num_gpu` entry; they can be run together in a single command, but this is more efficient.

```sh
tox -e run-benches -- 1 "4 8 16 32" benchmark_outputs_1 scenarios-liger.yaml none
tox -e run-benches -- 2 "8 16 32 64" benchmark_outputs_2 scenarios-liger.yaml none
tox -e run-benches -- 4 "16 32 64 128" benchmark_outputs_3 scenarios-liger.yaml none
```


## Known Issues

- MixedPrecision `--fp16` or `--bf16` should be used with `fast_lora`.
- `fast_lora` has issues with FSDP V1 with the `peft` style of FSDP wrapping.
* This is because the adapter's forward functions are bypassed in the fused ops.
* For AutoGPTQ/QLoRA this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops.
- `fast_rope_embeddings` does not work with position_ids. Currently `position_ids` are ignored and could give wrong results.
- `fast_rope_embeddings` does not work with `postion_ids`, it seems like HF has depracated passing these ids into the rope embedding methods.
2 changes: 1 addition & 1 deletion plugins/fused-ops-and-kernels/configs/fast_kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ training:
fast_rms_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True
fast_rope_embeddings: True
25 changes: 25 additions & 0 deletions plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
training:

fused_ops_and_kernels:

# if under training stanza, then putting
# base_layer and fused_lora will be a misnomer
# - this should be in peft.quantized
# However, if it is specified, it will still
# be read. This is useful in use cases where
# the yaml is system generated and not shown
# to a user.

# activate various unsloth optimizations
# there are two versions of the plugin
# - the FastKernel version supports individual kernels
# - the FastQuantized version is all-or-nothing

# fast loss triton kernels
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rms_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# PEFT-related acceleration
peft:

# quantization-releated acceleration
# e.g., kernels for quantized base weights
quantization:

fused_ops_and_kernels:

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: auto_gptq

# activate various unsloth optimizations
# there are two versions of the plugin
# - the FastKernel version supports individual kernels
# - the FastQuantized version is all-or-nothing


# fused kernels for lora linear layers
fused_lora: True

# fast loss triton kernels
fast_loss: fused_ce_liger

# fast rms norm triton kernels
fast_rsm_layernorm: True

# fast RoPE embedding triton kernels
fast_rope_embeddings: True
8 changes: 8 additions & 0 deletions plugins/fused-ops-and-kernels/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,11 @@ only-include = ["src/fms_acceleration_foak"]

[tool.hatch.build.targets.wheel.sources]
"src" = ""

[tool.black]
force-exclude = '''
/(
.*unsloth.*
| .*liger.*
)/
'''
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import torch

# Local
from .utils import lora_adapters_switch_ddp_from_fsdp
from .models.utils import filter_mp_rules
from .utils import lora_adapters_switch_ddp_from_fsdp


# consider rewriting register_foak_model_patch_rules into something
Expand Down Expand Up @@ -73,7 +73,10 @@ def register_foak_model_patch_rules(
# maybe this we should define envvars
FILTER_MAP = {
"fused_lora": {"qkvo", "mlp"},
"fast_loss": "cross-ent",
"fast_loss": {
True: "cross-ent",
"fused_ce_liger": "fused-lce",
},
"fast_rms_layernorm": "rms",
"fast_rope_embeddings": "rope",
}
Expand Down Expand Up @@ -109,19 +112,19 @@ def __init__(self, configurations: Dict[str, Dict]):
key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq"
)
self.configurations["fused_lora"] = self._check_config_and_maybe_check_values(
key="fused_lora", values=[False, True], default=True
key="fused_lora", values=[False, True], default=False
)
self.configurations["fast_loss"] = self._check_config_and_maybe_check_values(
key="fast_loss", values=[False, True], default=True
key="fast_loss", values=[False, True, "fused_ce_liger"], default=False
)
self.configurations["fast_rms_layernorm"] = (
self._check_config_and_maybe_check_values(
key="fast_rms_layernorm", values=[False, True], default=True
key="fast_rms_layernorm", values=[False, True], default=False
)
)
self.configurations["fast_rope_embeddings"] = (
self._check_config_and_maybe_check_values(
key="fast_rope_embeddings", values=[False, True], default=True
key="fast_rope_embeddings", values=[False, True], default=False
)
)

Expand Down Expand Up @@ -162,6 +165,8 @@ def augmentation(

if k in FILTER_MAP and k not in omitted:
ts = FILTER_MAP[k]
if isinstance(ts, dict) and v in ts:
ts = ts[v]
if isinstance(ts, str):
ts = {ts}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Byron Hsu & Linkedin team. All rights reserved.
#
# BSD 2-CLAUSE LICENSE
# Copyright 2024 LinkedIn Corporation
# All Rights Reserved.
# Redistribution and use in source and binary forms, with or
# without modification, are permitted provided that the following
# conditions are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from .fused_linear_cross_entropy_loss import lce_forward
Loading
Loading