From b16a4f088e4dba3636ed2231960aae2b53f96165 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 9 Jan 2025 09:20:04 +0000 Subject: [PATCH] fix lm_head layer forward error with marlin --- gptqmodel/models/base.py | 4 +++- gptqmodel/models/loader.py | 1 + gptqmodel/models/writer.py | 1 + gptqmodel/nn_modules/qlinear/marlin.py | 6 +++++- gptqmodel/utils/model.py | 11 +++++++++-- 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 330861666..b37033eeb 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -302,7 +302,7 @@ def quantize( raise ValueError(f"AutoRound version must be >= 0.3.0: actual = {auto_round_version}") if self.quantize_config.lm_head: - self.quantize_config.layer_config['lm_head'] = {"data_type": "int"} + self.quantize_config.layer_config[self.lm_head] = {"data_type": "int"} import torch.nn.functional as F from torch.utils.data import DataLoader @@ -372,6 +372,7 @@ def collate_batch(batch): backend=backend, desc_act=self.quantize_config.desc_act, format=self.quantize_config.format, + lm_head_name=self.lm_head, parallel_packing=self.quantize_config.parallel_packing, ) @@ -789,6 +790,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): backend=backend, desc_act=self.quantize_config.desc_act, format=self.quantize_config.format, + lm_head_name=self.lm_head, dynamic=self.quantize_config.dynamic, parallel_packing=self.quantize_config.parallel_packing, ) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 129048437..3023356ed 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -406,6 +406,7 @@ def skip(*args, **kwargs): quantize_config.group_size, backend=backend.AUTO if (backend == BACKEND.MARLIN and quantize_config.format == FORMAT.MARLIN) or backend == BACKEND.BITBLAS else backend, format=quantize_config.format, + lm_head_name=cls.lm_head, desc_act=quantize_config.desc_act, dynamic=quantize_config.dynamic, device=device, diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 741d2b1c3..95c47383d 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -346,6 +346,7 @@ def skip(*args, **kwargs): quantize_config.group_size, backend=BACKEND.AUTO, format=quantize_config.format, + lm_head_name=cls.lm_head, desc_act=quantize_config.desc_act, pack=True, ) diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index fadac4133..c9468a2cd 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -284,6 +284,10 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat else: self.bias = None + self.is_lm_head = False + if kwargs.get("name") is not None and kwargs.get("lm_head_name") is not None: + self.is_lm_head = kwargs["name"] == kwargs["lm_head_name"] + @classmethod def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: if marlin_import_exception is not None: @@ -330,7 +334,7 @@ def forward(self, A: torch.Tensor): A = A.half() return apply_gptq_marlin_linear( - input=A, + input=A.contiguous() if self.is_lm_head else A, weight=self.qweight, weight_scale=self.scales, weight_zp=self.zp, diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 08aff251e..ab0610aeb 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -129,6 +129,7 @@ def make_quant( group_size: int, backend: BACKEND, format: str | FORMAT, + lm_head_name: str, desc_act: bool = False, sym: bool = True, pack: bool = False, @@ -159,7 +160,8 @@ def make_quant( if linear is not QuantLinear: logger.info(f"Use {QuantLinear} failed, try to use {linear} instead.") - result = create_quant_layer(linear, bits, desc_act, dynamic, group_size, module, names, sym, device) + result = create_quant_layer(linear, bits, desc_act, dynamic, group_size, module, names, sym, device + , lm_head_name) return result except NotImplementedError as e: # only fallback to other quant linears when backend is auto. @@ -169,7 +171,8 @@ def make_quant( raise ValueError("no support quant linear was found for this module.") -def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module, names, sym, device) -> BaseQuantLinear: +def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module, names, sym, device, lm_head_name: str + ) -> BaseQuantLinear: if isinstance(module, QuantLinear): return QuantLinear named_modules = module.named_modules() @@ -225,6 +228,8 @@ def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module, outfeatures=out_features, bias=bias, weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype, + name=name, + lm_head_name=lm_head_name, ) new_layer.device = ori_layer_device recurse_setattr(module, name, new_layer.to(ori_layer_device)) @@ -362,6 +367,7 @@ def pack_model( group_size, backend: BACKEND, format: str | FORMAT, + lm_head_name: str, desc_act=False, sym: bool = True, dynamic=None, @@ -391,6 +397,7 @@ def pack_model( group_size, backend=backend, format=format, + lm_head_name=lm_head_name, desc_act=desc_act, pack=True, dynamic=dynamic,