Skip to content

Commit

Permalink
QuantizeConfig add "lm_head_low_gpu_mem_usage" field
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jan 10, 2025
1 parent e3a1af9 commit cff67f3
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
45 changes: 40 additions & 5 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def collate_batch(batch):
raise NotImplementedError(f"This type({type(lm_head_module)}) of lm_head quantization is currently not "
f"supported. SUPPORTS_MODULE_TYPES is {SUPPORTS_MODULE_TYPES}")

lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": False}
lm_head_quant_config = {"bits": 8, "group_size": 32, "sym": True, "desc_act": False, "mse": 2.4}
if self.quantize_config.dynamic is None:
self.quantize_config.dynamic = {self.lm_head: lm_head_quant_config}
elif self.quantize_config.dynamic_get(self.lm_head, default_value=None) is None:
Expand All @@ -411,6 +411,9 @@ def collate_batch(batch):
layer_input_kwargs = []
layer_outputs = []

if self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
self.model.to(self.quantize_config.device)

num_batches = len(calibration_dataset)
layers = get_module_by_name_prefix(self.model, self.layers_node)

Expand Down Expand Up @@ -446,7 +449,24 @@ def store_input_hook(_, args, kwargs):
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)

raise ValueError
if not self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
raise ValueError

lm_head_inputs = []
if self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
def store_lm_head_input_hook(_, args, kwargs):
# Positional arguments.
lm_head_layer_input = []
for inp in args:
lm_head_layer_input.append(move_to(inp, data_device))
if len(lm_head_layer_input) == 0:
# Some models put hidden_states in kwargs instead of args.
# For example, gptj ...
if kwargs.get("hidden_states") is not None:
lm_head_layer_input.append(move_to(kwargs["hidden_states"], data_device))

lm_head_inputs.append(lm_head_layer_input)
raise ValueError

# move layer to target device
layers[0] = layers[0].to(self.quantize_config.device)
Expand All @@ -464,6 +484,8 @@ def store_input_hook(_, args, kwargs):

# TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py
handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
if self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
lm_head_handle = layers[0].register_forward_pre_hook(store_lm_head_input_hook, with_kwargs=True)
is_ovis = self.__class__.__name__ == "OvisGPTQ"
for example in calibration_dataset:
for k, v in example.items():
Expand All @@ -484,8 +506,14 @@ def store_input_hook(_, args, kwargs):
except ValueError:
pass
handle.remove()
if self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
lm_head_handle.remove()

if self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
self.model.to(CPU)
else:
move_to(layers[0], CPU)

move_to(layers[0], CPU)
for module_name in self.base_modules:
module = get_module_by_name_prefix(self.model, module_name)
if module is not None:
Expand Down Expand Up @@ -523,6 +551,8 @@ def store_input_hook(_, args, kwargs):
if is_lm_head:
layer_pb.set_description(f"Quantizing lm_head")
layer = get_module(self.model, key=self.lm_head)
if self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
layer_inputs = lm_head_inputs
else:
layer_pb.set_description(f"Quantizing layer {i} of {layer_count - 1}")
layer = layers[i]
Expand Down Expand Up @@ -574,6 +604,7 @@ def store_input_hook(_, args, kwargs):

bits = self.quantize_config.dynamic_get(layer_name, "bits", bits)
sym = self.quantize_config.dynamic_get(layer_name, "sym", sym)
mse = self.quantize_config.dynamic_get(layer_name, "mse", mse)
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer.configure(
bits,
Expand Down Expand Up @@ -651,15 +682,19 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):

group_size = self.quantize_config.group_size
desc_act = self.quantize_config.desc_act
damp_percent = self.quantize_config.damp_percent
static_groups = self.quantize_config.static_groups
if self.quantize_config.dynamic is not None:
group_size = self.quantize_config.dynamic_get(layer_name, "group_size", group_size)
desc_act = self.quantize_config.dynamic_get(layer_name, "desc_act", desc_act)
damp_percent = self.quantize_config.dynamic_get(layer_name, "damp_percent", damp_percent)
static_groups = self.quantize_config.dynamic_get(layer_name, "static_groups", static_groups)

scale, zero, g_idx, duration, avg_loss, damp_percent = gptq[name].quantize(
percdamp=self.quantize_config.damp_percent,
percdamp=damp_percent,
group_size=group_size,
actorder=desc_act,
static_groups=self.quantize_config.static_groups,
static_groups=static_groups,
)
if task is not None:
task.get_logger().report_scalar(
Expand Down
4 changes: 3 additions & 1 deletion gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class QuantizeConfig():
sym: bool = field(default=True)
true_sequential: bool = field(default=True)
lm_head: bool = field(default=False)
lm_head_low_gpu_mem_usage: bool = field(default=False)
quant_method: str = field(default=QUANT_METHOD.GPTQ)
# default to gptq v1 format for maximum compat with 3rd party inference libs with minimal loss vs v2
# if you inference with gptqmodel, save to gptq_v2 format for best result
Expand Down Expand Up @@ -202,7 +203,8 @@ def meta_set(self, key: str, value: Any):
def meta_get(self, key: str) -> Any:
return self.meta.get(key)

def dynamic_get(self, layer_name: str, key: str = None, default_value: Union[int, bool] = None) -> Union[Dict, int, bool]:
def dynamic_get(self, layer_name: str, key: str = None, default_value: Union[int, bool, float] = None
) -> Union[Dict, int, bool, float]:
return dynamic_get(self.dynamic, layer_name, key, default_value)

# versionable is a meta.property that pairs value with version i.e "value:1.0.0"
Expand Down

0 comments on commit cff67f3

Please sign in to comment.