From fe14d28f622bf25644648371eacf7a79fba756a8 Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Thu, 19 Sep 2024 05:48:20 +0000 Subject: [PATCH 01/14] initial implementation of fused-linear-loss on llama Signed-off-by: 1000850000 user Signed-off-by: Anh Uong --- .../configs/fast_kernels.yaml | 5 +- .../framework_plugin_fast_kernels.py | 15 + .../kernels/liger/cross_entropy.py | 341 ++++++++++++++ .../liger/fused_linear_cross_entropy_loss.py | 417 ++++++++++++++++++ .../src/fms_acceleration_foak/models/llama.py | 10 +- ...oak-fast-kernels-sample-configuration.yaml | 5 +- 6 files changed, 789 insertions(+), 4 deletions(-) create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py diff --git a/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml b/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml index 476daa91..823af26f 100644 --- a/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml +++ b/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml @@ -22,4 +22,7 @@ training: fast_rms_layernorm: True # fast RoPE embedding triton kernels - fast_rope_embeddings: True \ No newline at end of file + fast_rope_embeddings: True + + # fused linear cross entropy loss + fused_linear_loss: False \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index cb39d4e6..1265aec7 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -26,6 +26,12 @@ from .framework_plugin_fast_quantized_peft import lora_adapters_switch_ddp_from_fsdp +def validate_plugin_args(configurations): + # Consider making this a more graceful fallback? + assert ( + configurations["fused_linear_loss"] != configurations["fast_loss"] + ), "If using `fused_linear_loss`, `fast_loss` must be set to False" + # consider rewriting register_foak_model_patch_rules into something # like this also def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = None): @@ -68,6 +74,7 @@ def register_foak_model_patch_rules2(base_type: str, filter_endswith: Set[str] = "fast_loss": "cross-ent", "fast_rms_layernorm": "rms", "fast_rope_embeddings": "rope", + "fused_linear_loss": "fused-lce", } @@ -115,6 +122,14 @@ def __init__(self, configurations: Dict[str, Dict]): key="fast_rope_embeddings", values=[False, True], default=True ) ) + self.configurations["fast_linear_cross_entropy"] = ( + self._check_config_and_maybe_check_values( + key="fast_linear_cross_entropy", values=[False, True], default=False + ) + ) + + validate_plugin_args(self.configurations) + @property def requires_agumentation(self): diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py new file mode 100644 index 00000000..5a9a9d07 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py @@ -0,0 +1,341 @@ +# Copyright 2024 Byron Hsu & Linkedin team. All rights reserved. +# +# BSD 2-CLAUSE LICENSE +# Copyright 2024 LinkedIn Corporation +# All Rights Reserved. +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import torch +import triton +import triton.language as tl + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + ignore_index, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The string for the reduction to apply + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load( + X_ptr + y + ) # we need to store the original value of X_y for the loss calculation + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") + ) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V), V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) + # = dx_i - (1 - label_smoothing) + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") + ) + if reduction == "mean": + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + else: + X_block = tl.exp(X_block - m) / d - eps + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = -(ori_X_y - m - tl.log(d)) + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss + + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + loss = loss / n_non_ignore + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + X_y = tl.load(X_ptr + y) + if reduction == "mean": + X_y += -(1 - label_smoothing) / (n_non_ignore) + else: + X_y += -(1 - label_smoothing) + + tl.store(loss_ptr, loss) + tl.store(X_ptr + y, X_y) + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) + + +def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): + BT, V = _input.shape + n_rows = BT + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) + + n_non_ignore = (target != ignore_index).sum().item() + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + loss_ptr=loss_1d, + loss_stride=loss_1d.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction=reduction, + BLOCK_SIZE=BLOCK_SIZE, + # TODO: 32 seems to give the best performance + # Performance is quite sensitive to num_warps + num_warps=32, + ) + + loss = torch.sum(loss_1d) + return loss, _input + + +def cross_entropy_backward(_input, grad_output): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + else: + BT, V = _input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + return _input + + +class LigerCrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Liger Cross Entropy loss. + It overrides the forward and backward methods of the torch.autograd.Function class. + """ + + @staticmethod + def forward( + ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean" + ): + """ + The forward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. + target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction (str): The reduction to apply to the output: "none" | "mean | "sum". + + Returns: + tensor: The computed loss. + """ + loss, _input = cross_entropy_forward( + _input, target, ignore_index, label_smoothing, reduction + ) + # TODO: investigation + # If we don't detach the _input tensor, the memory will double + # Not sure why but seems that there will be a time both grad and value exist but in different location + ctx.save_for_backward(_input.detach()) + return loss + + @staticmethod + def backward(ctx, grad_output): + """ + The backward pass of the Liger Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + (_input,) = ctx.saved_tensors + _input = cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py new file mode 100644 index 00000000..29cf5729 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py @@ -0,0 +1,417 @@ +# Copyright 2024 Byron Hsu & Linkedin team. All rights reserved. +# +# BSD 2-CLAUSE LICENSE +# Copyright 2024 LinkedIn Corporation +# All Rights Reserved. +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import List, Optional, Tuple, Union +import torch +import triton +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import ( + _CONFIG_FOR_DOC, + LLAMA_INPUTS_DOCSTRING, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +from .cross_entropy import ( + element_mul_kernel, + liger_cross_entropy_kernel, +) + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + bias=None, + ignore_index=-100, + label_smoothing=0.0, + reduction="mean", +): + dtype = ( + torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype + ) + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + chunk_size = triton.next_power_of_2( + triton.cdiv(BT, inc_factor) + ) # (BT + inc_factor - 1) // inc_factor + num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + grad_weight = ( + torch.zeros_like(weight, device=device) if weight.requires_grad else None + ) + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + + # NOTE: skip .item() here to avoid CUDA synchronization + total_n_non_ignore = (target != ignore_index).sum() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + n_non_ignore = (target_chunk != ignore_index).sum().item() + + # when doing CE, use the upcasted precision + logits_chunk = logits_chunk.float() + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + loss_ptr=loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + reduction=reduction, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + # gradient of logits_chunk is computed in-place by the above triton kernel. + # Following HuggingFace model source code, we do the forward and backward + # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge. + # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194) + # Propagating to lm_head's backward, we'll switch back to the original dtype. + logits_chunk = logits_chunk.to(dtype) + + # gradient of logits_chunk is computed in-place by the above triton kernel and is of shape: chunk_size x V + # thus grad_input[start_idx: end_idx] should be of shape: chunk_size x H + # additionally, since we are chunking the inputs, observe that the loss and gradients are calculated only + # on `n_non_ignore` tokens. However, the gradient of the input should be calculated for all tokens. + # Thus, we need an additional scaling factor of (n_non_ignore/total_n_non_ignore) to scale the gradients. + + if reduction == "mean": + alpha = n_non_ignore / total_n_non_ignore if total_n_non_ignore > 0 else 0.0 + else: + alpha = 1.0 + + loss_1d[start_idx:end_idx] = loss_1d_slice * alpha + grad_logits_chunk = logits_chunk * alpha # chunk_size x V + + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t(), + mat2=_input_chunk, + out=grad_weight, + alpha=alpha, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=alpha, + ) + + loss = torch.sum(loss_1d) + return loss, grad_input, grad_weight, grad_bias + + +def fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias +): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + element_mul_kernel[(n_rows,)]( + grad_input, + grad_input.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_weight, + grad_weight.stride(-2), + grad_output, + H, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + element_mul_kernel[(n_rows,)]( + grad_bias, + grad_bias.stride(-1), + grad_output, + 1, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + return grad_input, grad_weight, grad_bias + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ignore_index=-100, + label_smoothing=0.0, + reduction="mean", + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input, weight, target, bias, ignore_index, label_smoothing, reduction + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + return loss + + @staticmethod + def backward(ctx, grad_output): + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return (grad_input, grad_weight, None, grad_bias, None, None, None) + +class LigerFusedLinearCrossEntropyLoss(CrossEntropyLoss): + def __init__(self, *args, **kwargs): + super(LigerFusedLinearCrossEntropyLoss, self).__init__(*args, **kwargs) + + def forward(self, lin_weight, _input, target, bias=None): + return LigerFusedLinearCrossEntropyFunction.apply( + _input, + lin_weight, + target, + bias, + self.ignore_index, + self.label_smoothing, + self.reduction, + ) + +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy + + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and (labels is not None): + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + + else: + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split( + self.vocab_size // self.config.pretraining_tp, dim=0 + ) + logits = [ + F.linear(hidden_states, lm_head_slices[i]) + for i in range(self.config.pretraining_tp) + ] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 58bb456f..4226b6a0 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -23,6 +23,7 @@ combine_triggers, ) from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, LlamaAttention, LlamaMLP, LlamaRMSNorm, @@ -34,6 +35,7 @@ from ..kernels.unsloth.rope_embedding import fast_rope_embedding from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops +from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward def get_mp_rules(base_type: str): """ @@ -42,6 +44,7 @@ def get_mp_rules(base_type: str): its forward builder argument, wrap the forward_builder function as a partial function with the base_type argument """ + return [ # TODO: have a generic version of this rule # - do regex on RMSNorm class name @@ -105,8 +108,11 @@ def get_mp_rules(base_type: str): base_type=base_type, ), ), - # TODO: have a generic version of this rule - # - get the module_name and reload on that + ModelPatcherRule( + rule_id="llama-fused-lce", + trigger=ModelPatcherTrigger(check=LlamaForCausalLM), + forward=lce_forward, + ), ModelPatcherRule( rule_id="llama-cross-ent", import_and_maybe_reload=( diff --git a/sample-configurations/foak-fast-kernels-sample-configuration.yaml b/sample-configurations/foak-fast-kernels-sample-configuration.yaml index 4f2e3692..369cfba8 100644 --- a/sample-configurations/foak-fast-kernels-sample-configuration.yaml +++ b/sample-configurations/foak-fast-kernels-sample-configuration.yaml @@ -22,10 +22,13 @@ plugins: # - the FastQuantized version is all-or-nothing # fast loss triton kernels - fast_loss: True + fast_loss: False # fast rms norm triton kernels fast_rsm_layernorm: True # fast RoPE embedding triton kernels fast_rope_embeddings: True + + # fused linear cross entropy loss + fused_linear_loss: True From a554ac88dc5f264abfd24e71f6de8043b985a2fc Mon Sep 17 00:00:00 2001 From: 1000850000 user Date: Fri, 20 Sep 2024 03:17:57 +0000 Subject: [PATCH 02/14] syntax fixes and remove unused code Signed-off-by: Anh Uong --- .../framework_plugin_fast_kernels.py | 4 +- .../kernels/liger/cross_entropy.py | 119 ------------------ 2 files changed, 2 insertions(+), 121 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 1265aec7..e72c91e7 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -122,9 +122,9 @@ def __init__(self, configurations: Dict[str, Dict]): key="fast_rope_embeddings", values=[False, True], default=True ) ) - self.configurations["fast_linear_cross_entropy"] = ( + self.configurations["fused_linear_loss"] = ( self._check_config_and_maybe_check_values( - key="fast_linear_cross_entropy", values=[False, True], default=False + key="fused_linear_loss", values=[False, True], default=False ) ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py index 5a9a9d07..bbd5a05f 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py @@ -220,122 +220,3 @@ def element_mul_kernel( X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) - -def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction): - BT, V = _input.shape - n_rows = BT - - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - - # unreduced loss - loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) - - n_non_ignore = (target != ignore_index).sum().item() - - # ensure _input and target are contiguous in the last dimension - if _input.stride(-1) != 1: - _input = _input.contiguous() - if target.stride(-1) != 1: - target = target.contiguous() - - # Here we use a trick to store X_ptr gradient in X_ptr so we can save memory - liger_cross_entropy_kernel[(n_rows,)]( - X_ptr=_input, - X_stride=_input.stride(-2), - Y_ptr=target, - Y_stride=target.stride(-1), # always 1 - loss_ptr=loss_1d, - loss_stride=loss_1d.stride(-1), # always 1 - n_cols=V, - n_non_ignore=n_non_ignore, - ignore_index=ignore_index, - label_smoothing=label_smoothing, - reduction=reduction, - BLOCK_SIZE=BLOCK_SIZE, - # TODO: 32 seems to give the best performance - # Performance is quite sensitive to num_warps - num_warps=32, - ) - - loss = torch.sum(loss_1d) - return loss, _input - - -def cross_entropy_backward(_input, grad_output): - # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): - pass - - # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place - # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. - else: - BT, V = _input.shape - n_rows = BT - BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) - - element_mul_kernel[(n_rows,)]( - _input, - _input.stride(-2), - grad_output, - V, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=32, - ) - - return _input - - -class LigerCrossEntropyFunction(torch.autograd.Function): - """ - This class implements a custom autograd function for the Liger Cross Entropy loss. - It overrides the forward and backward methods of the torch.autograd.Function class. - """ - - @staticmethod - def forward( - ctx, _input, target, ignore_index=-100, label_smoothing=0.0, reduction="mean" - ): - """ - The forward pass of the Liger Cross Entropy loss. - - Parameters: - ctx : The context object. - _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. - target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. - ignore_index (int): The index to ignore in the target. - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - reduction (str): The reduction to apply to the output: "none" | "mean | "sum". - - Returns: - tensor: The computed loss. - """ - loss, _input = cross_entropy_forward( - _input, target, ignore_index, label_smoothing, reduction - ) - # TODO: investigation - # If we don't detach the _input tensor, the memory will double - # Not sure why but seems that there will be a time both grad and value exist but in different location - ctx.save_for_backward(_input.detach()) - return loss - - @staticmethod - def backward(ctx, grad_output): - """ - The backward pass of the Liger Cross Entropy loss. - - Parameters: - ctx : The context object with saved tensors. - grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. - - Returns: - tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. - """ - (_input,) = ctx.saved_tensors - _input = cross_entropy_backward(_input, grad_output) - return ( - _input, - None, - None, - None, - None, - ) \ No newline at end of file From 33ee02ad264cc17cfd116ae8516682829b07fa6f Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Wed, 16 Oct 2024 10:53:26 -0600 Subject: [PATCH 03/14] add new num_logits_to_keep arg for llama.forward() Signed-off-by: Anh Uong --- .../liger/fused_linear_cross_entropy_loss.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py index 29cf5729..5ab9dc9e 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py @@ -135,7 +135,7 @@ def fused_linear_cross_entropy_forward( # gradient of logits_chunk is computed in-place by the above triton kernel. # Following HuggingFace model source code, we do the forward and backward - # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge. + # w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) is huge. # (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194) # Propagating to lm_head's backward, we'll switch back to the original dtype. logits_chunk = logits_chunk.to(dtype) @@ -306,6 +306,7 @@ def lce_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Copy paste llama forward but replace torch cross entropy with liger fused linear cross entropy @@ -317,6 +318,11 @@ def lce_forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: Example: @@ -390,9 +396,14 @@ def lce_forward( ] logits = torch.cat(logits, dim=-1) else: - logits = self.lm_head(hidden_states) - logits = logits.float() + # TODO: differing line below in granite models compared to llama/mistral model type + # logits = logits / self.config.logits_scaling + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() From 4322843e873e4a2cac51ab1943dd76328e4bedf1 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Wed, 16 Oct 2024 10:54:31 -0600 Subject: [PATCH 04/14] add mixtral model patch Signed-off-by: Anh Uong --- .../liger/fused_linear_cross_entropy_loss.py | 155 +++++++++++++++++- .../fms_acceleration_foak/models/mixtral.py | 7 + 2 files changed, 159 insertions(+), 3 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py index 5ab9dc9e..91bbc4cc 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py @@ -34,6 +34,14 @@ _CONFIG_FOR_DOC, LLAMA_INPUTS_DOCSTRING, ) +from transformers.models.mixtral.modeling_mixtral import ( + _CONFIG_FOR_DOC, + MIXTRAL_INPUTS_DOCSTRING, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) from transformers.utils import ( add_start_docstrings_to_model_forward, replace_return_docstrings, @@ -289,7 +297,8 @@ def forward(self, lin_weight, _input, target, bias=None): self.reduction, ) -@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +# TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models? +# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings( output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC ) @@ -328,9 +337,9 @@ def lce_forward( Example: ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> from transformers import AutoTokenizer, AutoModelForCausalLM - >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" @@ -374,6 +383,7 @@ def lce_forward( loss = None logits = None + # patch change if self.training and (labels is not None): shift_hidden_states = hidden_states[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() @@ -425,4 +435,143 @@ def lce_forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + ) + +# TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward? +@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) +@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) +# Ignore copy +def lce_forward_mixtral( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + # patch change + if self.training and (labels is not None): + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + # TODO: unique differing part to mixtral model forward + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + # TODO: should this loss manipulation be indented in?? or should it be added to even the liger loss? + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, ) \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py index 67eada1c..fe832aea 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -23,6 +23,7 @@ combine_triggers, ) from transformers.models.mixtral.modeling_mixtral import ( + MixtralForCausalLM, MixtralAttention, MixtralRMSNorm, ) @@ -31,6 +32,7 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward_mixtral from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops @@ -93,6 +95,11 @@ def get_mp_rules(base_type): "transformers.models.mixtral.modeling_mixtral", ), ), + ModelPatcherRule( + rule_id="mixtral-fused-lce", + trigger=ModelPatcherTrigger(check=MixtralForCausalLM), + forward=lce_forward_mixtral, + ), ModelPatcherRule( rule_id="mixtral-rope", import_and_maybe_reload=( From 05cdbe6626edd4ac50704a1accba8b337d475fee Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Wed, 16 Oct 2024 10:54:58 -0600 Subject: [PATCH 05/14] add mistral and granite model patch Signed-off-by: Anh Uong --- .../src/fms_acceleration_foak/models/granite.py | 7 +++++++ .../src/fms_acceleration_foak/models/llama.py | 3 +-- .../src/fms_acceleration_foak/models/mistral.py | 8 +++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py index a2be13ab..d40b7e1f 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py @@ -27,6 +27,7 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops @@ -40,6 +41,7 @@ def get_mp_rules(base_type: str): try: # Third Party from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel + GraniteForCausalLM, GraniteAttention, GraniteMLP, GraniteRMSNorm, @@ -120,6 +122,11 @@ def get_mp_rules(base_type: str): "transformers.models.granite.modeling_granite", ), ), + ModelPatcherRule( + rule_id="granite-fused-lce", + trigger=ModelPatcherTrigger(check=GraniteForCausalLM), + forward=lce_forward, + ), # TODO: have a generic version of this rule # - get the module name # - check if "apply_rotary_pos_emb" exists diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index 4226b6a0..be668119 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -33,9 +33,8 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops - from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward +from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops def get_mp_rules(base_type: str): """ diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index 8e773a24..0ea886d8 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -23,6 +23,7 @@ combine_triggers, ) from transformers.models.mistral.modeling_mistral import ( + MistralForCausalLM, MistralAttention, MistralMLP, MistralRMSNorm, @@ -32,9 +33,9 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding +from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward from .utils import KEY_MLP, KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops - def get_mp_rules(base_type): """ Function to access all patch rules in this module. @@ -110,6 +111,11 @@ def get_mp_rules(base_type): "transformers.models.mistral.modeling_mistral", ), ), + ModelPatcherRule( + rule_id="mistral-fused-lce", + trigger=ModelPatcherTrigger(check=MistralForCausalLM), + forward=lce_forward, + ), ModelPatcherRule( rule_id="mistral-rope", import_and_maybe_reload=( From b89896893abb6d998c24899265990b18a2505a56 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Sat, 9 Nov 2024 22:43:40 -0700 Subject: [PATCH 06/14] add benchmark Signed-off-by: Anh Uong --- scripts/benchmarks/refs/a100_80gb_liger.csv | 121 ++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 scripts/benchmarks/refs/a100_80gb_liger.csv diff --git a/scripts/benchmarks/refs/a100_80gb_liger.csv b/scripts/benchmarks/refs/a100_80gb_liger.csv new file mode 100644 index 00000000..6aa2f10b --- /dev/null +++ b/scripts/benchmarks/refs/a100_80gb_liger.csv @@ -0,0 +1,121 @@ +bf16,epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second +True,0.07,,none,2e-5,,,16519.0,13632690688.0,6770300416.0,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.3393232345581056,51.8099,7.721,1.93,15811.649 +True,0.07,,none,2e-5,,,17401.0,11311659520.0,9063590400.0,bigcode/gpt_bigcode-santacoder,2,,2,,,bfloat16,2.199138298034668,35.7321,11.194,2.799,11463.097 +True,0.14,,none,2e-5,,,26739.0,20492466688.0,6769448448.0,bigcode/gpt_bigcode-santacoder,1,,8,,,bfloat16,2.3271564292907714,96.9916,8.248,1.031,16892.182 +True,0.14,,none,2e-5,,,20603.0,13862509056.0,9063707136.0,bigcode/gpt_bigcode-santacoder,2,,4,,,bfloat16,2.181814079284668,57.0793,14.016,1.752,14351.974 +True,0.07,,foak-fast-kernels,2e-5,,,15809.0,12021062144.0,6769251840.0,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.338859519958496,52.4698,7.623,1.906,15612.801 +True,0.07,,foak-fast-kernels,2e-5,,,16332.0,11311631872.0,9063562752.0,bigcode/gpt_bigcode-santacoder,2,,2,,,bfloat16,2.1992162322998046,35.3999,11.299,2.825,11570.652 +True,0.14,,foak-fast-kernels,2e-5,,,20597.0,17273076224.0,6769448448.0,bigcode/gpt_bigcode-santacoder,1,,8,,,bfloat16,2.327177867889404,96.0192,8.332,1.041,17063.257 +True,0.14,,foak-fast-kernels,2e-5,,,19285.0,12251984384.0,9063762432.0,bigcode/gpt_bigcode-santacoder,2,,4,,,bfloat16,2.1819879150390626,56.8156,14.081,1.76,14418.571 +True,0.07,,foak-fast-kernels-liger,2e-5,,,16521.0,13632690688.0,6770300416.0,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.338957748413086,51.9579,7.699,1.925,15766.612 +True,0.07,,foak-fast-kernels-liger,2e-5,,,17419.0,11311631872.0,9063562752.0,bigcode/gpt_bigcode-santacoder,2,,2,,,bfloat16,2.1991508483886717,35.1166,11.391,2.848,11664.015 +True,0.14,,foak-fast-kernels-liger,2e-5,,,26741.0,20492466688.0,6769448448.0,bigcode/gpt_bigcode-santacoder,1,,8,,,bfloat16,2.327241439819336,96.978,8.249,1.031,16894.556 +True,0.14,,foak-fast-kernels-liger,2e-5,,,20601.0,13863576576.0,9064765440.0,bigcode/gpt_bigcode-santacoder,2,,4,,,bfloat16,2.1819076919555664,57.0911,14.013,1.752,14349.008 +True,0.15,,none,2e-5,,,77207.0,72434853376.0,43467892224.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.8358560228347778,546.7736,0.732,0.183,2996.487 +True,0.15,,none,2e-5,,,78874.0,72434657280.0,57951176704.0,mistralai/Mistral-7B-v0.1,2,,2,,,bfloat16,0.833277006149292,311.0566,1.286,0.321,2633.604 +True,0.29,,none,2e-5,,,79883.0,72435246592.0,43468285440.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.833172254562378,1065.355,0.751,0.094,3075.782 +True,0.29,,none,2e-5,,,78420.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8249223232269287,567.6429,1.409,0.176,2886.322 +True,0.15,,foak-fast-kernels,2e-5,,,77233.0,72432723456.0,43466827264.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.8359725856781006,487.7703,0.82,0.205,3358.958 +True,0.15,,foak-fast-kernels,2e-5,,,78896.0,72434657280.0,57951176704.0,mistralai/Mistral-7B-v0.1,2,,2,,,bfloat16,0.8332040405273438,281.0146,1.423,0.356,2915.151 +True,0.29,,foak-fast-kernels,2e-5,,,71197.0,72433116672.0,43467220480.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.8336040306091309,946.2985,0.845,0.106,3462.755 +True,0.29,,foak-fast-kernels,2e-5,,,76683.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8249501895904541,508.9631,1.572,0.196,3219.094 +True,0.15,,foak-fast-kernels-liger,2e-5,,,71447.0,72432723456.0,43466827264.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.8359153127670288,487.509,0.82,0.205,3360.758 +True,0.15,,foak-fast-kernels-liger,2e-5,,,75779.0,72434657280.0,57951176704.0,mistralai/Mistral-7B-v0.1,2,,2,,,bfloat16,0.8328942394256592,281.5454,1.421,0.355,2909.655 +True,0.29,,foak-fast-kernels-liger,2e-5,,,75991.0,72433116672.0,43467220480.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.8338063526153564,946.5099,0.845,0.106,3461.982 +True,0.29,,foak-fast-kernels-liger,2e-5,,,79411.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8249048280715943,507.7695,1.576,0.197,3226.661 +True,,,none,2e-5,,,81177.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,bfloat16,,,,, +True,,,none,2e-5,,,79126.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,bfloat16,,,,, +True,,,none,2e-5,,,80729.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,bfloat16,,,,, +True,,,none,2e-5,,,80182.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,79128.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,79185.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,80127.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,79185.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,bfloat16,,,,, +True,,,none,2e-5,,,80873.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,bfloat16,,,,, +True,,,none,2e-5,,,80296.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,bfloat16,,,,, +True,,,none,2e-5,,,80873.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,bfloat16,,,,, +True,,,none,2e-5,,,80296.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,bfloat16,,,,, +True,0.15,,none,2e-4,16,0.1,29931.0,25681144320.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.857630443572998,485.2882,0.824,0.206,3376.138 +True,0.15,,none,2e-4,16,0.1,18457.0,14975803392.0,7368046592.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8570447063446045,281.9272,1.419,0.355,2905.715 +True,0.29,,none,2e-4,16,0.1,43971.0,36670876160.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,0.8569988822937011,961.8276,0.832,0.104,3406.848 +True,0.29,,none,2e-4,16,0.1,26155.0,21621940224.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8571900749206542,503.0123,1.59,0.199,3257.177 +True,0.15,,foak-fast-kernels,2e-4,16,0.1,28673.0,23530188288.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8572746562957764,426.2826,0.938,0.235,3843.459 +True,0.15,,foak-fast-kernels,2e-4,16,0.1,18123.0,14774476800.0,7368046592.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8571001052856445,252.0162,1.587,0.397,3250.585 +True,0.29,,foak-fast-kernels,2e-4,16,0.1,41433.0,32393276928.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,0.8570582962036133,842.5631,0.949,0.119,3889.086 +True,0.29,,foak-fast-kernels,2e-4,16,0.1,25005.0,21219287040.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8568509960174561,445.2057,1.797,0.225,3680.097 +True,0.15,,foak-fast-kernels-liger,2e-4,16,0.1,24995.0,23530188288.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8571897888183594,425.7566,0.94,0.235,3848.208 +True,0.15,,foak-fast-kernels-liger,2e-4,16,0.1,18495.0,14774476800.0,7368046592.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8576602077484131,260.1896,1.537,0.384,3148.473 +True,0.29,,foak-fast-kernels-liger,2e-4,16,0.1,34083.0,32393276928.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,0.8570835971832276,843.3957,0.949,0.119,3885.246 +True,0.29,,foak-fast-kernels-liger,2e-4,16,0.1,25551.0,21219287040.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8565159416198731,444.1643,1.801,0.225,3688.725 +True,,,none,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.15,,none,2e-4,16,0.1,62756.0,57925768704.0,47365978112.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.891134901046753,529.9427,0.755,0.189,1545.827 +True,,,none,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.29,,none,2e-4,16,0.1,70146.0,65050678784.0,47366174720.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8917711734771728,880.3987,0.909,0.114,1860.975 +True,,,foak-fast-kernels,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.15,,foak-fast-kernels,2e-4,16,0.1,62813.0,57699328000.0,47365978112.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8909786415100097,501.8188,0.797,0.199,1632.462 +True,,,foak-fast-kernels,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.29,,foak-fast-kernels,2e-4,16,0.1,69736.0,64608276992.0,47366174720.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8914951801300048,828.2126,0.966,0.121,1978.236 +True,,,foak-fast-kernels-liger,2e-4,16,0.1,81001.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.15,,foak-fast-kernels-liger,2e-4,16,0.1,62588.0,57699239424.0,47365978112.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8928797817230225,503.4879,0.794,0.199,1627.05 +True,,,foak-fast-kernels-liger,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.29,,foak-fast-kernels-liger,2e-4,16,0.1,70752.0,64592406528.0,47366174720.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8918977546691894,827.1362,0.967,0.121,1980.81 +True,,,none,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,none,2e-4,16,0.1,80929.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,none,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,none,2e-4,16,0.1,80423.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels,2e-4,16,0.1,80956.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels,2e-4,16,0.1,81116.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-4,16,0.1,81076.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-4,16,0.1,81116.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, +,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,19945.0,15353458176.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0182268142700195,485.6646,0.824,0.206,3373.521 +,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,14776.0,9542673920.0,2261220352.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9975294589996337,289.7083,1.381,0.345,2827.672 +,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,33739.0,26343190016.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.97410005569458,955.3479,0.837,0.105,3429.955 +,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,22294.0,16188810752.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9795886325836182,509.3915,1.571,0.196,3216.387 +,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,19715.0,13095119872.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0186691761016846,417.6284,0.958,0.239,3923.104 +,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,14563.0,9326863872.0,2261220352.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0090518665313721,223.4905,1.79,0.447,3665.48 +,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,33499.0,21853776896.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9725016212463379,818.2182,0.978,0.122,4004.8 +,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,21486.0,15703516672.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9953651046752929,422.34,1.894,0.237,3879.339 +,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,15397.0,13065335808.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0356037425994873,416.0081,0.962,0.24,3938.385 +,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,14574.0,9326863872.0,2261220352.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.023795919418335,223.8253,1.787,0.447,3659.997 +,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,24869.0,21792109568.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9770747470855713,817.6695,0.978,0.122,4007.487 +,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,21779.0,15703516672.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.012446279525757,421.8896,1.896,0.237,3883.48 +,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,37599.0,35528093184.0,24511457792.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9050130844116211,838.4913,0.477,0.119,1953.986 +,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,25997.0,21070198272.0,12581256192.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9056115531921387,510.6631,0.783,0.196,1604.189 +,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,50101.0,46517825024.0,24511851008.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9009766864776612,1599.6588,0.5,0.063,2048.437 +,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,33101.0,28182882304.0,12581452800.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9007492160797119,874.2594,0.915,0.114,1874.043 +,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,37887.0,34183875584.0,24511457792.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9058037376403809,769.7911,0.52,0.13,2128.37 +,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,26189.0,20783975424.0,12581256192.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9081956386566162,438.3654,0.912,0.228,1868.761 +,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,51057.0,43775222784.0,24511851008.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9022124576568603,1463.3011,0.547,0.068,2239.32 +,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,32711.0,27564131840.0,12581452800.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9040882682800293,782.8028,1.022,0.128,2092.992 +,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,42377.0,34324272128.0,24511457792.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9190836048126221,767.9232,0.521,0.13,2133.547 +,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,27540.0,20785280512.0,12581256192.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9187229442596435,435.3256,0.919,0.23,1881.81 +,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,73227.0,44132304896.0,24511851008.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9090401840209961,1463.3652,0.547,0.068,2239.222 +,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,34619.0,27563463680.0,12581452800.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9175021457672119,784.8748,1.019,0.127,2087.467 +,0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,71685.0,67069752832.0,36122373120.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9918170833587646,3617.6879,0.111,0.028,452.886 +,0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,53040.0,45637770240.0,18219970048.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9916643810272217,1935.4609,0.207,0.052,423.258 +,,True,accelerated-peft-autogptq,2e-4,16,0.1,81055.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,True,accelerated-peft-autogptq,2e-4,16,0.1,80982.0,,,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +,0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,71605.0,65992275456.0,36122373120.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9917643451690674,3300.024,0.121,0.03,496.481 +,0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,53438.0,45360356352.0,18219970048.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9924971103668213,1696.7048,0.236,0.059,482.818 +,,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,80447.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,80976.0,,,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, +,0.14,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,70231.0,65992275456.0,36122373120.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9904376316070557,3297.5962,0.121,0.03,496.847 +,0.14,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,53315.0,45360356352.0,18219970048.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9912145042419434,1697.5248,0.236,0.059,482.585 +,,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,80447.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, +,0.28,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,81102.0,70763420672.0,18220166656.0,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9919774532318115,3328.1658,0.24,0.03,492.283 From 94549272571d73b25792e16b0d5d048dc66e08f5 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Thu, 7 Nov 2024 15:37:28 -0700 Subject: [PATCH 07/14] add new liger benchmarks Signed-off-by: Anh Uong --- .../configs/fast_kernels_liger.yaml | 28 ++++++ .../configs/fast_quantized_peft_liger.yaml | 33 +++++++ sample-configurations/CONTENTS.yaml | 17 ++++ ...ogptq-foak-liger-sample-configuration.yaml | 55 ++++++++++++ ...oak-padding-free-sample-configuration.yaml | 5 +- ...ft-autogptq-foak-sample-configuration.yaml | 5 +- ...b-nf4-foak-liger-sample-configuration.yaml | 50 +++++++++++ ...oak-padding-free-sample-configuration.yaml | 5 +- ...eft-bnb-nf4-foak-sample-configuration.yaml | 5 +- ...st-kernels-liger-sample-configuration.yaml | 33 +++++++ ...oak-fast-kernels-sample-configuration.yaml | 11 ++- scripts/benchmarks/scenarios-liger.yaml | 90 +++++++++++++++++++ scripts/generate_sample_configurations.py | 17 +++- 13 files changed, 343 insertions(+), 11 deletions(-) create mode 100644 plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml create mode 100644 plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml create mode 100644 sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml create mode 100644 sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml create mode 100644 sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml create mode 100644 scripts/benchmarks/scenarios-liger.yaml diff --git a/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml b/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml new file mode 100644 index 00000000..8011db78 --- /dev/null +++ b/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml @@ -0,0 +1,28 @@ +training: + + fused_ops_and_kernels: + + # if under training stanza, then putting + # base_layer and fused_lora will be a misnomer + # - this should be in peft.quantized + # However, if it is specified, it will still + # be read. This is useful in use cases where + # the yaml is system generated and not shown + # to a user. + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + # fast loss triton kernels + fast_loss: False + + # fast rms norm triton kernels + fast_rms_layernorm: True + + # fast RoPE embedding triton kernels + fast_rope_embeddings: True + + # fused linear cross entropy loss + fused_linear_loss: True \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml new file mode 100644 index 00000000..7f239849 --- /dev/null +++ b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml @@ -0,0 +1,33 @@ +# PEFT-related acceleration +peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: auto_gptq + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + + # fused kernels for lora linear layers + fused_lora: True + + # fast loss triton kernels + fast_loss: False + + # fast rms norm triton kernels + fast_rsm_layernorm: True + + # fast RoPE embedding triton kernels + fast_rope_embeddings: True + + # fused linear cross entropy loss + fused_linear_loss: True \ No newline at end of file diff --git a/sample-configurations/CONTENTS.yaml b/sample-configurations/CONTENTS.yaml index 6781b3bd..33bc66f0 100644 --- a/sample-configurations/CONTENTS.yaml +++ b/sample-configurations/CONTENTS.yaml @@ -27,12 +27,24 @@ framework_configs: - fused-ops-and-kernels filename: accelerated-peft-autogptq-foak-sample-configuration.yaml + - shortname: accelerated-peft-autogptq-foak-liger + plugins: + - accelerated-peft + - fused-ops-and-kernels + filename: accelerated-peft-autogptq-foak-liger-sample-configuration.yaml + - shortname: accelerated-peft-bnb-foak plugins: - accelerated-peft - fused-ops-and-kernels filename: accelerated-peft-bnb-nf4-foak-sample-configuration.yaml + - shortname: accelerated-peft-bnb-foak-liger + plugins: + - accelerated-peft + - fused-ops-and-kernels + filename: accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml + - shortname: aadp-padding-free plugins: - attention-and-distributed-packing @@ -73,3 +85,8 @@ framework_configs: plugins: - fused-ops-and-kernels filename: foak-fast-kernels-sample-configuration.yaml + + - shortname: foak-fast-kernels-liger + plugins: + - fused-ops-and-kernels + filename: foak-fast-kernels-liger-sample-configuration.yaml \ No newline at end of file diff --git a/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml new file mode 100644 index 00000000..1abc5a11 --- /dev/null +++ b/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml @@ -0,0 +1,55 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # PEFT-related acceleration + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # AutoGPTQ quantized base weights. + auto_gptq: + + # Kernel to be used for GPTQ linear laeyer + # NOTE: Not all kernels are suitable for PEFT training; need to use + # kernels that support autograd forward / backward. The best + # recommendation at the moment is "triton_v2". + kernel: triton_v2 + + # If true, then will already expect quantized checkpoint + # passed into TrainingArguments.model_name_or_path + from_quantized: true + + # Setting to false, will create GPTQ-LORA using the local autogptq package. + # if true, will create legacy implementation of GPTQ-LORA using external + # `auto_gptq`. Refer to README for installation instructions + use_external_lib: false + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: auto_gptq + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + + # fused kernels for lora linear layers + fused_lora: true + + # fast loss triton kernels + fast_loss: false + + # fast rms norm triton kernels + fast_rsm_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true + + # fused linear cross entropy loss + fused_linear_loss: true diff --git a/sample-configurations/accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml index a331154e..5639842d 100644 --- a/sample-configurations/accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-autogptq-foak-padding-free-sample-configuration.yaml @@ -43,7 +43,10 @@ plugins: base_layer: auto_gptq # activate various unsloth optimizations - # NOTE: currently supports only all-or-nothing. + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + # fused kernels for lora linear layers fused_lora: true diff --git a/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml index 3ca13131..78c07cc9 100644 --- a/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-autogptq-foak-sample-configuration.yaml @@ -34,7 +34,10 @@ plugins: base_layer: auto_gptq # activate various unsloth optimizations - # NOTE: currently supports only all-or-nothing. + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + # fused kernels for lora linear layers fused_lora: true diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml new file mode 100644 index 00000000..4376182e --- /dev/null +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml @@ -0,0 +1,50 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + # PEFT-related acceleration + peft: + + # quantization-releated acceleration + # e.g., kernels for quantized base weights + quantization: + + # For loading BitsAndBytes quantized layers + # to serve as 4bit base-weights for LoRA PEFT-tuning. + # NOTE: currently AutoGPTQ is not properly integrated into huggingface / + # bitsandbytes, thus recommended quant_type to be either "nf4" + # or "fp4". + # bitsandbytes: + bitsandbytes: + quant_type: nf4 + + # If True, then no get_peft_model and prepare_model_for_kbit_training + # will be called. + no_peft_model: false + fused_ops_and_kernels: + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: bitsandbytes + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + + # fused kernels for lora linear layers + fused_lora: true + + # fast loss triton kernels + fast_loss: false + + # fast rms norm triton kernels + fast_rsm_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true + + # fused linear cross entropy loss + fused_linear_loss: true diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml index 32d077ae..b5752a3b 100644 --- a/sample-configurations/accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-padding-free-sample-configuration.yaml @@ -38,7 +38,10 @@ plugins: base_layer: bitsandbytes # activate various unsloth optimizations - # NOTE: currently supports only all-or-nothing. + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + # fused kernels for lora linear layers fused_lora: true diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml index f3f8741a..75fd3037 100644 --- a/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-sample-configuration.yaml @@ -29,7 +29,10 @@ plugins: base_layer: bitsandbytes # activate various unsloth optimizations - # NOTE: currently supports only all-or-nothing. + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + # fused kernels for lora linear layers fused_lora: true diff --git a/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml b/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml new file mode 100644 index 00000000..7002026a --- /dev/null +++ b/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml @@ -0,0 +1,33 @@ +# FMS Acceleration Plugin Configuration. +# +# Each stanza incorporates various configurations for +# different fine-tuning / training tasks. +plugins: + training: + + fused_ops_and_kernels: + + # if under training stanza, then putting + # base_layer and fused_lora will be a misnomer + # - this should be in peft.quantized + # However, if it is specified, it will still + # be read. This is useful in use cases where + # the yaml is system generated and not shown + # to a user. + + # activate various unsloth optimizations + # there are two versions of the plugin + # - the FastKernel version supports individual kernels + # - the FastQuantized version is all-or-nothing + + # fast loss triton kernels + fast_loss: false + + # fast rms norm triton kernels + fast_rms_layernorm: true + + # fast RoPE embedding triton kernels + fast_rope_embeddings: true + + # fused linear cross entropy loss + fused_linear_loss: true diff --git a/sample-configurations/foak-fast-kernels-sample-configuration.yaml b/sample-configurations/foak-fast-kernels-sample-configuration.yaml index 369cfba8..ba7669aa 100644 --- a/sample-configurations/foak-fast-kernels-sample-configuration.yaml +++ b/sample-configurations/foak-fast-kernels-sample-configuration.yaml @@ -3,10 +3,9 @@ # Each stanza incorporates various configurations for # different fine-tuning / training tasks. plugins: - # Configurations to accelerate data packing/padding in training training: - fused_ops_and_kernels: + fused_ops_and_kernels: # if under training stanza, then putting # base_layer and fused_lora will be a misnomer @@ -22,13 +21,13 @@ plugins: # - the FastQuantized version is all-or-nothing # fast loss triton kernels - fast_loss: False + fast_loss: true # fast rms norm triton kernels - fast_rsm_layernorm: True + fast_rms_layernorm: true # fast RoPE embedding triton kernels - fast_rope_embeddings: True + fast_rope_embeddings: true # fused linear cross entropy loss - fused_linear_loss: True + fused_linear_loss: false diff --git a/scripts/benchmarks/scenarios-liger.yaml b/scripts/benchmarks/scenarios-liger.yaml new file mode 100644 index 00000000..81a212fa --- /dev/null +++ b/scripts/benchmarks/scenarios-liger.yaml @@ -0,0 +1,90 @@ +# This file holds a list of scenarios to may be run. +# - to limit to a number of scenarios, use the --run-only-scenarios flag. +# - Each scenario will be run against a particular acceleration framework +# config, if the framework_config: key is specified. +# * a particular framework configuration +# - the arguments tag will hold arguments to be passed to sft_trainer +# * the arguments are singular except for model_name_or_path which can handle +# multiple arguments. +# - So anything that is critical for the scenario MUST be specified here +# and not in the defaults, e.g. fp16 + +# This stanza will be used in future to replace the custom processing functions in data_processing.py +# data_processing: +# dataset_name: yahma/alpaca-cleaned +# chat_template: | +# {%- for message in messages %} +# {% if message['input'] != '' %} +# Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. + +# {% else %} +# Below is an instruction that describes a task. Write a response that appropriately completes the request. + +# {% endif %} +# ### Instruction: +# {{ message['instruction'] }} + +# {% if message['input'] != '' %} +# ### Input: +# {{ message['input'] }} + +# {% endif %} +# ### Response: +# {{ message['output'] + eos_token }} +# {% endfor %} +# tokenize: True + + +scenarios: + - name: full-finetuning + framework_config: + - + - foak-fast-kernels + - foak-fast-kernels-liger + arguments: + learning_rate: 2e-5 + model_name_or_path: + - 'bigcode/gpt_bigcode-santacoder' + - 'mistralai/Mistral-7B-v0.1' + - 'mistralai/Mixtral-8x7B-Instruct-v0.1' + - 'NousResearch/Llama-2-70b-hf' + torch_dtype: bfloat16 + bf16: True + + - name: standard-peft + framework_config: + - + - foak-fast-kernels + - foak-fast-kernels-liger + arguments: + bf16: True + learning_rate: 2e-4 + torch_dtype: bfloat16 + peft_method: lora + r: 16 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + model_name_or_path: + - 'mistralai/Mistral-7B-v0.1' + - 'mistralai/Mixtral-8x7B-Instruct-v0.1' + - 'NousResearch/Llama-2-70b-hf' + + - name: accelerated-peft-gptq + framework_config: + - accelerated-peft-autogptq + - accelerated-peft-autogptq-foak + - accelerated-peft-autogptq-foak-liger + arguments: + learning_rate: 2e-4 + fp16: True # running gptq-lora in float16 is more performant, see issue + torch_dtype: float16 # https://github.com/foundation-model-stack/fms-acceleration/issues/84 + peft_method: lora + r: 16 + lora_alpha: 16 + lora_dropout: 0.1 + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + model_name_or_path: + - 'TheBloke/Mistral-7B-v0.1-GPTQ' + - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ' + - 'TheBloke/Llama-2-70B-GPTQ' diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index 11619106..86c53dfa 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -144,10 +144,13 @@ def read_configuration(path: str) -> Dict: KEY_BNB_NF4 = "bnb-nf4" KEY_BNB_NF4_BASELINE = "baseline-bnb-nf4" KEY_AUTO_GPTQ_FOAK = "auto-gptq-foak" +KEY_AUTO_GPTQ_FOAK_LIGER = "auto-gptq-foak-liger" KEY_BNB_NF4_FOAK = "bnb-nf4-foak" +KEY_BNB_NF4_FOAK_LIGER = "bnb-nf4-foak-liger" KEY_AADP_PADDING_FREE = "aadp-padding-free" KEY_AADP_MULTIPACK = "aadp-multipack" KEY_FAST_KERNELS = "foak-fast-kernels" +KEY_FAST_KERNELS_LIGER = "foak-fast-kernels-liger" CONFIGURATIONS = { KEY_AUTO_GPTQ: "plugins/accelerated-peft/configs/autogptq.yaml", @@ -166,13 +169,22 @@ def read_configuration(path: str) -> Dict: "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", [("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")], ), + KEY_AUTO_GPTQ_FOAK_LIGER: ( + "plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml", + [("peft.quantization.fused_ops_and_kernels.base_layer", "auto_gptq")], + ), KEY_BNB_NF4_FOAK: ( "plugins/fused-ops-and-kernels/configs/fast_quantized_peft.yaml", [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")], ), + KEY_BNB_NF4_FOAK_LIGER: ( + "plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml", + [("peft.quantization.fused_ops_and_kernels.base_layer", "bitsandbytes")], + ), KEY_AADP_PADDING_FREE: "plugins/attention-and-distributed-packing/configs/padding_free.yaml", KEY_AADP_MULTIPACK: "plugins/attention-and-distributed-packing/configs/multipack.yaml", KEY_FAST_KERNELS: "plugins/fused-ops-and-kernels/configs/fast_kernels.yaml", + KEY_FAST_KERNELS_LIGER: "plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml", } # list of (tag, combi) tuples @@ -186,13 +198,16 @@ def read_configuration(path: str) -> Dict: ("baseline-peft-bnb-nf4", (KEY_BNB_NF4_BASELINE,)), ("accelerated-peft-autogptq-foak", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ("accelerated-peft-bnb-nf4-foak", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), + ("accelerated-peft-autogptq-foak-liger", (KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK_LIGER)), + ("accelerated-peft-bnb-nf4-foak-liger", (KEY_BNB_NF4, KEY_BNB_NF4_FOAK_LIGER)), ("aadp-padding-free", (KEY_AADP_PADDING_FREE,)), ("accelerated-peft-autogptq-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ)), ("accelerated-peft-bnb-nf4-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4)), ("accelerated-peft-autogptq-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_AUTO_GPTQ, KEY_AUTO_GPTQ_FOAK)), ("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)), - ("foak-fast-kernels", (KEY_FAST_KERNELS)) + ("foak-fast-kernels", (KEY_FAST_KERNELS,)), + ("foak-fast-kernels-liger", (KEY_FAST_KERNELS_LIGER,)), ] From 2c202ef409fee602c0e3c0af92c6b3c6841b7c31 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 14 Nov 2024 08:51:39 +0000 Subject: [PATCH 08/14] some fixes Signed-off-by: Yu Chin Fabian Lim --- .../liger/fused_linear_cross_entropy_loss.py | 160 ------------------ .../fms_acceleration_foak/models/mixtral.py | 7 - scripts/benchmarks/benchmark.py | 2 +- scripts/benchmarks/scenarios-liger.yaml | 37 ++-- 4 files changed, 25 insertions(+), 181 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py index 91bbc4cc..edc655f6 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py @@ -30,22 +30,6 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama.modeling_llama import ( - _CONFIG_FOR_DOC, - LLAMA_INPUTS_DOCSTRING, -) -from transformers.models.mixtral.modeling_mixtral import ( - _CONFIG_FOR_DOC, - MIXTRAL_INPUTS_DOCSTRING, -) -from transformers.modeling_outputs import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, -) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from .cross_entropy import ( element_mul_kernel, @@ -297,11 +281,6 @@ def forward(self, lin_weight, _input, target, bias=None): self.reduction, ) -# TODO: how to add diff docstrings for diff model types? what if the loss functions aren't the same across models? -# @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def lce_forward( self, input_ids: torch.LongTensor = None, @@ -435,143 +414,4 @@ def lce_forward( past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - ) - -# TODO: is adding a separate copy of lce_forward() the right path or should the additional logic for Moe models be in the single lce_forward? -@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) -@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) -# Ignore copy -def lce_forward_mixtral( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, -) -> Union[Tuple, MoeCausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MixtralForCausalLM - - >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits - ) - - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - - loss = None - logits = None - - # patch change - if self.training and (labels is not None): - shift_hidden_states = hidden_states[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # flatten tokens - shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) - shift_labels = shift_labels.view(-1) - - lce = LigerFusedLinearCrossEntropyLoss() - loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) - else: - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - # TODO: unique differing part to mixtral model forward - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits if return_dict else outputs[-1], - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - # TODO: should this loss manipulation be indented in?? or should it be added to even the liger loss? - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - - if not return_dict: - output = (logits,) + outputs[1:] - if output_router_logits: - output = (aux_loss,) + output - return (loss,) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, ) \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py index fe832aea..67eada1c 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mixtral.py @@ -23,7 +23,6 @@ combine_triggers, ) from transformers.models.mixtral.modeling_mixtral import ( - MixtralForCausalLM, MixtralAttention, MixtralRMSNorm, ) @@ -32,7 +31,6 @@ from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward_mixtral from .utils import KEY_O, KEY_QKV, build_lora_fused_ops, trigger_fused_ops @@ -95,11 +93,6 @@ def get_mp_rules(base_type): "transformers.models.mixtral.modeling_mixtral", ), ), - ModelPatcherRule( - rule_id="mixtral-fused-lce", - trigger=ModelPatcherTrigger(check=MixtralForCausalLM), - forward=lce_forward_mixtral, - ), ModelPatcherRule( rule_id="mixtral-rope", import_and_maybe_reload=( diff --git a/scripts/benchmarks/benchmark.py b/scripts/benchmarks/benchmark.py index 38fe6679..df314868 100644 --- a/scripts/benchmarks/benchmark.py +++ b/scripts/benchmarks/benchmark.py @@ -723,7 +723,7 @@ def prepare_arguments(args, benchmark_dataset: BenchmarkDataset): if ( not args.run_only_scenarios - and scenarios.slow + and scenario.slow ): # unfiltered runs omit all "slow" marked scenarios print(f"Skipping slow scenario '{_scn_name}' beacuse run_only_scenarios=None.") diff --git a/scripts/benchmarks/scenarios-liger.yaml b/scripts/benchmarks/scenarios-liger.yaml index 81a212fa..cdd026d2 100644 --- a/scripts/benchmarks/scenarios-liger.yaml +++ b/scripts/benchmarks/scenarios-liger.yaml @@ -38,22 +38,18 @@ scenarios: - name: full-finetuning framework_config: - - - foak-fast-kernels - foak-fast-kernels-liger arguments: learning_rate: 2e-5 model_name_or_path: - - 'bigcode/gpt_bigcode-santacoder' - - 'mistralai/Mistral-7B-v0.1' - - 'mistralai/Mixtral-8x7B-Instruct-v0.1' - - 'NousResearch/Llama-2-70b-hf' + # - 'mistralai/Mistral-7B-v0.1' + - 'meta-llama/Meta-Llama-3-8B' torch_dtype: bfloat16 bf16: True - name: standard-peft framework_config: - - - foak-fast-kernels - foak-fast-kernels-liger arguments: @@ -66,13 +62,29 @@ scenarios: lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] model_name_or_path: - - 'mistralai/Mistral-7B-v0.1' - - 'mistralai/Mixtral-8x7B-Instruct-v0.1' - - 'NousResearch/Llama-2-70b-hf' + # - 'mistralai/Mistral-7B-v0.1' + - 'meta-llama/Meta-Llama-3-8B' + + - name: accelerated-peft-bnb + framework_config: + - accelerated-peft-bnb-foak + - accelerated-peft-bnb-foak-liger + arguments: + bf16: True + learning_rate: 2e-4 + torch_dtype: bfloat16 + peft_method: lora + r: 16 + lora_alpha: 16 + lora_dropout: 0.1 + per_device_train_batch_size: + target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] + model_name_or_path: + # - 'mistralai/Mistral-7B-v0.1' + - 'meta-llama/Meta-Llama-3-8B' - name: accelerated-peft-gptq framework_config: - - accelerated-peft-autogptq - accelerated-peft-autogptq-foak - accelerated-peft-autogptq-foak-liger arguments: @@ -85,6 +97,5 @@ scenarios: lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] model_name_or_path: - - 'TheBloke/Mistral-7B-v0.1-GPTQ' - - 'TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ' - - 'TheBloke/Llama-2-70B-GPTQ' + # - 'TheBloke/Mistral-7B-v0.1-GPTQ' + - 'TechxGenus/Meta-Llama-3-8B-GPTQ' From 12ebdb907b39f18993783e8fece6b513c43a6975 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Fri, 15 Nov 2024 15:41:51 +0000 Subject: [PATCH 09/14] revise benches Signed-off-by: Yu Chin Fabian Lim --- scripts/benchmarks/refs/a100_80gb_liger.csv | 216 ++++++++---------- .../benchmarks/refs/requirements_liger.txt | 87 +++++++ 2 files changed, 183 insertions(+), 120 deletions(-) create mode 100644 scripts/benchmarks/refs/requirements_liger.txt diff --git a/scripts/benchmarks/refs/a100_80gb_liger.csv b/scripts/benchmarks/refs/a100_80gb_liger.csv index 6aa2f10b..871df7fb 100644 --- a/scripts/benchmarks/refs/a100_80gb_liger.csv +++ b/scripts/benchmarks/refs/a100_80gb_liger.csv @@ -1,121 +1,97 @@ bf16,epoch,fp16,framework_config,learning_rate,lora_alpha,lora_dropout,mem_nvidia_mem_reserved,mem_peak_torch_mem_alloc_in_bytes,mem_torch_mem_alloc_in_bytes,model_name_or_path,num_gpus,peft_method,per_device_train_batch_size,r,target_modules,torch_dtype,train_loss,train_runtime,train_samples_per_second,train_steps_per_second,train_tokens_per_second -True,0.07,,none,2e-5,,,16519.0,13632690688.0,6770300416.0,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.3393232345581056,51.8099,7.721,1.93,15811.649 -True,0.07,,none,2e-5,,,17401.0,11311659520.0,9063590400.0,bigcode/gpt_bigcode-santacoder,2,,2,,,bfloat16,2.199138298034668,35.7321,11.194,2.799,11463.097 -True,0.14,,none,2e-5,,,26739.0,20492466688.0,6769448448.0,bigcode/gpt_bigcode-santacoder,1,,8,,,bfloat16,2.3271564292907714,96.9916,8.248,1.031,16892.182 -True,0.14,,none,2e-5,,,20603.0,13862509056.0,9063707136.0,bigcode/gpt_bigcode-santacoder,2,,4,,,bfloat16,2.181814079284668,57.0793,14.016,1.752,14351.974 -True,0.07,,foak-fast-kernels,2e-5,,,15809.0,12021062144.0,6769251840.0,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.338859519958496,52.4698,7.623,1.906,15612.801 -True,0.07,,foak-fast-kernels,2e-5,,,16332.0,11311631872.0,9063562752.0,bigcode/gpt_bigcode-santacoder,2,,2,,,bfloat16,2.1992162322998046,35.3999,11.299,2.825,11570.652 -True,0.14,,foak-fast-kernels,2e-5,,,20597.0,17273076224.0,6769448448.0,bigcode/gpt_bigcode-santacoder,1,,8,,,bfloat16,2.327177867889404,96.0192,8.332,1.041,17063.257 -True,0.14,,foak-fast-kernels,2e-5,,,19285.0,12251984384.0,9063762432.0,bigcode/gpt_bigcode-santacoder,2,,4,,,bfloat16,2.1819879150390626,56.8156,14.081,1.76,14418.571 -True,0.07,,foak-fast-kernels-liger,2e-5,,,16521.0,13632690688.0,6770300416.0,bigcode/gpt_bigcode-santacoder,1,,4,,,bfloat16,2.338957748413086,51.9579,7.699,1.925,15766.612 -True,0.07,,foak-fast-kernels-liger,2e-5,,,17419.0,11311631872.0,9063562752.0,bigcode/gpt_bigcode-santacoder,2,,2,,,bfloat16,2.1991508483886717,35.1166,11.391,2.848,11664.015 -True,0.14,,foak-fast-kernels-liger,2e-5,,,26741.0,20492466688.0,6769448448.0,bigcode/gpt_bigcode-santacoder,1,,8,,,bfloat16,2.327241439819336,96.978,8.249,1.031,16894.556 -True,0.14,,foak-fast-kernels-liger,2e-5,,,20601.0,13863576576.0,9064765440.0,bigcode/gpt_bigcode-santacoder,2,,4,,,bfloat16,2.1819076919555664,57.0911,14.013,1.752,14349.008 -True,0.15,,none,2e-5,,,77207.0,72434853376.0,43467892224.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.8358560228347778,546.7736,0.732,0.183,2996.487 -True,0.15,,none,2e-5,,,78874.0,72434657280.0,57951176704.0,mistralai/Mistral-7B-v0.1,2,,2,,,bfloat16,0.833277006149292,311.0566,1.286,0.321,2633.604 -True,0.29,,none,2e-5,,,79883.0,72435246592.0,43468285440.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.833172254562378,1065.355,0.751,0.094,3075.782 -True,0.29,,none,2e-5,,,78420.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8249223232269287,567.6429,1.409,0.176,2886.322 -True,0.15,,foak-fast-kernels,2e-5,,,77233.0,72432723456.0,43466827264.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.8359725856781006,487.7703,0.82,0.205,3358.958 -True,0.15,,foak-fast-kernels,2e-5,,,78896.0,72434657280.0,57951176704.0,mistralai/Mistral-7B-v0.1,2,,2,,,bfloat16,0.8332040405273438,281.0146,1.423,0.356,2915.151 -True,0.29,,foak-fast-kernels,2e-5,,,71197.0,72433116672.0,43467220480.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.8336040306091309,946.2985,0.845,0.106,3462.755 -True,0.29,,foak-fast-kernels,2e-5,,,76683.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8249501895904541,508.9631,1.572,0.196,3219.094 -True,0.15,,foak-fast-kernels-liger,2e-5,,,71447.0,72432723456.0,43466827264.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.8359153127670288,487.509,0.82,0.205,3360.758 -True,0.15,,foak-fast-kernels-liger,2e-5,,,75779.0,72434657280.0,57951176704.0,mistralai/Mistral-7B-v0.1,2,,2,,,bfloat16,0.8328942394256592,281.5454,1.421,0.355,2909.655 -True,0.29,,foak-fast-kernels-liger,2e-5,,,75991.0,72433116672.0,43467220480.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.8338063526153564,946.5099,0.845,0.106,3461.982 -True,0.29,,foak-fast-kernels-liger,2e-5,,,79411.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8249048280715943,507.7695,1.576,0.197,3226.661 -True,,,none,2e-5,,,81177.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,bfloat16,,,,, -True,,,none,2e-5,,,79126.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,bfloat16,,,,, -True,,,none,2e-5,,,80729.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,bfloat16,,,,, -True,,,none,2e-5,,,80182.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,79128.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,79185.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,80127.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,2,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,81179.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,,8,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,79185.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,2,,4,,,bfloat16,,,,, -True,,,none,2e-5,,,80873.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,bfloat16,,,,, -True,,,none,2e-5,,,80296.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,bfloat16,,,,, -True,,,none,2e-5,,,80873.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,bfloat16,,,,, -True,,,none,2e-5,,,80296.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,bfloat16,,,,, -True,,,foak-fast-kernels,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,4,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,2,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,80875.0,,,NousResearch/Llama-2-70b-hf,1,,8,,,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-5,,,80298.0,,,NousResearch/Llama-2-70b-hf,2,,4,,,bfloat16,,,,, -True,0.15,,none,2e-4,16,0.1,29931.0,25681144320.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.857630443572998,485.2882,0.824,0.206,3376.138 -True,0.15,,none,2e-4,16,0.1,18457.0,14975803392.0,7368046592.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8570447063446045,281.9272,1.419,0.355,2905.715 -True,0.29,,none,2e-4,16,0.1,43971.0,36670876160.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,0.8569988822937011,961.8276,0.832,0.104,3406.848 -True,0.29,,none,2e-4,16,0.1,26155.0,21621940224.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8571900749206542,503.0123,1.59,0.199,3257.177 -True,0.15,,foak-fast-kernels,2e-4,16,0.1,28673.0,23530188288.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8572746562957764,426.2826,0.938,0.235,3843.459 -True,0.15,,foak-fast-kernels,2e-4,16,0.1,18123.0,14774476800.0,7368046592.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8571001052856445,252.0162,1.587,0.397,3250.585 -True,0.29,,foak-fast-kernels,2e-4,16,0.1,41433.0,32393276928.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,0.8570582962036133,842.5631,0.949,0.119,3889.086 -True,0.29,,foak-fast-kernels,2e-4,16,0.1,25005.0,21219287040.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8568509960174561,445.2057,1.797,0.225,3680.097 -True,0.15,,foak-fast-kernels-liger,2e-4,16,0.1,24995.0,23530188288.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8571897888183594,425.7566,0.94,0.235,3848.208 -True,0.15,,foak-fast-kernels-liger,2e-4,16,0.1,18495.0,14774476800.0,7368046592.0,mistralai/Mistral-7B-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8576602077484131,260.1896,1.537,0.384,3148.473 -True,0.29,,foak-fast-kernels-liger,2e-4,16,0.1,34083.0,32393276928.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,0.8570835971832276,843.3957,0.949,0.119,3885.246 -True,0.29,,foak-fast-kernels-liger,2e-4,16,0.1,25551.0,21219287040.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8565159416198731,444.1643,1.801,0.225,3688.725 -True,,,none,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,0.15,,none,2e-4,16,0.1,62756.0,57925768704.0,47365978112.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.891134901046753,529.9427,0.755,0.189,1545.827 -True,,,none,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,0.29,,none,2e-4,16,0.1,70146.0,65050678784.0,47366174720.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8917711734771728,880.3987,0.909,0.114,1860.975 -True,,,foak-fast-kernels,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,0.15,,foak-fast-kernels,2e-4,16,0.1,62813.0,57699328000.0,47365978112.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8909786415100097,501.8188,0.797,0.199,1632.462 -True,,,foak-fast-kernels,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,0.29,,foak-fast-kernels,2e-4,16,0.1,69736.0,64608276992.0,47366174720.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8914951801300048,828.2126,0.966,0.121,1978.236 -True,,,foak-fast-kernels-liger,2e-4,16,0.1,81001.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,0.15,,foak-fast-kernels-liger,2e-4,16,0.1,62588.0,57699239424.0,47365978112.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,0.8928797817230225,503.4879,0.794,0.199,1627.05 -True,,,foak-fast-kernels-liger,2e-4,16,0.1,81225.0,,,mistralai/Mixtral-8x7B-Instruct-v0.1,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,0.29,,foak-fast-kernels-liger,2e-4,16,0.1,70752.0,64592406528.0,47366174720.0,mistralai/Mixtral-8x7B-Instruct-v0.1,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,0.8918977546691894,827.1362,0.967,0.121,1980.81 -True,,,none,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,none,2e-4,16,0.1,80929.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,none,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,none,2e-4,16,0.1,80423.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels,2e-4,16,0.1,80956.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels,2e-4,16,0.1,81116.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-4,16,0.1,81076.0,,,NousResearch/Llama-2-70b-hf,2,lora,2,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-4,16,0.1,81029.0,,,NousResearch/Llama-2-70b-hf,1,lora,8,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -True,,,foak-fast-kernels-liger,2e-4,16,0.1,81116.0,,,NousResearch/Llama-2-70b-hf,2,lora,4,16,q_proj k_proj v_proj o_proj,bfloat16,,,,, -,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,19945.0,15353458176.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0182268142700195,485.6646,0.824,0.206,3373.521 -,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,14776.0,9542673920.0,2261220352.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9975294589996337,289.7083,1.381,0.345,2827.672 -,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,33739.0,26343190016.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.97410005569458,955.3479,0.837,0.105,3429.955 -,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,22294.0,16188810752.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9795886325836182,509.3915,1.571,0.196,3216.387 -,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,19715.0,13095119872.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0186691761016846,417.6284,0.958,0.239,3923.104 -,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,14563.0,9326863872.0,2261220352.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.0090518665313721,223.4905,1.79,0.447,3665.48 -,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,33499.0,21853776896.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9725016212463379,818.2182,0.978,0.122,4004.8 -,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,21486.0,15703516672.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9953651046752929,422.34,1.894,0.237,3879.339 -,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,15397.0,13065335808.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.0356037425994873,416.0081,0.962,0.24,3938.385 -,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,14574.0,9326863872.0,2261220352.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,1.023795919418335,223.8253,1.787,0.447,3659.997 -,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,24869.0,21792109568.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9770747470855713,817.6695,0.978,0.122,4007.487 -,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,21779.0,15703516672.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,1.012446279525757,421.8896,1.896,0.237,3883.48 -,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,37599.0,35528093184.0,24511457792.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9050130844116211,838.4913,0.477,0.119,1953.986 -,0.15,True,accelerated-peft-autogptq,2e-4,16,0.1,25997.0,21070198272.0,12581256192.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9056115531921387,510.6631,0.783,0.196,1604.189 -,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,50101.0,46517825024.0,24511851008.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9009766864776612,1599.6588,0.5,0.063,2048.437 -,0.29,True,accelerated-peft-autogptq,2e-4,16,0.1,33101.0,28182882304.0,12581452800.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9007492160797119,874.2594,0.915,0.114,1874.043 -,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,37887.0,34183875584.0,24511457792.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9058037376403809,769.7911,0.52,0.13,2128.37 -,0.15,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,26189.0,20783975424.0,12581256192.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9081956386566162,438.3654,0.912,0.228,1868.761 -,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,51057.0,43775222784.0,24511851008.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9022124576568603,1463.3011,0.547,0.068,2239.32 -,0.29,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,32711.0,27564131840.0,12581452800.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9040882682800293,782.8028,1.022,0.128,2092.992 -,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,42377.0,34324272128.0,24511457792.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9190836048126221,767.9232,0.521,0.13,2133.547 -,0.15,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,27540.0,20785280512.0,12581256192.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9187229442596435,435.3256,0.919,0.23,1881.81 -,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,73227.0,44132304896.0,24511851008.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,0.9090401840209961,1463.3652,0.547,0.068,2239.222 -,0.29,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,34619.0,27563463680.0,12581452800.0,TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9175021457672119,784.8748,1.019,0.127,2087.467 -,0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,71685.0,67069752832.0,36122373120.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9918170833587646,3617.6879,0.111,0.028,452.886 -,0.14,True,accelerated-peft-autogptq,2e-4,16,0.1,53040.0,45637770240.0,18219970048.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9916643810272217,1935.4609,0.207,0.052,423.258 -,,True,accelerated-peft-autogptq,2e-4,16,0.1,81055.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,,True,accelerated-peft-autogptq,2e-4,16,0.1,80982.0,,,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -,0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,71605.0,65992275456.0,36122373120.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9917643451690674,3300.024,0.121,0.03,496.481 -,0.14,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,53438.0,45360356352.0,18219970048.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9924971103668213,1696.7048,0.236,0.059,482.818 -,,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,80447.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,,True,accelerated-peft-autogptq-foak,2e-4,16,0.1,80976.0,,,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,,,,, -,0.14,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,70231.0,65992275456.0,36122373120.0,TheBloke/Llama-2-70B-GPTQ,1,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9904376316070557,3297.5962,0.121,0.03,496.847 -,0.14,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,53315.0,45360356352.0,18219970048.0,TheBloke/Llama-2-70B-GPTQ,2,lora,2,16,q_proj k_proj v_proj o_proj,float16,0.9912145042419434,1697.5248,0.236,0.059,482.585 -,,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,80447.0,,,TheBloke/Llama-2-70B-GPTQ,1,lora,8,16,q_proj k_proj v_proj o_proj,float16,,,,, -,0.28,True,accelerated-peft-autogptq-foak-liger,2e-4,16,0.1,81102.0,70763420672.0,18220166656.0,TheBloke/Llama-2-70B-GPTQ,2,lora,4,16,q_proj k_proj v_proj o_proj,float16,0.9919774532318115,3328.1658,0.24,0.03,492.283 +True,0.17,,foak-fast-kernels,2e-05,,,77695.0,80318097408.0,48198051840.0,meta-llama/Meta-Llama-3-8B,1,,4,,,bfloat16,0.991794786453247,496.5063,0.806,0.201,3299.857 +True,,,foak-fast-kernels,2e-05,,,77437.0,,,meta-llama/Meta-Llama-3-8B,1,,8,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-05,,,52497.0,,,meta-llama/Meta-Llama-3-8B,1,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-05,,,68225.0,,,meta-llama/Meta-Llama-3-8B,1,,32,,,bfloat16,,,,, +True,0.17,,foak-fast-kernels-liger,2e-05,,,77537.0,80318097408.0,48198051840.0,meta-llama/Meta-Llama-3-8B,1,,4,,,bfloat16,0.9917966079711914,498.2461,0.803,0.201,3288.335 +True,0.34,,foak-fast-kernels-liger,2e-05,,,79881.0,80318490624.0,48198445056.0,meta-llama/Meta-Llama-3-8B,1,,8,,,bfloat16,0.9880468559265136,961.3168,0.832,0.104,3408.658 +True,,,foak-fast-kernels-liger,2e-05,,,79925.0,,,meta-llama/Meta-Llama-3-8B,1,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-05,,,78987.0,,,meta-llama/Meta-Llama-3-8B,1,,32,,,bfloat16,,,,, +True,0.17,,foak-fast-kernels,0.0002,16.0,0.1,50711.0,37644015616.0,16241584128.0,meta-llama/Meta-Llama-3-8B,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0148053932189942,427.8517,0.935,0.234,3829.365 +True,0.34,,foak-fast-kernels,0.0002,16.0,0.1,58301.0,59017447424.0,16241977344.0,meta-llama/Meta-Llama-3-8B,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0162366390228272,849.3577,0.942,0.118,3857.974 +True,,,foak-fast-kernels,0.0002,16.0,0.1,57695.0,,,meta-llama/Meta-Llama-3-8B,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels,0.0002,16.0,0.1,67261.0,,,meta-llama/Meta-Llama-3-8B,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.17,,foak-fast-kernels-liger,0.0002,16.0,0.1,25465.0,24905936896.0,16241584128.0,meta-llama/Meta-Llama-3-8B,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0148304843902587,426.2687,0.938,0.235,3843.585 +True,0.34,,foak-fast-kernels-liger,0.0002,16.0,0.1,34681.0,33567698944.0,16241977344.0,meta-llama/Meta-Llama-3-8B,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.01575608253479,844.8273,0.947,0.118,3878.662 +True,0.68,,foak-fast-kernels-liger,0.0002,16.0,0.1,53115.0,50891223040.0,16242763776.0,meta-llama/Meta-Llama-3-8B,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.014188413619995,1684.0095,0.95,0.059,3891.665 +True,,,foak-fast-kernels-liger,0.0002,16.0,0.1,79051.0,,,meta-llama/Meta-Llama-3-8B,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.17,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,41221.0,27259280384.0,5884111872.0,meta-llama/Meta-Llama-3-8B,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0356837940216064,410.0877,0.975,0.244,3995.243 +True,0.34,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,75995.0,48632712192.0,5884505088.0,meta-llama/Meta-Llama-3-8B,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0364081382751464,803.4691,0.996,0.124,4078.315 +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,57877.0,,,meta-llama/Meta-Llama-3-8B,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,77591.0,,,meta-llama/Meta-Llama-3-8B,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.17,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,15527.0,14411297792.0,5884111872.0,meta-llama/Meta-Llama-3-8B,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0357236671447754,408.6648,0.979,0.245,4009.154 +True,0.34,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,24359.0,22936744960.0,5884505088.0,meta-llama/Meta-Llama-3-8B,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0366782855987549,798.2473,1.002,0.125,4104.994 +True,0.68,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,42409.0,39987639296.0,5885291520.0,meta-llama/Meta-Llama-3-8B,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0348288249969482,1582.9171,1.011,0.063,4140.204 +True,1.35,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,78505.0,74089427968.0,5886864384.0,meta-llama/Meta-Llama-3-8B,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0354842090606688,3150.7966,1.016,0.032,4159.964 +,0.17,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,40823.0,27289066496.0,5913897984.0,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.0433240985870362,429.8265,0.931,0.233,3811.771 +,0.34,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,75341.0,48662498304.0,5914291200.0,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.0526717376708985,840.1997,0.952,0.119,3900.025 +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,73875.0,,,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,77553.0,,,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,0.17,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,15449.0,14441084416.0,5913897984.0,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.0416254806518557,424.7732,0.942,0.235,3857.117 +,0.34,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,24601.0,22966531584.0,5914291200.0,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.05357590675354,835.1457,0.958,0.12,3923.627 +,0.68,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,42905.0,40017425920.0,5915077632.0,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,1.0603761863708496,1657.3815,0.965,0.06,3954.189 +,1.35,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,78747.0,74119214592.0,5916650496.0,TechxGenus/Meta-Llama-3-8B-GPTQ,1,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,1.0691538333892825,3291.3639,0.972,0.03,3982.301 +True,0.34,,foak-fast-kernels,2e-05,,,81143.0,80320219648.0,64259672576.0,meta-llama/Meta-Llama-3-8B,2,,4,,,bfloat16,0.9677656173706056,690.8416,1.158,0.145,2371.6 +True,,,foak-fast-kernels,2e-05,,,68775.0,,,meta-llama/Meta-Llama-3-8B,2,,8,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-05,,,74233.0,,,meta-llama/Meta-Llama-3-8B,2,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-05,,,80729.0,,,meta-llama/Meta-Llama-3-8B,2,,32,,,bfloat16,,,,, +True,0.34,,foak-fast-kernels-liger,2e-05,,,80809.0,80320219648.0,64259672576.0,meta-llama/Meta-Llama-3-8B,2,,4,,,bfloat16,0.9677164840698242,623.5916,1.283,0.16,2627.361 +True,0.68,,foak-fast-kernels-liger,2e-05,,,81135.0,80320612864.0,64260065792.0,meta-llama/Meta-Llama-3-8B,2,,8,,,bfloat16,0.9606661891937256,1128.4084,1.418,0.089,2903.913 +True,,,foak-fast-kernels-liger,2e-05,,,79908.0,,,meta-llama/Meta-Llama-3-8B,2,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-05,,,80729.0,,,meta-llama/Meta-Llama-3-8B,2,,32,,,bfloat16,,,,, +True,0.34,,foak-fast-kernels,0.0002,16.0,0.1,47323.0,31606073856.0,8156781056.0,meta-llama/Meta-Llama-3-8B,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0190155601501465,455.9939,1.754,0.219,3593.031 +True,0.68,,foak-fast-kernels,0.0002,16.0,0.1,60685.0,52979505664.0,8157174272.0,meta-llama/Meta-Llama-3-8B,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0157581710815429,894.78,1.788,0.112,3662.129 +True,,,foak-fast-kernels,0.0002,16.0,0.1,72428.0,,,meta-llama/Meta-Llama-3-8B,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels,0.0002,16.0,0.1,66893.0,,,meta-llama/Meta-Llama-3-8B,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.34,,foak-fast-kernels-liger,0.0002,16.0,0.1,26877.0,23383704064.0,8156781056.0,meta-llama/Meta-Llama-3-8B,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0187576484680176,453.693,1.763,0.22,3611.252 +True,0.68,,foak-fast-kernels-liger,0.0002,16.0,0.1,41607.0,36071997952.0,8157174272.0,meta-llama/Meta-Llama-3-8B,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0161932468414308,888.8348,1.8,0.113,3686.624 +True,1.35,,foak-fast-kernels-liger,0.0002,16.0,0.1,70589.0,61448585728.0,8157960704.0,meta-llama/Meta-Llama-3-8B,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0166991901397706,1754.3381,1.824,0.057,3735.654 +True,,,foak-fast-kernels-liger,0.0002,16.0,0.1,78800.0,,,meta-llama/Meta-Llama-3-8B,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.34,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,43054.0,26509831680.0,3033275904.0,meta-llama/Meta-Llama-3-8B,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.042497215270996,422.108,1.895,0.237,3881.471 +True,0.68,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,78267.0,47883263488.0,3033669120.0,meta-llama/Meta-Llama-3-8B,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0393894767761231,815.2901,1.962,0.123,4019.183 +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,67822.0,,,meta-llama/Meta-Llama-3-8B,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,78756.0,,,meta-llama/Meta-Llama-3-8B,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.34,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,22164.0,17850271232.0,3033275904.0,meta-llama/Meta-Llama-3-8B,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0428221797943116,420.5608,1.902,0.238,3895.75 +True,0.68,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,35872.0,30402250240.0,3033669120.0,meta-llama/Meta-Llama-3-8B,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0388935470581055,811.0503,1.973,0.123,4040.193 +True,1.35,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,63878.0,55506208256.0,3034455552.0,meta-llama/Meta-Llama-3-8B,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0399011611938476,1601.8839,1.998,0.062,4091.183 +True,,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,79215.0,,,meta-llama/Meta-Llama-3-8B,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +,0.34,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,43748.0,26526510592.0,3049954816.0,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.072213077545166,438.1959,1.826,0.228,3738.967 +,0.68,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,79025.0,47899942400.0,3050348032.0,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.0689481258392337,852.041,1.878,0.117,3845.824 +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,68244.0,,,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,78983.0,,,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,0.34,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,22853.0,17867802624.0,3049954816.0,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.0653237342834472,435.5815,1.837,0.23,3761.409 +,0.68,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,36549.0,30419781632.0,3050348032.0,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.0656037616729737,847.3836,1.888,0.118,3866.962 +,1.35,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,64583.0,55523739648.0,3051134464.0,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,1.0683570766448975,1677.4704,1.908,0.06,3906.835 +,,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,79665.0,,,TechxGenus/Meta-Llama-3-8B-GPTQ,2,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +True,0.68,,foak-fast-kernels,2e-05,,,66986.0,47584840192.0,32138562048.0,meta-llama/Meta-Llama-3-8B,4,,4,,,bfloat16,0.9608215522766114,511.7436,3.127,0.195,3201.603 +True,1.35,,foak-fast-kernels,2e-05,,,81081.5,68958272000.0,32138955264.0,meta-llama/Meta-Llama-3-8B,4,,8,,,bfloat16,0.9529002285003664,1447.107,2.211,0.069,2264.38 +True,,,foak-fast-kernels,2e-05,,,52173.0,,,meta-llama/Meta-Llama-3-8B,4,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-05,,,71790.0,,,meta-llama/Meta-Llama-3-8B,4,,32,,,bfloat16,,,,, +True,0.68,,foak-fast-kernels-liger,2e-05,,,54707.5,40168839680.0,32138562048.0,meta-llama/Meta-Llama-3-8B,4,,4,,,bfloat16,0.9608566761016846,512.2228,3.124,0.195,3198.608 +True,1.35,,foak-fast-kernels-liger,2e-05,,,63613.0,44061352448.0,32138955264.0,meta-llama/Meta-Llama-3-8B,4,,8,,,bfloat16,0.9528849792480468,980.3948,3.264,0.102,3342.327 +True,2.7,,foak-fast-kernels-liger,2e-05,,,79252.5,60306944512.0,32139741696.0,meta-llama/Meta-Llama-3-8B,4,,16,,,bfloat16,0.9459449291229248,1914.3344,3.343,0.052,3423.435 +True,,,foak-fast-kernels-liger,2e-05,,,77731.0,,,meta-llama/Meta-Llama-3-8B,4,,32,,,bfloat16,,,,, +True,0.68,,foak-fast-kernels,0.0002,16.0,0.1,43643.0,27551078912.0,4088154624.0,meta-llama/Meta-Llama-3-8B,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0161590099334716,456.0277,3.509,0.219,3592.764 +True,1.35,,foak-fast-kernels,0.0002,16.0,0.1,78865.0,48924510720.0,4088547840.0,meta-llama/Meta-Llama-3-8B,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0162527561187744,885.3585,3.614,0.113,3701.1 +True,,,foak-fast-kernels,0.0002,16.0,0.1,72808.5,,,meta-llama/Meta-Llama-3-8B,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,foak-fast-kernels,0.0002,16.0,0.1,79491.0,,,meta-llama/Meta-Llama-3-8B,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.68,,foak-fast-kernels-liger,0.0002,16.0,0.1,23197.0,19315503616.0,4088154624.0,meta-llama/Meta-Llama-3-8B,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0156009674072266,454.3664,3.521,0.22,3605.901 +True,1.35,,foak-fast-kernels-liger,0.0002,16.0,0.1,37981.0,32003797504.0,4088547840.0,meta-llama/Meta-Llama-3-8B,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0163630485534667,880.5812,3.634,0.114,3721.179 +True,2.7,,foak-fast-kernels-liger,0.0002,16.0,0.1,66909.0,57380385280.0,4089334272.0,meta-llama/Meta-Llama-3-8B,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0160735416412354,1740.1593,3.678,0.057,3766.092 +True,,,foak-fast-kernels-liger,0.0002,16.0,0.1,80867.0,,,meta-llama/Meta-Llama-3-8B,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.68,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,42590.5,25112976896.0,1636421120.0,meta-llama/Meta-Llama-3-8B,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0383064079284667,423.731,3.776,0.236,3866.604 +True,1.35,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,77795.5,46486408704.0,1636814336.0,meta-llama/Meta-Llama-3-8B,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.039291534423828,817.7661,3.913,0.122,4007.014 +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,66703.0,,,meta-llama/Meta-Llama-3-8B,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,77649.0,,,meta-llama/Meta-Llama-3-8B,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.68,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,21677.0,16453416448.0,1636421120.0,meta-llama/Meta-Llama-3-8B,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0390117359161375,421.9715,3.792,0.237,3882.726 +True,1.35,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,35376.0,29005395456.0,1636814336.0,meta-llama/Meta-Llama-3-8B,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.0401603603363037,812.9813,3.936,0.123,4030.597 +True,2.7,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,63405.0,54109353472.0,1637600768.0,meta-llama/Meta-Llama-3-8B,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,1.037994260787964,1602.714,3.993,0.062,4089.064 +True,,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,78893.5,,,meta-llama/Meta-Llama-3-8B,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +,0.68,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,44207.5,25122840064.0,1646284288.0,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.0636010646820069,439.0306,3.644,0.228,3731.858 +,1.35,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,79475.5,46496271872.0,1646677504.0,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.0646323871612549,853.2938,3.75,0.117,3840.178 +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,67956.0,,,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,78902.0,,,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,0.68,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,23330.0,16464132096.0,1646284288.0,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.0635689449310304,436.8828,3.662,0.229,3750.205 +,1.35,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,36974.0,29016111104.0,1646677504.0,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.059765977859497,849.659,3.766,0.118,3856.606 +,2.7,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,65315.5,54120069120.0,1647463936.0,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,1.068314094543457,1676.21,3.818,0.06,3909.773 +,,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,79377.5,,,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, diff --git a/scripts/benchmarks/refs/requirements_liger.txt b/scripts/benchmarks/refs/requirements_liger.txt new file mode 100644 index 00000000..fff46200 --- /dev/null +++ b/scripts/benchmarks/refs/requirements_liger.txt @@ -0,0 +1,87 @@ +accelerate==1.0.1 +aiohappyeyeballs==2.4.3 +aiohttp==3.11.0 +aiosignal==1.3.1 +async-timeout==5.0.1 +attrs==24.2.0 +bitsandbytes==0.43.3 +certifi==2024.8.30 +charset-normalizer==3.4.0 +contourpy==1.3.1 +cycler==0.12.1 +datasets==2.21.0 +dill==0.3.8 +docstring_parser==0.16 +einops==0.8.0 +filelock==3.16.1 +flash-attn==2.7.0.post2 +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@d58960c46f990e3a805ce95a2d4cdee7dc831e19#egg=fms_acceleration&subdirectory=plugins/framework +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@d58960c46f990e3a805ce95a2d4cdee7dc831e19#egg=fms_acceleration_aadp&subdirectory=plugins/attention-and-distributed-packing +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@d58960c46f990e3a805ce95a2d4cdee7dc831e19#egg=fms_acceleration_foak&subdirectory=plugins/fused-ops-and-kernels +-e git+https://github.com/foundation-model-stack/fms-acceleration.git@d58960c46f990e3a805ce95a2d4cdee7dc831e19#egg=fms_acceleration_peft&subdirectory=plugins/accelerated-peft +fms-hf-tuning @ git+https://github.com/foundation-model-stack/fms-hf-tuning.git@398c2a8fe26d734344240555585d95e05299faa8 +fonttools==4.54.1 +frozenlist==1.5.0 +fsspec==2024.6.1 +huggingface-hub==0.26.2 +idna==3.10 +Jinja2==3.1.4 +kiwisolver==1.4.7 +llvmlite==0.43.0 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.9.2 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.1.0 +multiprocess==0.70.16 +networkx==3.4.2 +numba==0.60.0 +numpy==1.26.4 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.77 +nvidia-nvtx-cu12==12.1.105 +packaging==24.2 +pandas==2.2.3 +peft==0.13.2 +pillow==11.0.0 +propcache==0.2.0 +protobuf==5.28.3 +psutil==6.1.0 +pyarrow==18.0.0 +Pygments==2.18.0 +pyparsing==3.2.0 +python-dateutil==2.9.0.post0 +pytz==2024.2 +PyYAML==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rich==13.9.4 +safetensors==0.4.5 +sentencepiece==0.2.0 +shtab==1.7.1 +simpleeval==0.9.13 +six==1.16.0 +sympy==1.13.3 +threadpoolctl==3.5.0 +tokenizers==0.20.3 +torch==2.4.1 +tqdm==4.67.0 +transformers==4.45.2 +triton==3.0.0 +trl==0.11.4 +typing_extensions==4.12.2 +tyro==0.8.14 +tzdata==2024.2 +urllib3==2.2.3 +xxhash==3.5.0 +yarl==1.17.1 From 45f1a892918b561623abfda5060acd6ea04e863b Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 16 Nov 2024 03:44:02 +0000 Subject: [PATCH 10/14] refactor to fused_ops Signed-off-by: Yu Chin Fabian Lim --- plugins/fused-ops-and-kernels/.pylintrc | 4 ++- .../framework_plugin_fast_kernels.py | 4 +-- .../fused_ops/liger_ce/__init__.py | 27 +++++++++++++++++++ .../liger_ce}/cross_entropy.py | 0 .../fused_linear_cross_entropy_loss.py | 0 .../fms_acceleration_foak/models/granite.py | 4 +-- .../src/fms_acceleration_foak/models/llama.py | 4 +-- .../fms_acceleration_foak/models/mistral.py | 5 ++-- plugins/fused-ops-and-kernels/tox.ini | 1 + 9 files changed, 39 insertions(+), 10 deletions(-) create mode 100644 plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/__init__.py rename plugins/fused-ops-and-kernels/src/fms_acceleration_foak/{kernels/liger => fused_ops/liger_ce}/cross_entropy.py (100%) rename plugins/fused-ops-and-kernels/src/fms_acceleration_foak/{kernels/liger => fused_ops/liger_ce}/fused_linear_cross_entropy_loss.py (100%) diff --git a/plugins/fused-ops-and-kernels/.pylintrc b/plugins/fused-ops-and-kernels/.pylintrc index 31cb902c..cfe9aeb7 100644 --- a/plugins/fused-ops-and-kernels/.pylintrc +++ b/plugins/fused-ops-and-kernels/.pylintrc @@ -53,7 +53,9 @@ ignore=CVS,protobufs # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. # NOTE: do not lint code imported from unsloth -ignore-paths=.*fused_ops/unsloth_lora.*,.*kernels/unsloth* +ignore-paths=.*fused_ops/unsloth_lora.*, + .*fused_ops/liger_ce.*, + .*kernels/unsloth*, # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index a09b0253..049b26d4 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -23,8 +23,8 @@ import torch # Local -from .utils import lora_adapters_switch_ddp_from_fsdp from .models.utils import filter_mp_rules +from .utils import lora_adapters_switch_ddp_from_fsdp def validate_plugin_args(configurations): @@ -33,6 +33,7 @@ def validate_plugin_args(configurations): configurations["fused_linear_loss"] != configurations["fast_loss"] ), "If using `fused_linear_loss`, `fast_loss` must be set to False" + # consider rewriting register_foak_model_patch_rules into something # like this also def register_foak_model_patch_rules( @@ -139,7 +140,6 @@ def __init__(self, configurations: Dict[str, Dict]): validate_plugin_args(self.configurations) - @property def requires_agumentation(self): return True diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/__init__.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/__init__.py new file mode 100644 index 00000000..3a6da048 --- /dev/null +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 Byron Hsu & Linkedin team. All rights reserved. +# +# BSD 2-CLAUSE LICENSE +# Copyright 2024 LinkedIn Corporation +# All Rights Reserved. +# Redistribution and use in source and binary forms, with or +# without modification, are permitted provided that the following +# conditions are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from .fused_linear_cross_entropy_loss import lce_forward \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/cross_entropy.py similarity index 100% rename from plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/cross_entropy.py rename to plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/cross_entropy.py diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py similarity index 100% rename from plugins/fused-ops-and-kernels/src/fms_acceleration_foak/kernels/liger/fused_linear_cross_entropy_loss.py rename to plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/liger_ce/fused_linear_cross_entropy_loss.py diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py index 87cd48bf..e4b58572 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/granite.py @@ -26,10 +26,10 @@ from transformers import PretrainedConfig # Local +from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward from .utils import ( KEY_MLP, KEY_O, @@ -51,8 +51,8 @@ def get_mp_rules(base_type: str, config: PretrainedConfig = None): try: # Third Party from transformers.models.granite.modeling_granite import ( # pylint: disable=import-outside-toplevel - GraniteForCausalLM, GraniteAttention, + GraniteForCausalLM, GraniteMLP, GraniteRMSNorm, ) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py index b65fb768..94fab82f 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/llama.py @@ -25,17 +25,17 @@ ) from transformers import PretrainedConfig from transformers.models.llama.modeling_llama import ( - LlamaForCausalLM, LlamaAttention, + LlamaForCausalLM, LlamaMLP, LlamaRMSNorm, ) # Local +from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward from .utils import ( KEY_MLP, KEY_O, diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py index 35bf57ac..64e65274 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/models/mistral.py @@ -25,18 +25,17 @@ ) from transformers import PretrainedConfig from transformers.models.mistral.modeling_mistral import ( - MistralForCausalLM, MistralAttention, + MistralForCausalLM, MistralMLP, MistralRMSNorm, ) # Local +from ..fused_ops.liger_ce.fused_linear_cross_entropy_loss import lce_forward from ..kernels.unsloth.cross_entropy_loss import FastCrossEntropyLoss from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm from ..kernels.unsloth.rope_embedding import fast_rope_embedding -from ..kernels.liger.fused_linear_cross_entropy_loss import lce_forward - from .utils import ( KEY_MLP, KEY_O, diff --git a/plugins/fused-ops-and-kernels/tox.ini b/plugins/fused-ops-and-kernels/tox.ini index c3a38721..37a66b45 100644 --- a/plugins/fused-ops-and-kernels/tox.ini +++ b/plugins/fused-ops-and-kernels/tox.ini @@ -42,6 +42,7 @@ deps = commands = # exclude the code ported from unsloth black --exclude .*unsloth.* src + black --exclude .*liger.* src black --exclude .*unsloth.* tests isort . From dc075e350b6d11da621c682ab258d7ffe74082aa Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sat, 16 Nov 2024 14:58:05 +0000 Subject: [PATCH 11/14] fix fmt + lint Signed-off-by: Yu Chin Fabian Lim --- plugins/fused-ops-and-kernels/.isort.cfg | 3 ++- plugins/fused-ops-and-kernels/pyproject.toml | 8 ++++++++ plugins/fused-ops-and-kernels/tox.ini | 6 ++---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/plugins/fused-ops-and-kernels/.isort.cfg b/plugins/fused-ops-and-kernels/.isort.cfg index 4aa62fac..a6206218 100644 --- a/plugins/fused-ops-and-kernels/.isort.cfg +++ b/plugins/fused-ops-and-kernels/.isort.cfg @@ -10,4 +10,5 @@ known_firstparty= known_localfolder=tuning # skip code imported from unsloth -skip_glob=**/unsloth*/** +skip_glob=**/unsloth*/**, + **/liger*/** diff --git a/plugins/fused-ops-and-kernels/pyproject.toml b/plugins/fused-ops-and-kernels/pyproject.toml index d9acec60..516b2756 100644 --- a/plugins/fused-ops-and-kernels/pyproject.toml +++ b/plugins/fused-ops-and-kernels/pyproject.toml @@ -29,3 +29,11 @@ only-include = ["src/fms_acceleration_foak"] [tool.hatch.build.targets.wheel.sources] "src" = "" + +[tool.black] +force-exclude = ''' +/( +.*unsloth.* +| .*liger.* +)/ +''' diff --git a/plugins/fused-ops-and-kernels/tox.ini b/plugins/fused-ops-and-kernels/tox.ini index 37a66b45..b436fbe7 100644 --- a/plugins/fused-ops-and-kernels/tox.ini +++ b/plugins/fused-ops-and-kernels/tox.ini @@ -40,10 +40,8 @@ deps = black>=22.12 isort>=5.11 commands = - # exclude the code ported from unsloth - black --exclude .*unsloth.* src - black --exclude .*liger.* src - black --exclude .*unsloth.* tests + black src + black tests isort . [testenv:build] From a02a0a083f1b4048b5991890d86f0b98f20c8731 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Mon, 18 Nov 2024 01:09:52 +0000 Subject: [PATCH 12/14] update full benches and readme Signed-off-by: Yu Chin Fabian Lim --- plugins/fused-ops-and-kernels/README.md | 13 ++- scripts/benchmarks/refs/a100_80gb_liger.csv | 96 +++++++++++++++++++++ scripts/benchmarks/scenarios-liger.yaml | 8 +- 3 files changed, 112 insertions(+), 5 deletions(-) diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md index 0d66a357..4bfb8857 100644 --- a/plugins/fused-ops-and-kernels/README.md +++ b/plugins/fused-ops-and-kernels/README.md @@ -79,10 +79,21 @@ It is realtively easy by following an existing template, in what follows we use ) ``` +### Running Liger Kernel Benchmarks + +The benchmarks were ran seperately for each `num_gpu` entry; they can be run together in a single command, but this is more efficient. + +``` +tox -e run-benches -- 1 "4 8 16 32" benchmark_outputs_1 scenarios-liger.yaml none +tox -e run-benches 2 "8 16 32 64" benchmark_outputs_2 scenarios-liger.yaml none +tox -e run-benches 4 "16 32 64 128" benchmark_outputs_3 scenarios-liger.yaml none +``` + + ## Known Issues - MixedPrecision `--fp16` or `--bf16` should be used with `fast_lora`. - `fast_lora` has issues with FSDP V1 with the `peft` style of FSDP wrapping. * This is because the adapter's forward functions are bypassed in the fused ops. * For AutoGPTQ/QLoRA this is addressed by distributing the adapters using DDP so they will be unsharded in time for the fused ops. -- `fast_rope_embeddings` does not work with position_ids. Currently `position_ids` are ignored and could give wrong results. \ No newline at end of file +- `fast_rope_embeddings` does not work with `postion_ids`, it seems like HF has depracated passing these ids into the rope embedding methods. \ No newline at end of file diff --git a/scripts/benchmarks/refs/a100_80gb_liger.csv b/scripts/benchmarks/refs/a100_80gb_liger.csv index 871df7fb..e43666db 100644 --- a/scripts/benchmarks/refs/a100_80gb_liger.csv +++ b/scripts/benchmarks/refs/a100_80gb_liger.csv @@ -95,3 +95,99 @@ True,,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,78893.5,,,meta-llama/Meta ,1.35,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,36974.0,29016111104.0,1646677504.0,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.059765977859497,849.659,3.766,0.118,3856.606 ,2.7,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,65315.5,54120069120.0,1647463936.0,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,1.068314094543457,1676.21,3.818,0.06,3909.773 ,,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,79377.5,,,TechxGenus/Meta-Llama-3-8B-GPTQ,4,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +True,0.15,,foak-fast-kernels,2e-05,,,76071.0,72432723456.0,43466827264.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.838842716217041,479.0213,0.835,0.209,3420.307 +True,0.29,,foak-fast-kernels,2e-05,,,70035.0,72433116672.0,43467220480.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.8388796520233154,932.3966,0.858,0.107,3514.384 +True,,,foak-fast-kernels,2e-05,,,79167.0,,,mistralai/Mistral-7B-v0.1,1,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-05,,,73171.0,,,mistralai/Mistral-7B-v0.1,1,,32,,,bfloat16,,,,, +True,0.15,,foak-fast-kernels-liger,2e-05,,,70285.0,72432723456.0,43466827264.0,mistralai/Mistral-7B-v0.1,1,,4,,,bfloat16,0.8386862182617187,479.4765,0.834,0.209,3417.06 +True,0.29,,foak-fast-kernels-liger,2e-05,,,74829.0,72433116672.0,43467220480.0,mistralai/Mistral-7B-v0.1,1,,8,,,bfloat16,0.8387984752655029,931.6364,0.859,0.107,3517.252 +True,0.58,,foak-fast-kernels-liger,2e-05,,,79041.0,77144641024.0,43468006912.0,mistralai/Mistral-7B-v0.1,1,,16,,,bfloat16,0.8310897159576416,1837.2742,0.871,0.054,3567.023 +True,,,foak-fast-kernels-liger,2e-05,,,80539.0,,,mistralai/Mistral-7B-v0.1,1,,32,,,bfloat16,,,,, +True,0.15,,foak-fast-kernels,0.0002,16.0,0.1,27511.0,23530188288.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8601302146911621,420.488,0.951,0.238,3896.426 +True,0.29,,foak-fast-kernels,0.0002,16.0,0.1,40271.0,32393276928.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8624934101104736,832.3628,0.961,0.12,3936.745 +True,0.58,,foak-fast-kernels,0.0002,16.0,0.1,65793.0,50119454208.0,14665688576.0,mistralai/Mistral-7B-v0.1,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8582536506652833,1660.5425,0.964,0.06,3946.662 +True,,,foak-fast-kernels,0.0002,16.0,0.1,73377.0,,,mistralai/Mistral-7B-v0.1,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.15,,foak-fast-kernels-liger,0.0002,16.0,0.1,23833.0,23530188288.0,14664508928.0,mistralai/Mistral-7B-v0.1,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8599755954742432,419.5123,0.953,0.238,3905.488 +True,0.29,,foak-fast-kernels-liger,0.0002,16.0,0.1,32921.0,32393276928.0,14664902144.0,mistralai/Mistral-7B-v0.1,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8625451850891114,831.0694,0.963,0.12,3942.872 +True,0.58,,foak-fast-kernels-liger,0.0002,16.0,0.1,51099.0,50119454208.0,14665688576.0,mistralai/Mistral-7B-v0.1,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8574607944488526,1658.6224,0.965,0.06,3951.231 +True,,,foak-fast-kernels-liger,0.0002,16.0,0.1,78031.0,,,mistralai/Mistral-7B-v0.1,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.15,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,18809.0,13064809472.0,4306512384.0,mistralai/Mistral-7B-v0.1,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8750344657897949,403.2783,0.992,0.248,4062.703 +True,0.29,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,31953.0,21823466496.0,4306905600.0,mistralai/Mistral-7B-v0.1,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8779727077484131,786.5424,1.017,0.127,4166.082 +True,0.58,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,59137.0,39338994688.0,4307692032.0,mistralai/Mistral-7B-v0.1,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8731959533691406,1559.0449,1.026,0.064,4203.599 +True,1.16,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,78209.0,74371837952.0,4309264896.0,mistralai/Mistral-7B-v0.1,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.872735185623169,3097.7032,1.033,0.032,4231.264 +True,0.15,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,14171.0,13035024896.0,4306512384.0,mistralai/Mistral-7B-v0.1,1,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8749488735198975,401.8581,0.995,0.249,4077.061 +True,0.29,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,23323.0,21761798656.0,4306905600.0,mistralai/Mistral-7B-v0.1,1,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8784359264373779,785.9738,1.018,0.127,4169.096 +True,0.58,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,41883.0,39215346176.0,4307692032.0,mistralai/Mistral-7B-v0.1,1,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8728489208221436,1555.8043,1.028,0.064,4212.355 +True,1.16,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,79005.0,74122441216.0,4309264896.0,mistralai/Mistral-7B-v0.1,1,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8713663196563721,3091.0715,1.035,0.032,4240.342 +,0.15,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,18553.0,13095119872.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,0.9928509044647216,414.9265,0.964,0.241,3948.652 +,0.29,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,32337.0,21853776896.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,0.9894316577911376,816.7654,0.979,0.122,4011.923 +,0.58,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,59649.0,39369305088.0,4338002432.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,0.982344207763672,1619.3601,0.988,0.062,4047.031 +,1.15,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,78475.0,74402148352.0,4339575296.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,0.9788006114959716,3190.2664,1.003,0.031,4108.497 +,0.15,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,14235.0,13065335808.0,4336822784.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.0111017036437988,414.9576,0.964,0.241,3948.355 +,0.29,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,23707.0,21792109568.0,4337216000.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,0.9891231441497804,817.2045,0.979,0.122,4009.768 +,0.58,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,42395.0,39245657088.0,4338002432.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,0.9979404735565186,1618.9493,0.988,0.062,4048.057 +,1.15,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,79517.0,74152752128.0,4339575296.0,TheBloke/Mistral-7B-v0.1-GPTQ,1,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,1.0506530570983887,3184.3913,1.005,0.031,4116.077 +True,0.29,,foak-fast-kernels,2e-05,,,75537.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8286205387115478,504.3425,1.586,0.198,3248.586 +True,0.58,,foak-fast-kernels,2e-05,,,81209.0,72435247104.0,57951766528.0,mistralai/Mistral-7B-v0.1,2,,8,,,bfloat16,0.8197736072540284,1094.8571,1.461,0.091,2992.902 +True,,,foak-fast-kernels,2e-05,,,80994.0,,,mistralai/Mistral-7B-v0.1,2,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels,2e-05,,,72883.0,,,mistralai/Mistral-7B-v0.1,2,,32,,,bfloat16,,,,, +True,0.29,,foak-fast-kernels-liger,2e-05,,,78251.0,72434853888.0,57951373312.0,mistralai/Mistral-7B-v0.1,2,,4,,,bfloat16,0.8286735343933106,503.738,1.588,0.199,3252.484 +True,0.58,,foak-fast-kernels-liger,2e-05,,,79908.0,72435247104.0,57951766528.0,mistralai/Mistral-7B-v0.1,2,,8,,,bfloat16,0.8198539209365845,959.9081,1.667,0.104,3413.66 +True,,,foak-fast-kernels-liger,2e-05,,,80553.0,,,mistralai/Mistral-7B-v0.1,2,,16,,,bfloat16,,,,, +True,,,foak-fast-kernels-liger,2e-05,,,78785.0,,,mistralai/Mistral-7B-v0.1,2,,32,,,bfloat16,,,,, +True,0.29,,foak-fast-kernels,0.0002,16.0,0.1,23845.0,21219418112.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8595081615447998,446.6717,1.791,0.224,3668.019 +True,0.58,,foak-fast-kernels,0.0002,16.0,0.1,38141.0,34109038592.0,7368636416.0,mistralai/Mistral-7B-v0.1,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8571683597564698,875.4589,1.828,0.114,3742.951 +True,1.16,,foak-fast-kernels,0.0002,16.0,0.1,66197.0,59888279552.0,7369422848.0,mistralai/Mistral-7B-v0.1,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8570511913299561,1727.5688,1.852,0.058,3793.539 +True,,,foak-fast-kernels,0.0002,16.0,0.1,80157.0,,,mistralai/Mistral-7B-v0.1,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.29,,foak-fast-kernels-liger,0.0002,16.0,0.1,24391.0,21219418112.0,7368243200.0,mistralai/Mistral-7B-v0.1,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.859750566482544,446.2936,1.793,0.224,3671.126 +True,0.58,,foak-fast-kernels-liger,0.0002,16.0,0.1,38599.0,34109038592.0,7368636416.0,mistralai/Mistral-7B-v0.1,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8571057224273682,874.4754,1.83,0.114,3747.161 +True,1.16,,foak-fast-kernels-liger,0.0002,16.0,0.1,67119.0,59888279552.0,7369422848.0,mistralai/Mistral-7B-v0.1,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8565524101257325,1726.117,1.854,0.058,3796.73 +True,,,foak-fast-kernels-liger,0.0002,16.0,0.1,78998.0,,,mistralai/Mistral-7B-v0.1,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.29,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,19778.0,15685985280.0,2244738048.0,mistralai/Mistral-7B-v0.1,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8795236587524414,412.1764,1.941,0.243,3974.997 +True,0.58,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,33700.0,28439290880.0,2245131264.0,mistralai/Mistral-7B-v0.1,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8764357566833496,798.0677,2.005,0.125,4105.917 +True,1.16,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,61093.0,53945902080.0,2245917696.0,mistralai/Mistral-7B-v0.1,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8760993480682373,1576.0592,2.03,0.063,4158.22 +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,79604.0,,,mistralai/Mistral-7B-v0.1,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.29,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,19963.0,15685985280.0,2244738048.0,mistralai/Mistral-7B-v0.1,2,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.878936014175415,411.7593,1.943,0.243,3979.023 +True,0.58,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,34082.0,28439290880.0,2245131264.0,mistralai/Mistral-7B-v0.1,2,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8762027359008789,797.1272,2.007,0.125,4110.762 +True,1.16,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,61730.0,53945902080.0,2245917696.0,mistralai/Mistral-7B-v0.1,2,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8754921627044677,1573.6818,2.033,0.064,4164.501 +True,,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,79707.0,,,mistralai/Mistral-7B-v0.1,2,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +,0.29,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,20581.0,15703516672.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.1056310272216796,425.0725,1.882,0.235,3854.401 +,0.58,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,34556.0,28456822272.0,2261810176.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,0.9874585056304932,830.9632,1.925,0.12,3943.375 +,1.15,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,61815.0,53963433472.0,2262596608.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,0.9895050239562988,1637.2755,1.954,0.061,4002.747 +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,78173.0,,,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,0.29,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,20743.0,15703516672.0,2261416960.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.0293195629119871,425.1526,1.882,0.235,3853.675 +,0.58,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,34887.0,28456822272.0,2261810176.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,0.98430645942688,830.265,1.927,0.12,3946.692 +,1.15,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,62481.0,53963433472.0,2262596608.0,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,1.004049482345581,1635.4433,1.957,0.061,4007.232 +,,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,79547.0,,,TheBloke/Mistral-7B-v0.1-GPTQ,2,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +True,0.58,,foak-fast-kernels,2e-05,,,51099.5,36226152448.0,28984412160.0,mistralai/Mistral-7B-v0.1,4,,4,,,bfloat16,0.8197227716445923,492.7218,3.247,0.203,3325.203 +True,1.16,,foak-fast-kernels,2e-05,,,58746.5,39783938560.0,28984805376.0,mistralai/Mistral-7B-v0.1,4,,8,,,bfloat16,0.810418028831482,948.1775,3.375,0.105,3455.893 +True,2.33,,foak-fast-kernels,2e-05,,,76774.0,57299466752.0,28985591808.0,mistralai/Mistral-7B-v0.1,4,,16,,,bfloat16,0.7932196092605591,1860.555,3.44,0.054,3522.39 +True,,,foak-fast-kernels,2e-05,,,81199.0,,,mistralai/Mistral-7B-v0.1,4,,32,,,bfloat16,,,,, +True,0.58,,foak-fast-kernels-liger,2e-05,,,47698.5,36226152448.0,28984412160.0,mistralai/Mistral-7B-v0.1,4,,4,,,bfloat16,0.8197373056411743,492.5593,3.248,0.203,3326.3 +True,1.16,,foak-fast-kernels-liger,2e-05,,,55799.5,39732781056.0,28984805376.0,mistralai/Mistral-7B-v0.1,4,,8,,,bfloat16,0.8104191637039184,948.2377,3.375,0.105,3455.674 +True,2.33,,foak-fast-kernels-liger,2e-05,,,71515.5,56381022208.0,28985591808.0,mistralai/Mistral-7B-v0.1,4,,16,,,bfloat16,0.7934608507156372,1858.6873,3.443,0.054,3525.929 +True,,,foak-fast-kernels-liger,2e-05,,,81213.0,,,mistralai/Mistral-7B-v0.1,4,,32,,,bfloat16,,,,, +True,0.58,,foak-fast-kernels,0.0002,16.0,0.1,20758.0,17544448000.0,3692847104.0,mistralai/Mistral-7B-v0.1,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8567444038391113,446.5088,3.583,0.224,3669.356 +True,1.16,,foak-fast-kernels,0.0002,16.0,0.1,34894.0,30434068480.0,3693240320.0,mistralai/Mistral-7B-v0.1,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8568990898132324,868.1502,3.686,0.115,3774.462 +True,2.33,,foak-fast-kernels,0.0002,16.0,0.1,62814.0,56213309440.0,3694026752.0,mistralai/Mistral-7B-v0.1,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8558178234100342,1715.4042,3.731,0.058,3820.441 +True,,,foak-fast-kernels,0.0002,16.0,0.1,80781.0,,,mistralai/Mistral-7B-v0.1,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.58,,foak-fast-kernels-liger,0.0002,16.0,0.1,21304.0,17544448000.0,3692847104.0,mistralai/Mistral-7B-v0.1,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8567856788635254,445.9824,3.588,0.224,3673.687 +True,1.16,,foak-fast-kernels-liger,0.0002,16.0,0.1,35551.0,30434068480.0,3693240320.0,mistralai/Mistral-7B-v0.1,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8569093990325928,867.0898,3.691,0.115,3779.078 +True,2.33,,foak-fast-kernels-liger,0.0002,16.0,0.1,63711.0,56213309440.0,3694026752.0,mistralai/Mistral-7B-v0.1,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8561593055725097,1712.623,3.737,0.058,3826.645 +True,,,foak-fast-kernels-liger,0.0002,16.0,0.1,81094.0,,,mistralai/Mistral-7B-v0.1,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.58,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,19553.5,14682360832.0,1241113600.0,mistralai/Mistral-7B-v0.1,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.873923568725586,414.3866,3.861,0.241,3953.796 +True,1.16,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,33586.0,27435666432.0,1241506816.0,mistralai/Mistral-7B-v0.1,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8746437931060791,800.1313,3.999,0.125,4095.328 +True,2.33,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,61156.0,52942277632.0,1242293248.0,mistralai/Mistral-7B-v0.1,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.873357572555542,1576.6662,4.059,0.063,4156.619 +True,,,accelerated-peft-bnb-foak,0.0002,16.0,0.1,80633.0,,,mistralai/Mistral-7B-v0.1,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +True,0.58,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,19793.5,14682360832.0,1241113600.0,mistralai/Mistral-7B-v0.1,4,lora,4,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8740975952148438,413.8023,3.867,0.242,3959.378 +True,1.16,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,33930.0,27435666432.0,1241506816.0,mistralai/Mistral-7B-v0.1,4,lora,8,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8744520854949951,799.4148,4.003,0.125,4098.999 +True,2.33,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,61638.5,52942277632.0,1242293248.0,mistralai/Mistral-7B-v0.1,4,lora,16,16.0,q_proj k_proj v_proj o_proj,bfloat16,0.8735318374633789,1575.5873,4.062,0.063,4159.465 +True,,,accelerated-peft-bnb-foak-liger,0.0002,16.0,0.1,80901.0,,,mistralai/Mistral-7B-v0.1,4,lora,32,16.0,q_proj k_proj v_proj o_proj,bfloat16,,,,, +,0.58,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,21609.0,14693076480.0,1250976768.0,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,0.9958928966522216,427.3794,3.744,0.234,3833.596 +,1.15,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,35645.0,27446382080.0,1251369984.0,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.017822380065918,830.5326,3.853,0.12,3945.42 +,2.27,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,62819.5,52952993280.0,1252156416.0,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,1.0091158390045163,1638.8153,3.905,0.061,3998.986 +,,True,accelerated-peft-autogptq-foak,0.0002,16.0,0.1,78291.0,,,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, +,0.58,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,21592.5,14693076480.0,1250976768.0,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,4,16.0,q_proj k_proj v_proj o_proj,float16,1.024254894256592,428.6955,3.732,0.233,3821.827 +,1.15,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,35828.5,27446382080.0,1251369984.0,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,8,16.0,q_proj k_proj v_proj o_proj,float16,1.007179250717163,830.4849,3.853,0.12,3945.647 +,2.27,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,63458.0,52952993280.0,1252156416.0,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,16,16.0,q_proj k_proj v_proj o_proj,float16,0.996524419784546,1636.0771,3.912,0.061,4005.679 +,,True,accelerated-peft-autogptq-foak-liger,0.0002,16.0,0.1,80537.0,,,TheBloke/Mistral-7B-v0.1-GPTQ,4,lora,32,16.0,q_proj k_proj v_proj o_proj,float16,,,,, diff --git a/scripts/benchmarks/scenarios-liger.yaml b/scripts/benchmarks/scenarios-liger.yaml index cdd026d2..8004441d 100644 --- a/scripts/benchmarks/scenarios-liger.yaml +++ b/scripts/benchmarks/scenarios-liger.yaml @@ -43,7 +43,7 @@ scenarios: arguments: learning_rate: 2e-5 model_name_or_path: - # - 'mistralai/Mistral-7B-v0.1' + - 'mistralai/Mistral-7B-v0.1' - 'meta-llama/Meta-Llama-3-8B' torch_dtype: bfloat16 bf16: True @@ -62,7 +62,7 @@ scenarios: lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] model_name_or_path: - # - 'mistralai/Mistral-7B-v0.1' + - 'mistralai/Mistral-7B-v0.1' - 'meta-llama/Meta-Llama-3-8B' - name: accelerated-peft-bnb @@ -80,7 +80,7 @@ scenarios: per_device_train_batch_size: target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] model_name_or_path: - # - 'mistralai/Mistral-7B-v0.1' + - 'mistralai/Mistral-7B-v0.1' - 'meta-llama/Meta-Llama-3-8B' - name: accelerated-peft-gptq @@ -97,5 +97,5 @@ scenarios: lora_dropout: 0.1 target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] model_name_or_path: - # - 'TheBloke/Mistral-7B-v0.1-GPTQ' + - 'TheBloke/Mistral-7B-v0.1-GPTQ' - 'TechxGenus/Meta-Llama-3-8B-GPTQ' From 1a693148ec55a9a11f037fc84d2ed69aadc27a38 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Sun, 17 Nov 2024 16:42:54 +0000 Subject: [PATCH 13/14] fix fast foak configs Signed-off-by: Yu Chin Fabian Lim --- .../src/fms_acceleration/framework_plugin.py | 2 +- .../configs/fast_kernels.yaml | 3 -- .../configs/fast_kernels_liger.yaml | 5 +--- .../configs/fast_quantized_peft_liger.yaml | 5 +--- .../framework_plugin_fast_kernels.py | 30 +++++++------------ ...ogptq-foak-liger-sample-configuration.yaml | 5 +--- ...b-nf4-foak-liger-sample-configuration.yaml | 5 +--- ...st-kernels-liger-sample-configuration.yaml | 5 +--- ...oak-fast-kernels-sample-configuration.yaml | 3 -- scripts/generate_sample_configurations.py | 2 ++ 10 files changed, 18 insertions(+), 47 deletions(-) diff --git a/plugins/framework/src/fms_acceleration/framework_plugin.py b/plugins/framework/src/fms_acceleration/framework_plugin.py index cf1764d5..28fecebf 100644 --- a/plugins/framework/src/fms_acceleration/framework_plugin.py +++ b/plugins/framework/src/fms_acceleration/framework_plugin.py @@ -206,7 +206,7 @@ def _check_config_and_maybe_check_values( t = list(t.keys())[0] # otherwise take the first value if t not in values: - if default is None: + if t is not None or default is None: raise AccelerationPluginConfigError( f"{self.__class__.__name__}: Value at '{key}' was '{t}'. " f"Not found in expected set '{values}'." diff --git a/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml b/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml index 823af26f..45f0051e 100644 --- a/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml +++ b/plugins/fused-ops-and-kernels/configs/fast_kernels.yaml @@ -23,6 +23,3 @@ training: # fast RoPE embedding triton kernels fast_rope_embeddings: True - - # fused linear cross entropy loss - fused_linear_loss: False \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml b/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml index 8011db78..a154b95b 100644 --- a/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml +++ b/plugins/fused-ops-and-kernels/configs/fast_kernels_liger.yaml @@ -16,13 +16,10 @@ training: # - the FastQuantized version is all-or-nothing # fast loss triton kernels - fast_loss: False + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rms_layernorm: True # fast RoPE embedding triton kernels fast_rope_embeddings: True - - # fused linear cross entropy loss - fused_linear_loss: True \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml index 7f239849..c6655d34 100644 --- a/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml +++ b/plugins/fused-ops-and-kernels/configs/fast_quantized_peft_liger.yaml @@ -21,13 +21,10 @@ peft: fused_lora: True # fast loss triton kernels - fast_loss: False + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rsm_layernorm: True # fast RoPE embedding triton kernels fast_rope_embeddings: True - - # fused linear cross entropy loss - fused_linear_loss: True \ No newline at end of file diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py index 049b26d4..7948a98c 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py @@ -27,13 +27,6 @@ from .utils import lora_adapters_switch_ddp_from_fsdp -def validate_plugin_args(configurations): - # Consider making this a more graceful fallback? - assert ( - configurations["fused_linear_loss"] != configurations["fast_loss"] - ), "If using `fused_linear_loss`, `fast_loss` must be set to False" - - # consider rewriting register_foak_model_patch_rules into something # like this also def register_foak_model_patch_rules( @@ -80,10 +73,12 @@ def register_foak_model_patch_rules( # maybe this we should define envvars FILTER_MAP = { "fused_lora": {"qkvo", "mlp"}, - "fast_loss": "cross-ent", + "fast_loss": { + True: "cross-ent", + "fused_ce_liger": "fused-lce", + }, "fast_rms_layernorm": "rms", "fast_rope_embeddings": "rope", - "fused_linear_loss": "fused-lce", } @@ -117,29 +112,22 @@ def __init__(self, configurations: Dict[str, Dict]): key="base_layer", values=["auto_gptq", "bitsandbytes"], default="auto_gptq" ) self.configurations["fused_lora"] = self._check_config_and_maybe_check_values( - key="fused_lora", values=[False, True], default=True + key="fused_lora", values=[False, True], default=False ) self.configurations["fast_loss"] = self._check_config_and_maybe_check_values( - key="fast_loss", values=[False, True], default=True + key="fast_loss", values=[False, True, "fused_ce_liger"], default=False ) self.configurations["fast_rms_layernorm"] = ( self._check_config_and_maybe_check_values( - key="fast_rms_layernorm", values=[False, True], default=True + key="fast_rms_layernorm", values=[False, True], default=False ) ) self.configurations["fast_rope_embeddings"] = ( self._check_config_and_maybe_check_values( - key="fast_rope_embeddings", values=[False, True], default=True - ) - ) - self.configurations["fused_linear_loss"] = ( - self._check_config_and_maybe_check_values( - key="fused_linear_loss", values=[False, True], default=False + key="fast_rope_embeddings", values=[False, True], default=False ) ) - validate_plugin_args(self.configurations) - @property def requires_agumentation(self): return True @@ -177,6 +165,8 @@ def augmentation( if k in FILTER_MAP and k not in omitted: ts = FILTER_MAP[k] + if isinstance(ts, dict) and v in ts: + ts = ts[v] if isinstance(ts, str): ts = {ts} diff --git a/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml b/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml index 1abc5a11..1126b4f8 100644 --- a/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-autogptq-foak-liger-sample-configuration.yaml @@ -43,13 +43,10 @@ plugins: fused_lora: true # fast loss triton kernels - fast_loss: false + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rsm_layernorm: true # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: true diff --git a/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml b/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml index 4376182e..71c305ac 100644 --- a/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml +++ b/sample-configurations/accelerated-peft-bnb-nf4-foak-liger-sample-configuration.yaml @@ -38,13 +38,10 @@ plugins: fused_lora: true # fast loss triton kernels - fast_loss: false + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rsm_layernorm: true # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: true diff --git a/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml b/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml index 7002026a..1752755f 100644 --- a/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml +++ b/sample-configurations/foak-fast-kernels-liger-sample-configuration.yaml @@ -21,13 +21,10 @@ plugins: # - the FastQuantized version is all-or-nothing # fast loss triton kernels - fast_loss: false + fast_loss: fused_ce_liger # fast rms norm triton kernels fast_rms_layernorm: true # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: true diff --git a/sample-configurations/foak-fast-kernels-sample-configuration.yaml b/sample-configurations/foak-fast-kernels-sample-configuration.yaml index ba7669aa..b9d646b6 100644 --- a/sample-configurations/foak-fast-kernels-sample-configuration.yaml +++ b/sample-configurations/foak-fast-kernels-sample-configuration.yaml @@ -28,6 +28,3 @@ plugins: # fast RoPE embedding triton kernels fast_rope_embeddings: true - - # fused linear cross entropy loss - fused_linear_loss: false diff --git a/scripts/generate_sample_configurations.py b/scripts/generate_sample_configurations.py index 157f55bb..6232dce6 100644 --- a/scripts/generate_sample_configurations.py +++ b/scripts/generate_sample_configurations.py @@ -224,7 +224,9 @@ def read_configuration(path: str) -> Dict: ("accelerated-peft-bnb-nf4-foak-padding-free", (KEY_AADP_PADDING_FREE,KEY_BNB_NF4, KEY_BNB_NF4_FOAK)), ("aadp-padding-free-multipack", (KEY_AADP_PADDING_FREE, KEY_AADP_MULTIPACK)), ("foak-fast-kernels", (KEY_FAST_KERNELS,)), + ("foak-fast-kernels-liger", (KEY_FAST_KERNELS_LIGER,)), ("moe-scattermoe-granite-ep1", (KEY_SCATTERMOE_EP1,)), + ("moe-scattermoe-granite-ep1-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP1,)), ("moe-scattermoe-granite-ep1-padding-free-foak", (KEY_AADP_PADDING_FREE, KEY_FAST_KERNELS, KEY_SCATTERMOE_EP1,)), ("moe-scattermoe-granite-ep2", (KEY_SCATTERMOE_EP2,)), ("moe-scattermoe-granite-ep2-padding-free", (KEY_AADP_PADDING_FREE, KEY_SCATTERMOE_EP2,)), From 45951372ea86b3e7b29780c641c163ef2654b6a5 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Mon, 2 Dec 2024 15:53:19 -0700 Subject: [PATCH 14/14] docs: update foak readme benchmarks Signed-off-by: Anh Uong --- plugins/fused-ops-and-kernels/README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/plugins/fused-ops-and-kernels/README.md b/plugins/fused-ops-and-kernels/README.md index 4bfb8857..0331c550 100644 --- a/plugins/fused-ops-and-kernels/README.md +++ b/plugins/fused-ops-and-kernels/README.md @@ -81,12 +81,14 @@ It is realtively easy by following an existing template, in what follows we use ### Running Liger Kernel Benchmarks -The benchmarks were ran seperately for each `num_gpu` entry; they can be run together in a single command, but this is more efficient. +Using the [scenarios-liger.yaml](../../scripts/benchmarks/scenarios-liger.yaml), this will run full fine tuning, lora peft, autoGPTQ lora peft, and bits-and-bytes lora peft with the triton kernels (Fast RMS, RoPE, CrossEnt) as a base and then run with the liger kernel for LigerFusedLinearCrossEntropy as well as Fast RMS, RoPE to compare results. It only runs against mistral and llama models. -``` +The benchmarks were ran separately for each `num_gpu` entry; they can be run together in a single command, but this is more efficient. + +```sh tox -e run-benches -- 1 "4 8 16 32" benchmark_outputs_1 scenarios-liger.yaml none -tox -e run-benches 2 "8 16 32 64" benchmark_outputs_2 scenarios-liger.yaml none -tox -e run-benches 4 "16 32 64 128" benchmark_outputs_3 scenarios-liger.yaml none +tox -e run-benches -- 2 "8 16 32 64" benchmark_outputs_2 scenarios-liger.yaml none +tox -e run-benches -- 4 "16 32 64 128" benchmark_outputs_3 scenarios-liger.yaml none ```