Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX committed Oct 7, 2024
1 parent 84908bf commit 2e44278
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 62 deletions.
74 changes: 37 additions & 37 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@


class BaseQuantLinear(nn.Module):
SUPPORTED_BITS = []
SUPPORTED_GROUP_SIZE = []
SUPPORTED_DESC_ACT = [True, False]
SUPPORTED_SYM = [True, False]
SUPPORTED_SHARDS: bool = True
SUPPORTED_DEVICES = [DEVICE.CUDA]
SUPPORTS_BITS = []
SUPPORTS_GROUP_SIZE = []
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS: bool = True
SUPPORTS_DEVICES = [DEVICE.CUDA]
# empty which means all
SUPPORT_INFEATURES_DIVISIBLE_BY = []
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = []
# empty which means all
SUPPORT_OUTFEATURES_DIVISIBLE_BY = []
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = []

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, *args, **kwargs):
super().__init__()
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, infeatures=infeatures,outfeatures=outfeatures)
if err:
raise err

if DEVICE.CUDA in self.SUPPORTED_DEVICES:
if DEVICE.CUDA in self.SUPPORTS_DEVICES:
check_cuda()

@classmethod
Expand All @@ -36,73 +36,73 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic
@classmethod
def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, dynamic=None, infeatures=None,
outfeatures=None) -> Tuple[bool, Optional[Exception]]:
if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS:
err = f"{cls} only supports `{cls.SUPPORTED_BITS}` bits: actual bits = `{bits}`"
if cls.SUPPORTS_BITS and bits not in cls.SUPPORTS_BITS:
err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual bits = `{bits}`"
return False, NotImplementedError(err)
if cls.SUPPORTED_GROUP_SIZE and group_size not in cls.SUPPORTED_GROUP_SIZE:
err = f"{cls} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
if cls.SUPPORTS_GROUP_SIZE and group_size not in cls.SUPPORTS_GROUP_SIZE:
err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
return False, NotImplementedError(err)
if cls.SUPPORTED_SYM and sym not in cls.SUPPORTED_SYM:
err = f"{cls} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}`"
if cls.SUPPORTS_SYM and sym not in cls.SUPPORTS_SYM:
err = f"{cls} only supports `{cls.SUPPORTS_SYM}` bits: actual sym = `{sym}`"
return False, NotImplementedError(err)
if cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT:
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
if cls.SUPPORTS_DESC_ACT and desc_act not in cls.SUPPORTS_DESC_ACT:
err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
return False, NotImplementedError(err)
if dynamic is not None:
if cls.SUPPORTED_BITS:
if cls.SUPPORTS_BITS:
dynamic_bits = {}
for pattern, pattern_dict in dynamic.items():
dynamic_bits[pattern] = pattern_dict.get("bits", bits)
if len(cls.SUPPORTED_BITS) == 1:
err = f"{cls} not supported dynamic_bits, only support `{cls.SUPPORTED_BITS}` bits"
if len(cls.SUPPORTS_BITS) == 1:
err = f"{cls} not supported dynamic_bits, only support `{cls.SUPPORTS_BITS}` bits"
return False, NotImplementedError(err)
else:
for layer, bits in dynamic_bits.items():
if bits not in cls.SUPPORTED_BITS:
err = f"{cls} only supports `{cls.SUPPORTED_BITS}` bits: actual dynamic_bits = `{bits}` for layer `{layer}`"
if bits not in cls.SUPPORTS_BITS:
err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual dynamic_bits = `{bits}` for layer `{layer}`"
return False, NotImplementedError(err)
if cls.SUPPORTED_GROUP_SIZE:
if cls.SUPPORTS_GROUP_SIZE:
dynamic_group_size = {}
for pattern, pattern_dict in dynamic.items():
dynamic_group_size[pattern] = pattern_dict.get("group_size", group_size)
for layer, group_size in dynamic_group_size.items():
if group_size not in cls.SUPPORTED_GROUP_SIZE:
err = f"{cls} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}` for layer `{layer}`"
if group_size not in cls.SUPPORTS_GROUP_SIZE:
err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}` for layer `{layer}`"
return False, NotImplementedError(err)
if cls.SUPPORTED_SYM:
if cls.SUPPORTS_SYM:
dynamic_sym = {}
for pattern, pattern_dict in dynamic.items():
dynamic_sym[pattern] = pattern_dict.get("sym", sym)
for layer, sym in dynamic_sym.items():
if sym not in cls.SUPPORTED_SYM:
err = f"{cls} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}` for layer `{layer}`"
if sym not in cls.SUPPORTS_SYM:
err = f"{cls} only supports `{cls.SUPPORTS_SYM}` bits: actual sym = `{sym}` for layer `{layer}`"
return False, NotImplementedError(err)
if cls.SUPPORTED_DESC_ACT:
if cls.SUPPORTS_DESC_ACT:
dynamic_desc_act = {}
for pattern, pattern_dict in dynamic.items():
dynamic_desc_act[pattern] = pattern_dict.get("desc_act", desc_act)
for layer, desc_act in dynamic_desc_act.items():
if desc_act not in cls.SUPPORTED_DESC_ACT:
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}` for layer `{layer}`"
if desc_act not in cls.SUPPORTS_DESC_ACT:
err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{desc_act}` for layer `{layer}`"
return False, NotImplementedError(err)
if infeatures is not None:
validate = all(infeatures % in_fea == 0 for in_fea in cls.SUPPORT_INFEATURES_DIVISIBLE_BY)
validate = all(infeatures % in_fea == 0 for in_fea in cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY)
if not validate:
err = f"{cls}: `infeatures` must be divisible by {cls.SUPPORT_INFEATURES_DIVISIBLE_BY}."
err = f"{cls}: `infeatures` must be divisible by {cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)
if outfeatures is not None:
validate = all(outfeatures % out_fea == 0 for out_fea in cls.SUPPORT_OUTFEATURES_DIVISIBLE_BY)
validate = all(outfeatures % out_fea == 0 for out_fea in cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY)
if not validate:
err = f"{cls}: `outfeatures` must be divisible by {cls.SUPPORT_OUTFEATURES_DIVISIBLE_BY}."
err = f"{cls}: `outfeatures` must be divisible by {cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)

return True, None

@classmethod
def validate_device(cls, device_type: str):
device = get_device_by_type(device_type)
if cls.SUPPORTED_DEVICES and device not in cls.SUPPORTED_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTED_DEVICES}` bits: actual device = `{device}`")
if cls.SUPPORTS_DEVICES and device not in cls.SUPPORTS_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTS_DEVICES}` bits: actual device = `{device}`")

# override me
def post_init(self):
Expand Down
10 changes: 5 additions & 5 deletions gptqmodel/nn_modules/qlinear/qlinear_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ def unpack_qzeros(qzeros, bits):


class BitBLASQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [1, 2, 4]
SUPPORTED_DESC_ACT = [False]
SUPPORTED_SHARDS = True
SUPPORT_INFEATURES_DIVISIBLE_BY = [16]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [16]
SUPPORTS_BITS = [1, 2, 4]
SUPPORTS_DESC_ACT = [False]
SUPPORTS_SHARDS = True
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [16]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [16]

OPT_FEATURES = [1, 16, 32, 64, 128, 256, 512]
zeros_mode = "quantized" # "original" or "rescale" or "quantized"
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ def ext_q4_matmul(x, q4, q4_width):


class ExllamaQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]
SUPPORT_INFEATURES_DIVISIBLE_BY = [32]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [32]
SUPPORTS_BITS = [4]
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):


class ExllamaV2QuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]
SUPPORT_INFEATURES_DIVISIBLE_BY = [32]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [32]
SUPPORTS_BITS = [4]
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

Expand Down
8 changes: 4 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def _get_perms():


class MarlinQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]
SUPPORTED_GROUP_SIZE = [128, -1]
SUPPORTED_DESC_ACT = [False]
SUPPORTED_SYM = [True]
SUPPORTS_BITS = [4]
SUPPORTS_GROUP_SIZE = [128, -1]
SUPPORTS_DESC_ACT = [False]
SUPPORTS_SYM = [True]

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
Expand Down
10 changes: 5 additions & 5 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ def apply_gptq_marlin_linear(
return output.reshape(out_shape)

class MarlinInferenceQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4, 8]
SUPPORTED_GROUP_SIZE = [-1, 32, 64, 128]
SUPPORTED_DESC_ACT = [True, False]
SUPPORTED_SYM = [True]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [64]
SUPPORTS_BITS = [4, 8]
SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128]
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [64]

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def convert_dtype_torch2str(dtype):


class QBitsQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [2, 3, 4, 8]
SUPPORTED_DEVICES = [DEVICE.CPU]
SUPPORTS_BITS = [2, 3, 4, 8]
SUPPORTS_DEVICES = [DEVICE.CPU]

def __init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@


class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
SUPPORTED_BITS = [2, 4, 8]
SUPPORT_INFEATURES_DIVISIBLE_BY = [32]
SUPPORT_OUTFEATURES_DIVISIBLE_BY = [32]
SUPPORTS_BITS = [2, 4, 8]
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32]

"""
Triton v2 quantized linear layer.
Expand Down

0 comments on commit 2e44278

Please sign in to comment.