diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 9b4d85ace..39ac14ac5 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -398,12 +398,7 @@ def collate_batch(batch): cur_layer_device = get_device(layers[0]) data_device = cur_layer_device if calibration_enable_gpu_cache else CPU - cur_layer_device = self.quantize_config.device - data_device = self.quantize_config.device - self.model.to(self.quantize_config.device) - print("self.model", self.model) - lm_head_module = None # TODO check _tied_weights if self.quantize_config.lm_head: @@ -449,40 +444,7 @@ def store_input_hook(_, args, kwargs): one_kwargs[k] = nested_move_to(v, data_device) layer_input_kwargs.append(one_kwargs) - if not self.quantize_config.lm_head: - raise ValueError - - lm_head_layer_inputs = [] - lm_head_attention_masks = [] - lm_head_position_ids = [] - lm_head_layer_input_kwargs = [] - def store_lm_head_input_hook(_, args, kwargs): - # Positional arguments. - layer_input = [] - for inp in args: - layer_input.append(move_to(inp, data_device)) - if len(layer_input) == 0: - # Some models put hidden_states in kwargs instead of args. - # For example, gptj ... - if kwargs.get("hidden_states") is not None: - layer_input.append(move_to(kwargs["hidden_states"], data_device)) - - lm_head_layer_inputs.append(layer_input) - - # Keyword arguments. - if kwargs.get("attention_mask") is not None: - lm_head_attention_masks.append(kwargs["attention_mask"].to(data_device)) - else: - lm_head_attention_masks.append(None) - - pos_ids = kwargs.get("position_ids", None) - if pos_ids is not None: - lm_head_position_ids.append(move_to(pos_ids, data_device)) - one_kwargs = {} - for (k, v) in kwargs.items(): # make sure other arguments also be captured - if k not in ["hidden_states", "attention_mask", "position_ids"]: - one_kwargs[k] = nested_move_to(v, data_device) - lm_head_layer_input_kwargs.append(one_kwargs) + raise ValueError # move layer to target device layers[0] = layers[0].to(self.quantize_config.device) @@ -500,9 +462,6 @@ def store_lm_head_input_hook(_, args, kwargs): # TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) - if self.quantize_config.lm_head: - lm_head_handle = lm_head_module.register_forward_pre_hook(store_lm_head_input_hook, with_kwargs=True) - print("lm_head_handle", lm_head_handle) is_ovis = self.__class__.__name__ == "OvisGPTQ" for example in calibration_dataset: for k, v in example.items(): @@ -523,9 +482,6 @@ def store_lm_head_input_hook(_, args, kwargs): except ValueError: pass handle.remove() - if self.quantize_config.lm_head: - lm_head_handle.remove() - move_to(layers[0], CPU) for module_name in self.base_modules: @@ -562,22 +518,13 @@ def store_lm_head_input_hook(_, args, kwargs): for i in layer_pb: is_lm_head = i >= layer_count - - if is_lm_head: - inputs = lm_head_layer_inputs - masks = lm_head_attention_masks - pos_ids = lm_head_position_ids - input_kwargs = lm_head_layer_input_kwargs layer_pb.set_description(f"Quantizing lm_head") layer = get_module(self.model, key=self.lm_head) else: - inputs = layer_inputs - masks = attention_masks - pos_ids = position_ids - input_kwargs = layer_input_kwargs layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}") layer = layers[i] + if layer.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): # TODO FIXME: currently we not support quantizing cross attention layer (pixel_values) continue @@ -657,19 +604,19 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): fwd_start = time.time() for j in range(num_batches): layer_input = [] - for k, layer_inp in enumerate(inputs[j]): + for k, layer_inp in enumerate(layer_inputs[j]): layer_input.append(move_to(layer_inp, cur_layer_device)) - mask = masks[j] + mask = attention_masks[j] layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device) additional_layer_inputs = {"attention_mask": layer_attention_mask} layer_position_ids = ( - None if not pos_ids else move_to(pos_ids[j], cur_layer_device) + None if not position_ids else move_to(position_ids[j], cur_layer_device) ) if layer_position_ids is not None: additional_layer_inputs["position_ids"] = layer_position_ids - for k, v in input_kwargs[j].items(): + for k, v in layer_input_kwargs[j].items(): additional_layer_inputs[k] = nested_move_to(v, cur_layer_device) with torch.no_grad(): @@ -749,17 +696,17 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): for j in range(num_batches): layer_input = [] - for k, layer_inp in enumerate(inputs[j]): + for k, layer_inp in enumerate(layer_inputs[j]): layer_input.append(move_to(layer_inp, cur_layer_device)) - mask = masks[j] + mask = attention_masks[j] layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device) additional_layer_inputs = {"attention_mask": layer_attention_mask} - layer_position_ids = None if not pos_ids else move_to(pos_ids[j], cur_layer_device) + layer_position_ids = None if not position_ids else move_to(position_ids[j], cur_layer_device) if layer_position_ids is not None: additional_layer_inputs["position_ids"] = layer_position_ids - for k, v in input_kwargs[j].items(): + for k, v in layer_input_kwargs[j].items(): additional_layer_inputs[k] = nested_move_to(v, cur_layer_device) if hasattr(layer, "reuse_kv"):