From 419218f03869eb5b13382d2c49cc7cd69e6df709 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 6 Jan 2025 11:22:01 +0000 Subject: [PATCH] quantize lm_head --- gptqmodel/models/base.py | 133 ++++++++++++++++++++++++++++++++------- gptqmodel/utils/model.py | 81 ++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 22 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 105ab262f..db7dc850c 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import json import os import shutil @@ -9,6 +10,7 @@ import torch import torch.nn as nn from packaging import version +from torch import autocast from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils from ..nn_modules.hooked_linear import replace_linear_with_hooked_linear @@ -20,10 +22,11 @@ from ..utils.importer import select_quant_linear from ..utils.logger import setup_logger from ..utils.model import (MODALITY, check_to_quantized, find_layers, get_device, get_module_by_name_prefix, - get_moe_layer_modules, move_to, nested_move_to, normalize_tokenizer, pack_model) + get_moe_layer_modules, move_to, nested_move_to, normalize_tokenizer, pack_model, get_module, + collect_best_params, pack_module) from ..utils.progress import ProgressBar from ..utils.torch import torch_empty_cache -from ._const import CPU, DEVICE +from ._const import CPU, DEVICE, CUDA from .loader import ModelLoader from .writer import (QUANT_LOG_DAMP, QUANT_LOG_FWD_TIME, QUANT_LOG_LAYER, QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter) @@ -217,8 +220,8 @@ def quantize( "FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ." ) - if self.quantize_config.lm_head and not isinstance(self.quantize_config, AutoRoundQuantizeConfig): - raise ValueError("`lm_head=True` quantization is only available with AutoRound quantizer. Please use `AutoRoundQuantizeConfig` instead of `QuantizeConfig` and set `lm_head=True` or set `lm_head=False`.") + # if self.quantize_config.lm_head and not isinstance(self.quantize_config, AutoRoundQuantizeConfig): + # raise ValueError("`lm_head=True` quantization is only available with AutoRound quantizer. Please use `AutoRoundQuantizeConfig` instead of `QuantizeConfig` and set `lm_head=True` or set `lm_head=False`.") if len(calibration_dataset) == 0: raise ValueError("Calibration dataset must not be empty.") @@ -395,6 +398,21 @@ def collate_batch(batch): cur_layer_device = get_device(layers[0]) data_device = cur_layer_device if calibration_enable_gpu_cache else CPU + cur_layer_device = self.quantize_config.device + data_device = self.quantize_config.device + self.model.to(self.quantize_config.device) + + print("self.model", self.model) + lm_head_module = None + + # TODO check _tied_weights + if self.quantize_config.lm_head: + lm_head_module = get_module(self.model, key=self.lm_head) + if lm_head_module is None: + raise ValueError(f"could not find layer {self.lm_head} in the model, exit...") + # TODO torch.nn .Linear, transformers.modeling_utils.Conv1D check + print("lm_head_module", lm_head_module.weight) + def store_input_hook(_, args, kwargs): # Positional arguments. layer_input = [] @@ -409,7 +427,7 @@ def store_input_hook(_, args, kwargs): layer_inputs.append(layer_input) # Keyword arguments. - if kwargs["attention_mask"] is not None: + if kwargs.get("attention_mask") is not None: attention_masks.append(kwargs["attention_mask"].to(data_device)) else: attention_masks.append(None) @@ -422,7 +440,41 @@ def store_input_hook(_, args, kwargs): if k not in ["hidden_states", "attention_mask", "position_ids"]: one_kwargs[k] = nested_move_to(v, data_device) layer_input_kwargs.append(one_kwargs) - raise ValueError + + if not self.quantize_config.lm_head: + raise ValueError + + lm_head_layer_inputs = [] + lm_head_attention_masks = [] + lm_head_position_ids = [] + lm_head_layer_input_kwargs = [] + def store_lm_head_input_hook(_, args, kwargs): + # Positional arguments. + layer_input = [] + for inp in args: + layer_input.append(move_to(inp, data_device)) + if len(layer_input) == 0: + # Some models put hidden_states in kwargs instead of args. + # For example, gptj ... + if kwargs.get("hidden_states") is not None: + layer_input.append(move_to(kwargs["hidden_states"], data_device)) + + lm_head_layer_inputs.append(layer_input) + + # Keyword arguments. + if kwargs.get("attention_mask") is not None: + lm_head_attention_masks.append(kwargs["attention_mask"].to(data_device)) + else: + lm_head_attention_masks.append(None) + + pos_ids = kwargs.get("position_ids", None) + if pos_ids is not None: + lm_head_position_ids.append(move_to(pos_ids, data_device)) + one_kwargs = {} + for (k, v) in kwargs.items(): # make sure other arguments also be captured + if k not in ["hidden_states", "attention_mask", "position_ids"]: + one_kwargs[k] = nested_move_to(v, data_device) + lm_head_layer_input_kwargs.append(one_kwargs) # move layer to target device layers[0] = layers[0].to(self.quantize_config.device) @@ -440,6 +492,9 @@ def store_input_hook(_, args, kwargs): # TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) + if self.quantize_config.lm_head: + lm_head_handle = lm_head_module.register_forward_pre_hook(store_lm_head_input_hook, with_kwargs=True) + print("lm_head_handle", lm_head_handle) is_ovis = self.__class__.__name__ == "OvisGPTQ" for example in calibration_dataset: for k, v in example.items(): @@ -460,6 +515,9 @@ def store_input_hook(_, args, kwargs): except ValueError: pass handle.remove() + if self.quantize_config.lm_head: + lm_head_handle.remove() + move_to(layers[0], CPU) for module_name in self.base_modules: @@ -483,7 +541,7 @@ def store_input_hook(_, args, kwargs): quantizers = {} layer_count = len(layers) - layer_pb = ProgressBar(range(layer_count)) + layer_pb = ProgressBar(range(layer_count + 1 if self.quantize_config.lm_head else layer_count)) gpu_memorys = [] cpu_memorys = [] durations = [] @@ -495,8 +553,23 @@ def store_input_hook(_, args, kwargs): replace_linear_with_hooked_linear(self.model) for i in layer_pb: - layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}") - layer = layers[i] + is_lm_head = i >= layer_count + + + if is_lm_head: + inputs = lm_head_layer_inputs + masks = lm_head_attention_masks + pos_ids = lm_head_position_ids + input_kwargs = lm_head_layer_input_kwargs + layer_pb.set_description(f"Quantizing lm_head") + layer = get_module(self.model, key=self.lm_head) + else: + inputs = layer_inputs + masks = attention_masks + pos_ids = position_ids + input_kwargs = layer_input_kwargs + layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}") + layer = layers[i] if layer.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): # TODO FIXME: currently we not support quantizing cross attention layer (pixel_values) continue @@ -574,19 +647,19 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): fwd_start = time.time() for j in range(num_batches): layer_input = [] - for k, layer_inp in enumerate(layer_inputs[j]): + for k, layer_inp in enumerate(inputs[j]): layer_input.append(move_to(layer_inp, cur_layer_device)) - mask = attention_masks[j] + mask = masks[j] layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device) additional_layer_inputs = {"attention_mask": layer_attention_mask} layer_position_ids = ( - None if not position_ids else move_to(position_ids[j], cur_layer_device) + None if not pos_ids else move_to(pos_ids[j], cur_layer_device) ) if layer_position_ids is not None: additional_layer_inputs["position_ids"] = layer_position_ids - for k, v in layer_input_kwargs[j].items(): + for k, v in input_kwargs[j].items(): additional_layer_inputs[k] = nested_move_to(v, cur_layer_device) with torch.no_grad(): @@ -595,11 +668,11 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): if layer.reuse_kv: additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(i - 1) - layer_output = layer(*layer_input, **additional_layer_inputs) + layer_output = layer(*layer_input) if self.quantize_config.lm_head else layer(*layer_input, **additional_layer_inputs) if shared_kv_cache_dict.get(i) is None: shared_kv_cache_dict[i] = layer_output[-1] else: - layer(*layer_input, **additional_layer_inputs) + layer(*layer_input) if self.quantize_config.lm_head else layer(*layer_input, **additional_layer_inputs) del layer_input del additional_layer_inputs @@ -666,17 +739,17 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): for j in range(num_batches): layer_input = [] - for k, layer_inp in enumerate(layer_inputs[j]): + for k, layer_inp in enumerate(inputs[j]): layer_input.append(move_to(layer_inp, cur_layer_device)) - mask = attention_masks[j] + mask = masks[j] layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device) additional_layer_inputs = {"attention_mask": layer_attention_mask} - layer_position_ids = None if not position_ids else move_to(position_ids[j], cur_layer_device) + layer_position_ids = None if not pos_ids else move_to(pos_ids[j], cur_layer_device) if layer_position_ids is not None: additional_layer_inputs["position_ids"] = layer_position_ids - for k, v in layer_input_kwargs[j].items(): + for k, v in input_kwargs[j].items(): additional_layer_inputs[k] = nested_move_to(v, cur_layer_device) if hasattr(layer, "reuse_kv"): @@ -685,7 +758,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): with torch.no_grad(): layer_output = move_to( - layer(*layer_input, **additional_layer_inputs)[0], + layer(*layer_input)[0] if self.quantize_config.lm_head else layer(*layer_input, **additional_layer_inputs)[0], cur_layer_device if calibration_enable_gpu_cache else CPU, ) layer_outputs.append([layer_output]) @@ -693,8 +766,8 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): del layer_input del additional_layer_inputs - - layers[i] = move_to(layer, CPU) + if not is_lm_head: + layers[i] = move_to(layer, CPU) del layer del gptq del layer_inputs @@ -730,6 +803,22 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): parallel_packing=self.quantize_config.parallel_packing, ) + lm_head_module = get_module(self.model, key=self.lm_head) + self.qlinear_kernel = pack_module( + model=lm_head_module, + quantizers=quantizers, + bits=self.quantize_config.bits, + group_size=self.quantize_config.group_size, + backend=backend, + desc_act=self.quantize_config.desc_act, + format=self.quantize_config.format, + dynamic=self.quantize_config.dynamic, + parallel_packing=self.quantize_config.parallel_packing, + ) + + + print("lm_head_module end", lm_head_module.weight) + self.model.config.use_cache = forward_pass_use_cache self.quantized = True diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3fe895fe1..a1eacfd6f 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import functools import hashlib import json @@ -107,6 +108,27 @@ def get_module_by_name_suffix(model, module_name: str): if name.endswith(module_name): return module +def get_module(module, key): + """Get module from model by key name. + + Args: + module (torch.nn.Module): original model + key (str): module name to be replaced + """ + name_list = key.split(".") + for name in name_list: + module = getattr(module, name, None) + return module + +def collect_best_params(block): + params = {} + for n, m in block.named_modules(): + if hasattr(m, "orig_layer"): + params[n] = {} + for key in m.params.keys(): + params[n][key] = copy.deepcopy(m.params[key].data) + return params + def make_quant( module, @@ -338,6 +360,65 @@ def pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar): qlayers[name].to(layer_device) pbar.progress() +def pack_module( + model, + quantizers, + bits, + group_size, + backend: BACKEND, + format: str | FORMAT, + desc_act=False, + sym: bool = True, + dynamic=None, + parallel_packing: bool = True, +): + QuantLinear = select_quant_linear( + bits=bits, + dynamic=dynamic, + group_size=group_size, + desc_act=desc_act, + sym=sym, + backend=backend, + format=format, + pack=True, + ) + + model.to(CPU) + + logger.info("Packing model...") + + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + make_quant( + model, + quantizers, + bits, + group_size, + backend=backend, + format=format, + desc_act=desc_act, + pack=True, + dynamic=dynamic, + ) + qlayers = find_layers(model, [QuantLinear]) + names = list(qlayers.keys()) + + if parallel_packing: + max_workers = 2 + else: + max_workers = 1 + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + with ProgressBar(total=len(names)) as pbar: + def wrapper(name): + pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar) + + for _ in executor.map(wrapper, names): + pass + + logger.info("Model packed.") + return QuantLinear + def pack_model( model,