Skip to content

Commit

Permalink
remove store_lm_head_input_hook()
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 7, 2025
1 parent fff2b67 commit c3f826f
Showing 1 changed file with 10 additions and 63 deletions.
73 changes: 10 additions & 63 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,7 @@ def collate_batch(batch):
cur_layer_device = get_device(layers[0])
data_device = cur_layer_device if calibration_enable_gpu_cache else CPU

cur_layer_device = self.quantize_config.device
data_device = self.quantize_config.device
self.model.to(self.quantize_config.device)

print("self.model", self.model)
lm_head_module = None

# TODO check _tied_weights
if self.quantize_config.lm_head:
Expand Down Expand Up @@ -449,40 +444,7 @@ def store_input_hook(_, args, kwargs):
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)

if not self.quantize_config.lm_head:
raise ValueError

lm_head_layer_inputs = []
lm_head_attention_masks = []
lm_head_position_ids = []
lm_head_layer_input_kwargs = []
def store_lm_head_input_hook(_, args, kwargs):
# Positional arguments.
layer_input = []
for inp in args:
layer_input.append(move_to(inp, data_device))
if len(layer_input) == 0:
# Some models put hidden_states in kwargs instead of args.
# For example, gptj ...
if kwargs.get("hidden_states") is not None:
layer_input.append(move_to(kwargs["hidden_states"], data_device))

lm_head_layer_inputs.append(layer_input)

# Keyword arguments.
if kwargs.get("attention_mask") is not None:
lm_head_attention_masks.append(kwargs["attention_mask"].to(data_device))
else:
lm_head_attention_masks.append(None)

pos_ids = kwargs.get("position_ids", None)
if pos_ids is not None:
lm_head_position_ids.append(move_to(pos_ids, data_device))
one_kwargs = {}
for (k, v) in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, data_device)
lm_head_layer_input_kwargs.append(one_kwargs)
raise ValueError

# move layer to target device
layers[0] = layers[0].to(self.quantize_config.device)
Expand All @@ -500,9 +462,6 @@ def store_lm_head_input_hook(_, args, kwargs):

# TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py
handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
if self.quantize_config.lm_head:
lm_head_handle = lm_head_module.register_forward_pre_hook(store_lm_head_input_hook, with_kwargs=True)
print("lm_head_handle", lm_head_handle)
is_ovis = self.__class__.__name__ == "OvisGPTQ"
for example in calibration_dataset:
for k, v in example.items():
Expand All @@ -523,9 +482,6 @@ def store_lm_head_input_hook(_, args, kwargs):
except ValueError:
pass
handle.remove()
if self.quantize_config.lm_head:
lm_head_handle.remove()


move_to(layers[0], CPU)
for module_name in self.base_modules:
Expand Down Expand Up @@ -562,22 +518,13 @@ def store_lm_head_input_hook(_, args, kwargs):

for i in layer_pb:
is_lm_head = i >= layer_count


if is_lm_head:
inputs = lm_head_layer_inputs
masks = lm_head_attention_masks
pos_ids = lm_head_position_ids
input_kwargs = lm_head_layer_input_kwargs
layer_pb.set_description(f"Quantizing lm_head")
layer = get_module(self.model, key=self.lm_head)
else:
inputs = layer_inputs
masks = attention_masks
pos_ids = position_ids
input_kwargs = layer_input_kwargs
layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}")
layer = layers[i]

if layer.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower():
# TODO FIXME: currently we not support quantizing cross attention layer (pixel_values)
continue
Expand Down Expand Up @@ -657,19 +604,19 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
fwd_start = time.time()
for j in range(num_batches):
layer_input = []
for k, layer_inp in enumerate(inputs[j]):
for k, layer_inp in enumerate(layer_inputs[j]):
layer_input.append(move_to(layer_inp, cur_layer_device))

mask = masks[j]
mask = attention_masks[j]
layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device)

additional_layer_inputs = {"attention_mask": layer_attention_mask}
layer_position_ids = (
None if not pos_ids else move_to(pos_ids[j], cur_layer_device)
None if not position_ids else move_to(position_ids[j], cur_layer_device)
)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in input_kwargs[j].items():
for k, v in layer_input_kwargs[j].items():
additional_layer_inputs[k] = nested_move_to(v, cur_layer_device)

with torch.no_grad():
Expand Down Expand Up @@ -749,17 +696,17 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):

for j in range(num_batches):
layer_input = []
for k, layer_inp in enumerate(inputs[j]):
for k, layer_inp in enumerate(layer_inputs[j]):
layer_input.append(move_to(layer_inp, cur_layer_device))

mask = masks[j]
mask = attention_masks[j]
layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device)

additional_layer_inputs = {"attention_mask": layer_attention_mask}
layer_position_ids = None if not pos_ids else move_to(pos_ids[j], cur_layer_device)
layer_position_ids = None if not position_ids else move_to(position_ids[j], cur_layer_device)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in input_kwargs[j].items():
for k, v in layer_input_kwargs[j].items():
additional_layer_inputs[k] = nested_move_to(v, cur_layer_device)

if hasattr(layer, "reuse_kv"):
Expand Down

0 comments on commit c3f826f

Please sign in to comment.