Skip to content

Commit

Permalink
lm_head uses a special quantize config
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 7, 2025
1 parent 7ee016a commit fff2b67
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@ def collate_batch(batch):
# TODO torch.nn .Linear, transformers.modeling_utils.Conv1D check
print("lm_head_module", lm_head_module.weight)

# TODO warning overwrite dynamic
lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": False}
if self.quantize_config.dynamic is None:
self.quantize_config.dynamic = {"lm_head": lm_head_quant_config}
else:
self.quantize_config.dynamic["lm_head"] = lm_head_quant_config

# TODO HookLinear add register_forward_pre_hook()
def store_input_hook(_, args, kwargs):
# Positional arguments.
layer_input = []
Expand Down Expand Up @@ -600,15 +608,15 @@ def store_lm_head_input_hook(_, args, kwargs):
modules = [[self.lm_head]] if is_lm_head else layer_modules
for names in modules:
subset = {n: full[n] for n in names if n in full}
print("subset",subset)

skipped_modules = []
gptq = {}
for name in subset:
bits = self.quantize_config.bits
sym = self.quantize_config.sym
mse = self.quantize_config.mse
if self.quantize_config.dynamic is not None:
layer_name = f"{self.layers_node}.{i}.{name}"
layer_name = self.lm_head if is_lm_head else f"{self.layers_node}.{i}.{name}"

if self.quantize_config.dynamic_get(layer_name=layer_name) == False: # noqa: E712
logger.info(f"skip module: {layer_name}")
Expand All @@ -625,7 +633,7 @@ def store_lm_head_input_hook(_, args, kwargs):
sym=sym,
mse=mse,
)
print("gptq[name]",gptq[name])

for name in skipped_modules:
subset.pop(name)

Expand Down Expand Up @@ -690,12 +698,12 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
subset[name].forward_hook = None

for name_index, name in enumerate(subset):
layer_name = self.lm_head if is_lm_head else f"{self.layers_node}.{i}.{name}"
layer_pb.set_description(f"Quantizing {name} in layer {i} of {layer_count - 1}")

group_size = self.quantize_config.group_size
desc_act = self.quantize_config.desc_act
if self.quantize_config.dynamic is not None:
layer_name = f"{self.layers_node}.{i}.{name}"
group_size = self.quantize_config.dynamic_get(layer_name, "group_size", group_size)
desc_act = self.quantize_config.dynamic_get(layer_name, "desc_act", desc_act)

Expand Down Expand Up @@ -731,7 +739,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
self.quant_log.append(stat)
logger.info(stat)

quantizers[self.lm_head if is_lm_head else f"{self.layers_node}.{i}.{name}"] = (
quantizers[layer_name] = (
gptq[name].quantizer.to(CPU),
move_to(scale, CPU),
move_to(zero, CPU),
Expand Down

0 comments on commit fff2b67

Please sign in to comment.