Skip to content

Commit

Permalink
add SUPPORT_INFEATURES_DIVISIBLE_BY and SUPPORT_OUTFEATURES_DIVISIBLE_BY
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX committed Oct 7, 2024
1 parent 5d647ef commit 7398420
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 51 deletions.
57 changes: 35 additions & 22 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,77 +13,90 @@ class BaseQuantLinear(nn.Module):
SUPPORTED_SYM = [True, False]
SUPPORTED_SHARDS: bool = True
SUPPORTED_DEVICES = [DEVICE.CUDA]
# empty which means all
SUPPORT_INFEATURES_DIVISIBLE_BY = []
# empty which means all
SUPPORT_OUTFEATURES_DIVISIBLE_BY = []

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, *args, **kwargs):
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, *args, **kwargs):
super().__init__()
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym)
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, infeatures=infeatures,outfeatures=outfeatures)
if err:
raise err

if DEVICE.CUDA in self.SUPPORTED_DEVICES:
check_cuda()

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic=None) -> Tuple[bool, Optional[Exception]]:
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic=None) -> Tuple[
bool, Optional[Exception]]:
validate, err = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic)
return validate, err

@classmethod
def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic=None) -> Tuple[bool, Optional[Exception]]:
validate = True
err = ""
def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic=None, infeatures=None,
outfeatures=None) -> Tuple[bool, Optional[Exception]]:
if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS:
validate = False
err = f"{cls} only supports `{cls.SUPPORTED_BITS}` bits: actual bits = `{bits}`"
elif cls.SUPPORTED_GROUP_SIZE and group_size not in cls.SUPPORTED_GROUP_SIZE:
validate = False
return False, NotImplementedError(err)
if cls.SUPPORTED_GROUP_SIZE and group_size not in cls.SUPPORTED_GROUP_SIZE:
err = f"{cls} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
elif cls.SUPPORTED_SYM and sym not in cls.SUPPORTED_SYM:
validate = False
return False, NotImplementedError(err)
if cls.SUPPORTED_SYM and sym not in cls.SUPPORTED_SYM:
err = f"{cls} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}`"
elif cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT:
validate = False
return False, NotImplementedError(err)
if cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT:
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
elif dynamic is not None:
return False, NotImplementedError(err)
if dynamic is not None:
if cls.SUPPORTED_BITS:
dynamic_bits = {}
for pattern, pattern_dict in dynamic.items():
dynamic_bits[pattern] = pattern_dict.get("bits", bits)
if len(cls.SUPPORTED_BITS) == 1:
validate = False
err = f"{cls} not supported dynamic_bits, only support `{cls.SUPPORTED_BITS}` bits"
return False, NotImplementedError(err)
else:
for layer, bits in dynamic_bits.items():
if bits not in cls.SUPPORTED_BITS:
validate = False
err = f"{cls} only supports `{cls.SUPPORTED_BITS}` bits: actual dynamic_bits = `{bits}` for layer `{layer}`"
break
return False, NotImplementedError(err)
if cls.SUPPORTED_GROUP_SIZE:
dynamic_group_size = {}
for pattern, pattern_dict in dynamic.items():
dynamic_group_size[pattern] = pattern_dict.get("group_size", group_size)
for layer, group_size in dynamic_group_size.items():
if group_size not in cls.SUPPORTED_GROUP_SIZE:
validate = False
err = f"{cls} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}` for layer `{layer}`"
break
return False, NotImplementedError(err)
if cls.SUPPORTED_SYM:
dynamic_sym = {}
for pattern, pattern_dict in dynamic.items():
dynamic_sym[pattern] = pattern_dict.get("sym", sym)
for layer, sym in dynamic_sym.items():
if sym not in cls.SUPPORTED_SYM:
validate = False
err = f"{cls} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}` for layer `{layer}`"
return False, NotImplementedError(err)
if cls.SUPPORTED_DESC_ACT:
dynamic_desc_act = {}
for pattern, pattern_dict in dynamic.items():
dynamic_desc_act[pattern] = pattern_dict.get("desc_act", desc_act)
for layer, desc_act in dynamic_desc_act.items():
if desc_act not in cls.SUPPORTED_DESC_ACT:
validate = False
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}` for layer `{layer}`"
return validate, NotImplementedError(err) if err else None
return False, NotImplementedError(err)
if infeatures is not None:
validate = all(infeatures % in_fea == 0 for in_fea in cls.SUPPORT_INFEATURES_DIVISIBLE_BY)
if not validate:
err = f"{cls}: `infeatures` must be divisible by {cls.SUPPORT_INFEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)
if outfeatures is not None:
validate = all(outfeatures % out_fea == 0 for out_fea in cls.SUPPORT_OUTFEATURES_DIVISIBLE_BY)
if not validate:
err = f"{cls}: `outfeatures` must be divisible by {cls.SUPPORT_OUTFEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)

return True, None

@classmethod
def validate_device(cls, device_type: str):
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class BitBLASQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [1, 2, 4]
SUPPORTED_DESC_ACT = [False]
SUPPORTED_SHARDS = True
SUPPORT_INFEATURES_DIVISIBLE_BY = [16]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [16]

OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512]
zeros_mode = "quantized" # "original" or "rescale" or "quantized"
Expand Down Expand Up @@ -105,7 +107,7 @@ def __init__(
layout: str = "nt",
**kwargs,
):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures,outfeatures=outfeatures,**kwargs)

import_bitblas()

Expand All @@ -126,8 +128,6 @@ def __init__(
def _validate_parameters(
self, group_size: int, infeatures: int, outfeatures: int
):
if infeatures % 16 != 0 or outfeatures % 16 != 0:
raise ValueError("`infeatures` and `outfeatures` must be divisible by 16.")
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")

Expand Down
14 changes: 6 additions & 8 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,27 @@ def ext_q4_matmul(x, q4, q4_width):

class ExllamaQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]

SUPPORT_INFEATURES_DIVISIBLE_BY = [32]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [32]

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs,):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures

# auto pad
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=self.infeatures, outfeatures=self.outfeatures, **kwargs)

self.bits = bits

# backup original values
self.original_outfeatures = outfeatures
self.original_infeatures = infeatures

self.maxq = 2**self.bits - 1

assert self.infeatures % 32 == 0
assert self.outfeatures % 32 == 0

self.register_buffer(
"qweight",
torch.zeros((self.original_infeatures // 32 * self.bits, self.original_outfeatures), dtype=torch.int32),
Expand Down
17 changes: 8 additions & 9 deletions gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,30 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):

class ExllamaV2QuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]
SUPPORT_INFEATURES_DIVISIBLE_BY = [32]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [32]

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs,):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)
self.group_size = group_size if group_size != -1 else infeatures
# auto pad
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=self.infeatures, outfeatures=self.outfeatures, **kwargs)

self.q_handle = None
self.q_tensors = None

self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures

# auto pad
self.outfeatures = outfeatures + (-outfeatures % 32)
self.infeatures = infeatures + (-infeatures % self.group_size)

# backup original values
self.original_outfeatures = outfeatures
self.original_infeatures = infeatures
self.maxq = 2**self.bits - 1

assert self.infeatures % 32 == 0
assert self.outfeatures % 32 == 0

# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
self.register_buffer(
"qweight",
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class MarlinQuantLinear(BaseQuantLinear):

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs)
if not torch.cuda.get_device_capability()[0] >= 8:
raise ValueError(
f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `backend=Backend.MARLIN`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).'
Expand Down
6 changes: 2 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,11 @@ class MarlinInferenceQuantLinear(BaseQuantLinear):
SUPPORTED_GROUP_SIZE = [-1, 32, 64, 128]
SUPPORTED_DESC_ACT = [True, False]
SUPPORTED_SYM = [True]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [64]

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

self.original_infeatures = infeatures
self.original_outfeatures = outfeatures
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs)

self.pack_factor = 32 // bits # packed into int32

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/qlinear_qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
**kwargs,
):
self.sym = False
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, **kwargs)

self.infeatures = infeatures
self.outfeatures = outfeatures
Expand Down
7 changes: 4 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
SUPPORTED_BITS = [2, 4, 8]
SUPPORT_INFEATURES_DIVISIBLE_BY = [32]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [32]

"""
Triton v2 quantized linear layer.
Expand All @@ -24,9 +27,7 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
"""

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures, outfeatures, bias, **kwargs,):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)
if infeatures % 32 != 0 or outfeatures % 32 != 0:
raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs)
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
Expand Down

0 comments on commit 7398420

Please sign in to comment.