Skip to content

Commit

Permalink
add store_lm_head_input_hook()
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 9, 2025
1 parent 7db0daf commit 10c97a8
Showing 1 changed file with 25 additions and 8 deletions.
33 changes: 25 additions & 8 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def collate_batch(batch):
raise NotImplementedError(f"This type({type(lm_head_module)}) of lm_head quantization is currently not "
f"supported. SUPPORTS_MODULE_TYPES is {SUPPORTS_MODULE_TYPES}")

lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": False}
lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": True}
if self.quantize_config.dynamic is None:
self.quantize_config.dynamic = {self.lm_head: lm_head_quant_config}
elif self.quantize_config.dynamic_get(self.lm_head, default_value=None) is None:
Expand All @@ -413,6 +413,9 @@ def collate_batch(batch):
num_batches = len(calibration_dataset)
layers = get_module_by_name_prefix(self.model, self.layers_node)

# TODO lm_head need move model to same device
self.model.to(self.quantize_config.device)

cur_layer_device = get_device(layers[0])
data_device = cur_layer_device if calibration_enable_gpu_cache else CPU

Expand Down Expand Up @@ -445,6 +448,22 @@ def store_input_hook(_, args, kwargs):
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)

if not self.quantize_config.lm_head:
raise ValueError

lm_head_inputs = []
def store_lm_head_input_hook(_, args, kwargs):
# Positional arguments.
lm_head_layer_input = []
for inp in args:
lm_head_layer_input.append(move_to(inp, data_device))
if len(lm_head_layer_input) == 0:
# Some models put hidden_states in kwargs instead of args.
# For example, gptj ...
if kwargs.get("hidden_states") is not None:
lm_head_layer_input.append(move_to(kwargs["hidden_states"], data_device))

lm_head_inputs.append(lm_head_layer_input)
raise ValueError

# move layer to target device
Expand All @@ -463,6 +482,8 @@ def store_input_hook(_, args, kwargs):

# TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py
handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
if self.quantize_config.lm_head:
lm_head_handle = layers[0].register_forward_pre_hook(store_lm_head_input_hook, with_kwargs=True)
is_ovis = self.__class__.__name__ == "OvisGPTQ"
for example in calibration_dataset:
for k, v in example.items():
Expand All @@ -483,6 +504,8 @@ def store_input_hook(_, args, kwargs):
except ValueError:
pass
handle.remove()
if self.quantize_config.lm_head:
lm_head_handle.remove()

move_to(layers[0], CPU)
for module_name in self.base_modules:
Expand Down Expand Up @@ -525,9 +548,7 @@ def store_input_hook(_, args, kwargs):
if is_lm_head:
layer_pb.set_description(f"Quantizing lm_head")
layer = get_module(self.model, key=self.lm_head)
if only_quant_lm_head:
layer_inputs = torch.load(lm_head_layer_inputs_path)
print("loaded lm_head_layer_inputs.pt", layer_inputs)
layer_inputs = lm_head_inputs
else:
layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}")
layer = layers[i]
Expand Down Expand Up @@ -744,10 +765,6 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
[],
) # TODO: is it really OK to cache only the first positional argument?

if i == layer_count - 1:
print("saved lm_head_layer_inputs.pt", layer_inputs)
torch.save(layer_inputs, lm_head_layer_inputs_path)

torch_empty_cache()

logger.info(f"Quantization summary:\n{self.quant_log}")
Expand Down

0 comments on commit 10c97a8

Please sign in to comment.