diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py index a23163542..411240b31 100644 --- a/gptqmodel/models/_const.py +++ b/gptqmodel/models/_const.py @@ -1,8 +1,13 @@ from torch import device CPU = device("cpu") +CUDA = device("cuda") CUDA_0 = device("cuda:0") + +DEVICE_TYPE_CPU = "cuda" +DEVICE_TYPE_CUDA = "cuda" + SUPPORTED_MODELS = [ "bloom", "gptj", diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index e61354abe..e8d8ceaba 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -33,7 +33,7 @@ gptqmodel_post_init, make_quant, move_to, nested_move_to, pack_model, simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes) from ..version import __version__ -from ._const import CPU, CUDA_0, SUPPORTED_MODELS +from ._const import CPU, DEVICE_TYPE_CUDA, CUDA_0, SUPPORTED_MODELS logger = logging.getLogger(__name__) handler = logging.StreamHandler() @@ -754,7 +754,7 @@ def from_quantized( check_cuda() if backend == Backend.QBITS: - device = torch.device("cpu") + device = CPU try: pass except Exception as e: @@ -961,7 +961,7 @@ def skip(*args, **kwargs): if device is not None: device = torch.device(device) if not max_memory and not device_map: - device_map = {"": device.index if device.type == "cuda" else device.type} + device_map = {"": device.index if device.type == DEVICE_TYPE_CUDA else device.type} if not isinstance(device_map, dict) and device_map != "sequential": max_memory = accelerate.utils.get_balanced_memory( model=model, diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 50337f8ee..7854e3b21 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -1,17 +1,31 @@ import torch.nn as nn +from ...models._const import DEVICE_TYPE_CUDA -class BaseQuantLinear(nn.Module): +class BaseQuantLinear(nn.Module): SUPPORTED_BITS = [] SUPPORTED_GROUP_SIZE = [] SUPPORTED_DESC_ACT = [True, False] SUPPORTED_SYM = [True, False] SUPPORTED_SHARDS: bool = True + SUPPORTED_DEVICES = [DEVICE_TYPE_CUDA] + + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, *args, **kwargs): + super().__init__() + _, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym) + if err: + raise NotImplementedError(err) @classmethod - def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_error: bool = True) -> bool: + def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool) -> bool: + validate, _ = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym) + return validate + + @classmethod + def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, ): validate = True + print("cccc",cls, cls.SUPPORTED_BITS) err = "" if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS: validate = False @@ -25,11 +39,12 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_e elif cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT: validate = False err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`" + return validate, err - if not validate and raise_error: - raise NotImplementedError(err) - - return validate + @classmethod + def validate_device(cls, device_type: str): + if cls.SUPPORTED_DEVICES and device_type not in cls.SUPPORTED_DEVICES: + raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTED_DEVICES}` bits: actual device = `{device_type}`") # override me def post_init(self): diff --git a/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py b/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py index 1b264fdb5..35ea18c84 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_bitblas.py @@ -92,8 +92,8 @@ def __init__( self, bits: int, group_size: int, - sym: bool, desc_act: bool, + sym: bool, infeatures: int, outfeatures: int, bias: bool, @@ -104,13 +104,11 @@ def __init__( layout: str = "nt", **kwargs, ): - super().__init__() + super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs) # TODO: remove delayed import after bitblas whl support for 11.7, 11.8, 12.0 are added import_bitblas() - self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act) - self._validate_parameters(group_size, infeatures, outfeatures) self.bits = bits @@ -243,6 +241,7 @@ def reset_parameters(self): self.q_params = None def post_init(self): + self.validate_device(self.qweight.device.type) # eliminate runtime overhead like exllama state param_list = [self.qweight, self.scales, self.zeros] if self.bitblas_matmul.config.with_bias: diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py index 26842fed6..469bc867f 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllama.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllama.py @@ -10,6 +10,7 @@ import transformers from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel_exllama_kernels import make_q4, q4_matmul +from gptqmodel.models._const import CUDA logger = getLogger(__name__) @@ -40,9 +41,8 @@ class ExllamaQuantLinear(BaseQuantLinear): """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs,): - super().__init__() - self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act) + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, device: torch.device, infeatures: int, outfeatures: int, bias: bool, **kwargs,): + super().__init__(bits, group_size, sym, desc_act, device, **kwargs) self.bits = bits self.group_size = group_size if group_size != -1 else infeatures @@ -92,7 +92,7 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat self.bias = None def post_init(self): - assert self.qweight.device.type == "cuda" + self.validate_device(self.qweight.device.type) assert self.qweight.device.index is not None # resize due to padding after model weights have been loaded diff --git a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py index ae9a1b0f6..6e32039f1 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py @@ -100,10 +100,9 @@ class ExllamaV2QuantLinear(BaseQuantLinear): """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeatures: int, outfeatures: int, + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs,): - super().__init__() - self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act) + super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs) self.q_handle = None self.q_tensors = None @@ -156,7 +155,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat self.bias = None def post_init(self, temp_dq): - assert self.qweight.device.type == "cuda" + self.validate_device(self.qweight.device.type) assert self.qweight.device.index is not None # resize due to padding after model weights have been loaded diff --git a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py index 41d223d25..90dc215a9 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_marlin.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_marlin.py @@ -67,10 +67,9 @@ class MarlinQuantLinear(BaseQuantLinear): SUPPORTED_DESC_ACT = [False] SUPPORTED_SYM = [True] - def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeatures: int, outfeatures: int, + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs): - super().__init__() - self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act) + super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs) if not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( @@ -170,6 +169,9 @@ def forward(self, A): C = C + self.bias if self.bias is not None else C return C + def post_init(self): + self.validate_device(self.B.device.type) + # Copied from https://github.com/IST-DASLab/marlin/pull/1 @torch.no_grad() diff --git a/gptqmodel/nn_modules/qlinear/qlinear_qbits.py b/gptqmodel/nn_modules/qlinear/qlinear_qbits.py index bcd6ce74d..4996ab3ca 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_qbits.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_qbits.py @@ -6,6 +6,7 @@ import torch.nn as nn import transformers from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.models._const import DEVICE_TYPE_CPU logger = getLogger(__name__) @@ -38,11 +39,13 @@ def convert_dtype_torch2str(dtype): class QBitsQuantLinear(BaseQuantLinear): SUPPORTED_BITS = [4, 8] + SUPPORTED_DEVICES = [DEVICE_TYPE_CPU] def __init__( self, bits: int, group_size: int, + desc_act: bool, sym: bool, infeatures: int, outfeatures: int, @@ -52,12 +55,8 @@ def __init__( weight_dtype=torch.bfloat16, **kwargs, ): - super().__init__() - self.sym = False - - - self.validate(bits=bits, group_size=group_size, sym=self.sym, desc_act=False) + super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs) self.infeatures = infeatures self.outfeatures = outfeatures @@ -102,9 +101,9 @@ def __init__( self.trainable = trainable def post_init(self): + self.validate_device(self.qweight.device.type) from intel_extension_for_transformers import qbits - assert self.qweight.device.type == "cpu" if self.bias is not None: self.bias = self.bias.to(dtype=torch.float32) diff --git a/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py b/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py index f8ff12ec0..abcdd409d 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py @@ -14,6 +14,7 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin): + SUPPORTED_BITS = [2, 4, 8] """ Triton v2 quantized linear layer. @@ -22,10 +23,8 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin): dequant and matmul into single kernel.add() """ - def __init__(self, bits, group_size, infeatures, outfeatures, bias, **kwargs,): - super().__init__() - if bits not in [2, 4, 8]: - raise NotImplementedError("Only 2,4,8 bits are supported.") + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures, outfeatures, bias, **kwargs,): + super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs) if infeatures % 32 != 0 or outfeatures % 32 != 0: raise NotImplementedError("in_feature and out_feature must be divisible by 32.") self.infeatures = infeatures @@ -65,7 +64,7 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, **kwargs,): self.bias = None def post_init(self): - pass + self.validate_device(self.qweight.device.type) def pack(self, linear, scales, zeros, g_idx=None): W = linear.weight.data.clone() diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index adb40419a..efb597ca6 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -43,7 +43,7 @@ def select_quant_linear( allow_backends = format_dict[format] for k, v in backend_dict.items(): in_allow_backends = k in allow_backends - validate = v.validate(bits, group_size, desc_act, sym, raise_error=False) + validate = v.validate(bits, group_size, desc_act, sym) check_pack_func = hasattr(v, "pack") if pack else True if in_allow_backends and validate and check_pack_func: logger.info(f"Auto choose the fastest one based on quant model compatibility: {v}")