Skip to content

Commit

Permalink
Check if quant lm_head supports
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 7, 2025
1 parent 691e9bb commit 31ef797
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
get_moe_layer_modules, move_to, nested_move_to, normalize_tokenizer, pack_model, get_module)
from ..utils.progress import ProgressBar
from ..utils.torch import torch_empty_cache
from ._const import CPU, DEVICE, CUDA
from ._const import CPU, DEVICE, CUDA, SUPPORTS_MODULE_TYPES
from .loader import ModelLoader
from .writer import (QUANT_LOG_DAMP, QUANT_LOG_FWD_TIME, QUANT_LOG_LAYER,
QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter)
Expand Down Expand Up @@ -382,6 +382,28 @@ def collate_batch(batch):
self.quantized = True
return

if self.quantize_config.lm_head:
if self.model.config.tie_word_embeddings and hasattr(self.model.model, "_tied_weights_keys"):
tied_keys = self.model._tied_weights_keys
for item in tied_keys:
if self.lm_head in item:
raise NotImplementedError(f"quantizing lm_head with tied weights has not been supported "
f"currently")

lm_head_module = get_module(self.model, key=self.lm_head)
if get_module(self.model, key=self.lm_head) is None:
raise ValueError(f"could not find layer {self.lm_head} in the model, exit...")

if not isinstance(lm_head_module, tuple(SUPPORTS_MODULE_TYPES)):
raise NotImplementedError(f"This type({type(lm_head_module)}) of lm_head quantization is currently not "
f"supported. SUPPORTS_MODULE_TYPES is {SUPPORTS_MODULE_TYPES}")

lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": False}
if self.quantize_config.dynamic is None:
self.quantize_config.dynamic = {self.lm_head: lm_head_quant_config}
elif self.quantize_config.dynamic_get(self.lm_head, default_value=None) is None:
self.quantize_config.dynamic[self.lm_head] = lm_head_quant_config

forward_pass_use_cache = self.model.config.use_cache if hasattr(self.model.config, "use_cache") else False
self.model.config.use_cache = False

Expand All @@ -397,23 +419,6 @@ def collate_batch(batch):
cur_layer_device = get_device(layers[0])
data_device = cur_layer_device if calibration_enable_gpu_cache else CPU

print("self.model", self.model)

# TODO check _tied_weights
if self.quantize_config.lm_head:
lm_head_module = get_module(self.model, key=self.lm_head)
if lm_head_module is None:
raise ValueError(f"could not find layer {self.lm_head} in the model, exit...")
# TODO torch.nn .Linear, transformers.modeling_utils.Conv1D check
print("lm_head_module", lm_head_module.weight)

# TODO warning overwrite dynamic
lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": False}
if self.quantize_config.dynamic is None:
self.quantize_config.dynamic = {"lm_head": lm_head_quant_config}
else:
self.quantize_config.dynamic["lm_head"] = lm_head_quant_config

# TODO HookLinear add register_forward_pre_hook()
def store_input_hook(_, args, kwargs):
# Positional arguments.
Expand Down

0 comments on commit 31ef797

Please sign in to comment.