Skip to content

Commit

Permalink
quantize lm_head
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 6, 2025
1 parent b0a674d commit 419218f
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 22 deletions.
133 changes: 111 additions & 22 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import json
import os
import shutil
Expand All @@ -9,6 +10,7 @@
import torch
import torch.nn as nn
from packaging import version
from torch import autocast
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils

from ..nn_modules.hooked_linear import replace_linear_with_hooked_linear
Expand All @@ -20,10 +22,11 @@
from ..utils.importer import select_quant_linear
from ..utils.logger import setup_logger
from ..utils.model import (MODALITY, check_to_quantized, find_layers, get_device, get_module_by_name_prefix,
get_moe_layer_modules, move_to, nested_move_to, normalize_tokenizer, pack_model)
get_moe_layer_modules, move_to, nested_move_to, normalize_tokenizer, pack_model, get_module,
collect_best_params, pack_module)
from ..utils.progress import ProgressBar
from ..utils.torch import torch_empty_cache
from ._const import CPU, DEVICE
from ._const import CPU, DEVICE, CUDA
from .loader import ModelLoader
from .writer import (QUANT_LOG_DAMP, QUANT_LOG_FWD_TIME, QUANT_LOG_LAYER,
QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter)
Expand Down Expand Up @@ -217,8 +220,8 @@ def quantize(
"FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ."
)

if self.quantize_config.lm_head and not isinstance(self.quantize_config, AutoRoundQuantizeConfig):
raise ValueError("`lm_head=True` quantization is only available with AutoRound quantizer. Please use `AutoRoundQuantizeConfig` instead of `QuantizeConfig` and set `lm_head=True` or set `lm_head=False`.")
# if self.quantize_config.lm_head and not isinstance(self.quantize_config, AutoRoundQuantizeConfig):
# raise ValueError("`lm_head=True` quantization is only available with AutoRound quantizer. Please use `AutoRoundQuantizeConfig` instead of `QuantizeConfig` and set `lm_head=True` or set `lm_head=False`.")

if len(calibration_dataset) == 0:
raise ValueError("Calibration dataset must not be empty.")
Expand Down Expand Up @@ -395,6 +398,21 @@ def collate_batch(batch):
cur_layer_device = get_device(layers[0])
data_device = cur_layer_device if calibration_enable_gpu_cache else CPU

cur_layer_device = self.quantize_config.device
data_device = self.quantize_config.device
self.model.to(self.quantize_config.device)

print("self.model", self.model)
lm_head_module = None

# TODO check _tied_weights
if self.quantize_config.lm_head:
lm_head_module = get_module(self.model, key=self.lm_head)
if lm_head_module is None:
raise ValueError(f"could not find layer {self.lm_head} in the model, exit...")
# TODO torch.nn .Linear, transformers.modeling_utils.Conv1D check
print("lm_head_module", lm_head_module.weight)

def store_input_hook(_, args, kwargs):
# Positional arguments.
layer_input = []
Expand All @@ -409,7 +427,7 @@ def store_input_hook(_, args, kwargs):
layer_inputs.append(layer_input)

# Keyword arguments.
if kwargs["attention_mask"] is not None:
if kwargs.get("attention_mask") is not None:
attention_masks.append(kwargs["attention_mask"].to(data_device))
else:
attention_masks.append(None)
Expand All @@ -422,7 +440,41 @@ def store_input_hook(_, args, kwargs):
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)
raise ValueError

if not self.quantize_config.lm_head:
raise ValueError

lm_head_layer_inputs = []
lm_head_attention_masks = []
lm_head_position_ids = []
lm_head_layer_input_kwargs = []
def store_lm_head_input_hook(_, args, kwargs):
# Positional arguments.
layer_input = []
for inp in args:
layer_input.append(move_to(inp, data_device))
if len(layer_input) == 0:
# Some models put hidden_states in kwargs instead of args.
# For example, gptj ...
if kwargs.get("hidden_states") is not None:
layer_input.append(move_to(kwargs["hidden_states"], data_device))

lm_head_layer_inputs.append(layer_input)

# Keyword arguments.
if kwargs.get("attention_mask") is not None:
lm_head_attention_masks.append(kwargs["attention_mask"].to(data_device))
else:
lm_head_attention_masks.append(None)

pos_ids = kwargs.get("position_ids", None)
if pos_ids is not None:
lm_head_position_ids.append(move_to(pos_ids, data_device))
one_kwargs = {}
for (k, v) in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, data_device)
lm_head_layer_input_kwargs.append(one_kwargs)

# move layer to target device
layers[0] = layers[0].to(self.quantize_config.device)
Expand All @@ -440,6 +492,9 @@ def store_input_hook(_, args, kwargs):

# TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py
handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
if self.quantize_config.lm_head:
lm_head_handle = lm_head_module.register_forward_pre_hook(store_lm_head_input_hook, with_kwargs=True)
print("lm_head_handle", lm_head_handle)
is_ovis = self.__class__.__name__ == "OvisGPTQ"
for example in calibration_dataset:
for k, v in example.items():
Expand All @@ -460,6 +515,9 @@ def store_input_hook(_, args, kwargs):
except ValueError:
pass
handle.remove()
if self.quantize_config.lm_head:
lm_head_handle.remove()


move_to(layers[0], CPU)
for module_name in self.base_modules:
Expand All @@ -483,7 +541,7 @@ def store_input_hook(_, args, kwargs):
quantizers = {}

layer_count = len(layers)
layer_pb = ProgressBar(range(layer_count))
layer_pb = ProgressBar(range(layer_count + 1 if self.quantize_config.lm_head else layer_count))
gpu_memorys = []
cpu_memorys = []
durations = []
Expand All @@ -495,8 +553,23 @@ def store_input_hook(_, args, kwargs):
replace_linear_with_hooked_linear(self.model)

for i in layer_pb:
layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}")
layer = layers[i]
is_lm_head = i >= layer_count


if is_lm_head:
inputs = lm_head_layer_inputs
masks = lm_head_attention_masks
pos_ids = lm_head_position_ids
input_kwargs = lm_head_layer_input_kwargs
layer_pb.set_description(f"Quantizing lm_head")
layer = get_module(self.model, key=self.lm_head)
else:
inputs = layer_inputs
masks = attention_masks
pos_ids = position_ids
input_kwargs = layer_input_kwargs
layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}")
layer = layers[i]
if layer.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower():
# TODO FIXME: currently we not support quantizing cross attention layer (pixel_values)
continue
Expand Down Expand Up @@ -574,19 +647,19 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
fwd_start = time.time()
for j in range(num_batches):
layer_input = []
for k, layer_inp in enumerate(layer_inputs[j]):
for k, layer_inp in enumerate(inputs[j]):
layer_input.append(move_to(layer_inp, cur_layer_device))

mask = attention_masks[j]
mask = masks[j]
layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device)

additional_layer_inputs = {"attention_mask": layer_attention_mask}
layer_position_ids = (
None if not position_ids else move_to(position_ids[j], cur_layer_device)
None if not pos_ids else move_to(pos_ids[j], cur_layer_device)
)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
for k, v in input_kwargs[j].items():
additional_layer_inputs[k] = nested_move_to(v, cur_layer_device)

with torch.no_grad():
Expand All @@ -595,11 +668,11 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
if layer.reuse_kv:
additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(i - 1)

layer_output = layer(*layer_input, **additional_layer_inputs)
layer_output = layer(*layer_input) if self.quantize_config.lm_head else layer(*layer_input, **additional_layer_inputs)
if shared_kv_cache_dict.get(i) is None:
shared_kv_cache_dict[i] = layer_output[-1]
else:
layer(*layer_input, **additional_layer_inputs)
layer(*layer_input) if self.quantize_config.lm_head else layer(*layer_input, **additional_layer_inputs)

del layer_input
del additional_layer_inputs
Expand Down Expand Up @@ -666,17 +739,17 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):

for j in range(num_batches):
layer_input = []
for k, layer_inp in enumerate(layer_inputs[j]):
for k, layer_inp in enumerate(inputs[j]):
layer_input.append(move_to(layer_inp, cur_layer_device))

mask = attention_masks[j]
mask = masks[j]
layer_attention_mask = mask if mask is None else move_to(mask, cur_layer_device)

additional_layer_inputs = {"attention_mask": layer_attention_mask}
layer_position_ids = None if not position_ids else move_to(position_ids[j], cur_layer_device)
layer_position_ids = None if not pos_ids else move_to(pos_ids[j], cur_layer_device)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
for k, v in input_kwargs[j].items():
additional_layer_inputs[k] = nested_move_to(v, cur_layer_device)

if hasattr(layer, "reuse_kv"):
Expand All @@ -685,16 +758,16 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):

with torch.no_grad():
layer_output = move_to(
layer(*layer_input, **additional_layer_inputs)[0],
layer(*layer_input)[0] if self.quantize_config.lm_head else layer(*layer_input, **additional_layer_inputs)[0],
cur_layer_device if calibration_enable_gpu_cache else CPU,
)
layer_outputs.append([layer_output])

del layer_input
del additional_layer_inputs


layers[i] = move_to(layer, CPU)
if not is_lm_head:
layers[i] = move_to(layer, CPU)
del layer
del gptq
del layer_inputs
Expand Down Expand Up @@ -730,6 +803,22 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
parallel_packing=self.quantize_config.parallel_packing,
)

lm_head_module = get_module(self.model, key=self.lm_head)
self.qlinear_kernel = pack_module(
model=lm_head_module,
quantizers=quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=backend,
desc_act=self.quantize_config.desc_act,
format=self.quantize_config.format,
dynamic=self.quantize_config.dynamic,
parallel_packing=self.quantize_config.parallel_packing,
)


print("lm_head_module end", lm_head_module.weight)

self.model.config.use_cache = forward_pass_use_cache

self.quantized = True
Expand Down
81 changes: 81 additions & 0 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import functools
import hashlib
import json
Expand Down Expand Up @@ -107,6 +108,27 @@ def get_module_by_name_suffix(model, module_name: str):
if name.endswith(module_name):
return module

def get_module(module, key):
"""Get module from model by key name.
Args:
module (torch.nn.Module): original model
key (str): module name to be replaced
"""
name_list = key.split(".")
for name in name_list:
module = getattr(module, name, None)
return module

def collect_best_params(block):
params = {}
for n, m in block.named_modules():
if hasattr(m, "orig_layer"):
params[n] = {}
for key in m.params.keys():
params[n][key] = copy.deepcopy(m.params[key].data)
return params


def make_quant(
module,
Expand Down Expand Up @@ -338,6 +360,65 @@ def pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar):
qlayers[name].to(layer_device)
pbar.progress()

def pack_module(
model,
quantizers,
bits,
group_size,
backend: BACKEND,
format: str | FORMAT,
desc_act=False,
sym: bool = True,
dynamic=None,
parallel_packing: bool = True,
):
QuantLinear = select_quant_linear(
bits=bits,
dynamic=dynamic,
group_size=group_size,
desc_act=desc_act,
sym=sym,
backend=backend,
format=format,
pack=True,
)

model.to(CPU)

logger.info("Packing model...")

layers = find_layers(model)
layers = {n: layers[n] for n in quantizers}
make_quant(
model,
quantizers,
bits,
group_size,
backend=backend,
format=format,
desc_act=desc_act,
pack=True,
dynamic=dynamic,
)
qlayers = find_layers(model, [QuantLinear])
names = list(qlayers.keys())

if parallel_packing:
max_workers = 2
else:
max_workers = 1

with ThreadPoolExecutor(max_workers=max_workers) as executor:
with ProgressBar(total=len(names)) as pbar:
def wrapper(name):
pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar)

for _ in executor.map(wrapper, names):
pass

logger.info("Model packed.")
return QuantLinear


def pack_model(
model,
Expand Down

0 comments on commit 419218f

Please sign in to comment.