Skip to content

Commit

Permalink
[LoRA][Kernel] Remove the unused libentry module (vllm-project#10214)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Nov 11, 2024
1 parent 58170d6 commit 36e4acd
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 276 deletions.
73 changes: 24 additions & 49 deletions tests/lora/test_punica_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
from unittest.mock import patch

import pytest
import torch

Expand All @@ -16,7 +14,6 @@
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry

from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
Expand Down Expand Up @@ -235,9 +232,6 @@ def test_punica_bgmv(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)

Expand All @@ -262,33 +256,21 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
Expand Down Expand Up @@ -324,7 +306,6 @@ def test_punica_expand_nslices(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)
Expand Down Expand Up @@ -374,22 +355,16 @@ def test_punica_expand_nslices(
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)

bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
Expand Down
73 changes: 24 additions & 49 deletions tests/lora/test_punica_variation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
under different conditions, including various batches, numbers of LoRA , and
maximum ranks.
"""
from unittest.mock import patch

import pytest
import torch

Expand All @@ -15,7 +13,6 @@
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.platforms import current_platform
from vllm.triton_utils.libentry import LibEntry

from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
Expand Down Expand Up @@ -150,8 +147,6 @@ def test_punica_bgmv(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)
Expand All @@ -177,33 +172,22 @@ def test_punica_bgmv(
device,
)
if op_type == "shrink":
# The current _bgmv_shrink_kernel does not require the libentry
# decoration. The purpose of adding this patch is to test the
# correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
LibEntry(_bgmv_shrink_kernel),
):
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
bgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
scaling,
)
else:
# ditto
with patch(
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
LibEntry(_bgmv_expand_kernel),
):
bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)

bgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
indices,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
Expand Down Expand Up @@ -239,8 +223,6 @@ def test_punica_expand_nslices(
seed: int,
device: str,
):
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel

torch.set_default_device(device)
current_platform.seed_everything(seed)

Expand Down Expand Up @@ -289,22 +271,15 @@ def test_punica_expand_nslices(
add_inputs=True,
)
else:
# The current _bgmv_expand_slice_kernel does not require the
# libentry decoration. The purpose of adding this patch is to test
# the correctness of libentry.
with patch(
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
LibEntry(_bgmv_expand_slice_kernel),
):
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,
Expand Down
3 changes: 0 additions & 3 deletions vllm/lora/ops/sgmv_expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import triton
import triton.language as tl

from vllm.triton_utils import libentry


@libentry()
@triton.jit
def _sgmv_expand_kernel(
input_ptr,
Expand Down
3 changes: 0 additions & 3 deletions vllm/lora/ops/sgmv_expand_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import triton
import triton.language as tl

from vllm.triton_utils import libentry


@libentry()
@triton.jit
def _sgmv_expand_slice_kernel(
input_ptr,
Expand Down
3 changes: 0 additions & 3 deletions vllm/lora/ops/sgmv_shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import triton
import triton.language as tl

from vllm.triton_utils import libentry


@libentry()
@triton.jit
def _sgmv_shrink_kernel(
input_ptr,
Expand Down
3 changes: 1 addition & 2 deletions vllm/triton_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@

from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
from vllm.triton_utils.libentry import libentry

__all__ += ["maybe_set_triton_cache_manager", "libentry"]
__all__ += ["maybe_set_triton_cache_manager"]
Loading

0 comments on commit 36e4acd

Please sign in to comment.