Skip to content

Commit

Permalink
Explicitly specify SUPPORTS_DEVICES
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Dec 4, 2024
1 parent c0f8b11 commit 9b44f80
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 3 deletions.
6 changes: 4 additions & 2 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BaseQuantLinear(nn.Module):
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS: bool = True
SUPPORTS_DEVICES = [DEVICE.CUDA]
SUPPORTS_DEVICES = [] # Empty or None means no device is supported.
# empty which means all
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = []
# empty which means all
Expand Down Expand Up @@ -101,7 +101,9 @@ def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynami
@classmethod
def validate_device(cls, device_type: str):
device = get_device_by_type(device_type)
if cls.SUPPORTS_DEVICES and device not in cls.SUPPORTS_DEVICES:
if cls.SUPPORTS_DEVICES is None or len(cls.SUPPORTS_DEVICES) == 0:
raise NotImplementedError(f"{cls} does not support any devices, SUPPORTS_DEVICES is `{cls.SUPPORTS_DEVICES}`.")
if device not in cls.SUPPORTS_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTS_DEVICES}` bits: actual device = `{device}`")

# override me
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

from ...utils.logger import setup_logger
from ...models._const import DEVICE

logger = setup_logger()

Expand Down Expand Up @@ -78,6 +79,7 @@ class BitBLASQuantLinear(BaseQuantLinear):
SUPPORTS_DESC_ACT = [False]
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [16]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [16]
SUPPORTS_DEVICES = [DEVICE.CUDA]

OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512]
zeros_mode = "quantized" # "original" or "rescale" or "quantized"
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.utils.logger import setup_logger
from gptqmodel.nn_modules.qlinear.qlinear_torch import TorchQuantLinear
from ...models._const import DEVICE

logger = setup_logger()

Expand All @@ -21,6 +22,7 @@

class CudaQuantLinear(TorchQuantLinear):
SUPPORTS_BITS = [2, 3, 4, 8]
SUPPORTS_DEVICES = [DEVICE.CUDA]

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from ...models._const import DEVICE

exllama_import_exception = None
try:
Expand Down Expand Up @@ -43,6 +44,7 @@ class ExllamaQuantLinear(BaseQuantLinear):
SUPPORTS_BITS = [4]
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_DEVICES = [DEVICE.CUDA]

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

from ...utils.logger import setup_logger
from ...models._const import DEVICE

exllama_v2_import_exception = None
try:
Expand Down Expand Up @@ -105,6 +106,7 @@ class ExllamaV2QuantLinear(BaseQuantLinear):
SUPPORTS_BITS = [4]
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_DEVICES = [DEVICE.CUDA]

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch.nn.parameter import Parameter

from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from ...models._const import DEVICE

marlin_import_exception = None
try:
Expand Down Expand Up @@ -143,6 +144,7 @@ class MarlinQuantLinear(BaseQuantLinear):
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64]
SUPPORTS_DEVICES = [DEVICE.CUDA]

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/nn_modules/qlinear/qlinear_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from gptqmodel.models._const import DEVICE
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.utils.logger import setup_logger
from ...models._const import DEVICE

logger = setup_logger()

class TorchQuantLinear(BaseQuantLinear):
SUPPORTS_BITS = [2, 3, 4, 8]
SUPPORTS_DEVICES = [] # empty means all devices are supported.
SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.XPU, DEVICE.CUDA]

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...utils.logger import setup_logger
from ..triton_utils.mixin import TritonModuleMixin
from . import BaseQuantLinear
from ...models._const import DEVICE

try:
from triton import __version__ as triton_version
Expand All @@ -29,6 +30,7 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
SUPPORTS_BITS = [2, 4, 8]
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_DEVICES = [DEVICE.CUDA]

"""
Triton v2 quantized linear layer.
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class BACKEND(Enum):
VLLM = 7
SGLANG = 8
CUDA = 9
TORCH = 10

def get_backend(backend: str):
try:
Expand Down

0 comments on commit 9b44f80

Please sign in to comment.