Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 7, 2025
1 parent 419218f commit 0f0048b
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,11 @@ def store_lm_head_input_hook(_, args, kwargs):
move_to(layer, self.quantize_config.device)

cur_layer_device = get_device(layer)
full = find_layers(layer)
for names in layer_modules:
full = find_layers(layer, name=self.lm_head if is_lm_head else "")
modules = [[self.lm_head]] if is_lm_head else layer_modules
for names in modules:
subset = {n: full[n] for n in names if n in full}
print("subset",subset)
skipped_modules = []
gptq = {}
for name in subset:
Expand Down Expand Up @@ -791,6 +793,9 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
task.get_logger().report_plotly('avg_loss', 'avg_loss', loss_fig)
task.get_logger().report_plotly('quant_time', 'quant_time', time_fig)

if self.quantize_config.lm_head:
lm_quantizers = {self.lm_head: quantizers.pop(self.lm_head)}

self.qlinear_kernel = pack_model(
model=self.model,
quantizers=quantizers,
Expand All @@ -803,21 +808,22 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
parallel_packing=self.quantize_config.parallel_packing,
)

lm_head_module = get_module(self.model, key=self.lm_head)
self.qlinear_kernel = pack_module(
model=lm_head_module,
quantizers=quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=backend,
desc_act=self.quantize_config.desc_act,
format=self.quantize_config.format,
dynamic=self.quantize_config.dynamic,
parallel_packing=self.quantize_config.parallel_packing,
)

if self.quantize_config.lm_head:
lm_head_module = get_module(self.model, key=self.lm_head)
self.qlinear_kernel = pack_module(
module=lm_head_module,
module_name=self.lm_head,
quantizers=lm_quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=backend,
desc_act=self.quantize_config.desc_act,
format=self.quantize_config.format,
dynamic=self.quantize_config.dynamic,
parallel_packing=self.quantize_config.parallel_packing,
)

print("lm_head_module end", lm_head_module.weight)
print("lm_head_module end", lm_head_module.weight)

self.model.config.use_cache = forward_pass_use_cache

Expand Down

0 comments on commit 0f0048b

Please sign in to comment.