diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 4bbe1e358..330861666 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -395,7 +395,7 @@ def collate_batch(batch): raise NotImplementedError(f"This type({type(lm_head_module)}) of lm_head quantization is currently not " f"supported. SUPPORTS_MODULE_TYPES is {SUPPORTS_MODULE_TYPES}") - lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": False} + lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": True} if self.quantize_config.dynamic is None: self.quantize_config.dynamic = {self.lm_head: lm_head_quant_config} elif self.quantize_config.dynamic_get(self.lm_head, default_value=None) is None: @@ -413,6 +413,9 @@ def collate_batch(batch): num_batches = len(calibration_dataset) layers = get_module_by_name_prefix(self.model, self.layers_node) + # TODO lm_head need move model to same device + self.model.to(self.quantize_config.device) + cur_layer_device = get_device(layers[0]) data_device = cur_layer_device if calibration_enable_gpu_cache else CPU @@ -445,6 +448,22 @@ def store_input_hook(_, args, kwargs): one_kwargs[k] = nested_move_to(v, data_device) layer_input_kwargs.append(one_kwargs) + if not self.quantize_config.lm_head: + raise ValueError + + lm_head_inputs = [] + def store_lm_head_input_hook(_, args, kwargs): + # Positional arguments. + lm_head_layer_input = [] + for inp in args: + lm_head_layer_input.append(move_to(inp, data_device)) + if len(lm_head_layer_input) == 0: + # Some models put hidden_states in kwargs instead of args. + # For example, gptj ... + if kwargs.get("hidden_states") is not None: + lm_head_layer_input.append(move_to(kwargs["hidden_states"], data_device)) + + lm_head_inputs.append(lm_head_layer_input) raise ValueError # move layer to target device @@ -463,6 +482,8 @@ def store_input_hook(_, args, kwargs): # TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) + if self.quantize_config.lm_head: + lm_head_handle = layers[0].register_forward_pre_hook(store_lm_head_input_hook, with_kwargs=True) is_ovis = self.__class__.__name__ == "OvisGPTQ" for example in calibration_dataset: for k, v in example.items(): @@ -483,6 +504,8 @@ def store_input_hook(_, args, kwargs): except ValueError: pass handle.remove() + if self.quantize_config.lm_head: + lm_head_handle.remove() move_to(layers[0], CPU) for module_name in self.base_modules: @@ -525,9 +548,7 @@ def store_input_hook(_, args, kwargs): if is_lm_head: layer_pb.set_description(f"Quantizing lm_head") layer = get_module(self.model, key=self.lm_head) - if only_quant_lm_head: - layer_inputs = torch.load(lm_head_layer_inputs_path) - print("loaded lm_head_layer_inputs.pt", layer_inputs) + layer_inputs = lm_head_inputs else: layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}") layer = layers[i] @@ -744,10 +765,6 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): [], ) # TODO: is it really OK to cache only the first positional argument? - if i == layer_count - 1: - print("saved lm_head_layer_inputs.pt", layer_inputs) - torch.save(layer_inputs, lm_head_layer_inputs_path) - torch_empty_cache() logger.info(f"Quantization summary:\n{self.quant_log}")