Skip to content

Commit

Permalink
fix lm_head layer forward error with marlin
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 9, 2025
1 parent 10c97a8 commit b16a4f0
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 4 deletions.
4 changes: 3 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def quantize(
raise ValueError(f"AutoRound version must be >= 0.3.0: actual = {auto_round_version}")

if self.quantize_config.lm_head:
self.quantize_config.layer_config['lm_head'] = {"data_type": "int"}
self.quantize_config.layer_config[self.lm_head] = {"data_type": "int"}

import torch.nn.functional as F
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -372,6 +372,7 @@ def collate_batch(batch):
backend=backend,
desc_act=self.quantize_config.desc_act,
format=self.quantize_config.format,
lm_head_name=self.lm_head,
parallel_packing=self.quantize_config.parallel_packing,
)

Expand Down Expand Up @@ -789,6 +790,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
backend=backend,
desc_act=self.quantize_config.desc_act,
format=self.quantize_config.format,
lm_head_name=self.lm_head,
dynamic=self.quantize_config.dynamic,
parallel_packing=self.quantize_config.parallel_packing,
)
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def skip(*args, **kwargs):
quantize_config.group_size,
backend=backend.AUTO if (backend == BACKEND.MARLIN and quantize_config.format == FORMAT.MARLIN) or backend == BACKEND.BITBLAS else backend,
format=quantize_config.format,
lm_head_name=cls.lm_head,
desc_act=quantize_config.desc_act,
dynamic=quantize_config.dynamic,
device=device,
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ def skip(*args, **kwargs):
quantize_config.group_size,
backend=BACKEND.AUTO,
format=quantize_config.format,
lm_head_name=cls.lm_head,
desc_act=quantize_config.desc_act,
pack=True,
)
Expand Down
6 changes: 5 additions & 1 deletion gptqmodel/nn_modules/qlinear/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
else:
self.bias = None

self.is_lm_head = False
if kwargs.get("name") is not None and kwargs.get("lm_head_name") is not None:
self.is_lm_head = kwargs["name"] == kwargs["lm_head_name"]

@classmethod
def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
if marlin_import_exception is not None:
Expand Down Expand Up @@ -330,7 +334,7 @@ def forward(self, A: torch.Tensor):
A = A.half()

return apply_gptq_marlin_linear(
input=A,
input=A.contiguous() if self.is_lm_head else A,
weight=self.qweight,
weight_scale=self.scales,
weight_zp=self.zp,
Expand Down
11 changes: 9 additions & 2 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def make_quant(
group_size: int,
backend: BACKEND,
format: str | FORMAT,
lm_head_name: str,
desc_act: bool = False,
sym: bool = True,
pack: bool = False,
Expand Down Expand Up @@ -159,7 +160,8 @@ def make_quant(
if linear is not QuantLinear:
logger.info(f"Use {QuantLinear} failed, try to use {linear} instead.")

result = create_quant_layer(linear, bits, desc_act, dynamic, group_size, module, names, sym, device)
result = create_quant_layer(linear, bits, desc_act, dynamic, group_size, module, names, sym, device
, lm_head_name)
return result
except NotImplementedError as e:
# only fallback to other quant linears when backend is auto.
Expand All @@ -169,7 +171,8 @@ def make_quant(
raise ValueError("no support quant linear was found for this module.")


def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module, names, sym, device) -> BaseQuantLinear:
def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module, names, sym, device, lm_head_name: str
) -> BaseQuantLinear:
if isinstance(module, QuantLinear):
return QuantLinear
named_modules = module.named_modules()
Expand Down Expand Up @@ -225,6 +228,8 @@ def create_quant_layer(QuantLinear, bits, desc_act, dynamic, group_size, module,
outfeatures=out_features,
bias=bias,
weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype,
name=name,
lm_head_name=lm_head_name,
)
new_layer.device = ori_layer_device
recurse_setattr(module, name, new_layer.to(ori_layer_device))
Expand Down Expand Up @@ -362,6 +367,7 @@ def pack_model(
group_size,
backend: BACKEND,
format: str | FORMAT,
lm_head_name: str,
desc_act=False,
sym: bool = True,
dynamic=None,
Expand Down Expand Up @@ -391,6 +397,7 @@ def pack_model(
group_size,
backend=backend,
format=format,
lm_head_name=lm_head_name,
desc_act=desc_act,
pack=True,
dynamic=dynamic,
Expand Down

0 comments on commit b16a4f0

Please sign in to comment.