Skip to content

Commit

Permalink
base model added "modality" field
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Dec 20, 2024
1 parent 71068ad commit f8ea146
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
33 changes: 21 additions & 12 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
nested_move_to,
pack_model,
simple_dispatch_model,
MODALITY,
)
from ..utils.progress import ProgressBar
from ..utils.torch import torch_empty_cache
Expand Down Expand Up @@ -87,6 +88,8 @@ class BaseGPTQModel(nn.Module):

supports_desc_act = [True, False]

modality: List[MODALITY] = [MODALITY.TEXT]

def __init__(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -265,13 +268,30 @@ def quantize(
if BITBLAS_AVAILABLE is False:
raise ValueError(BITBLAS_INSTALL_HINT)


device_map = self.hf_device_map
if device_map:
for name, device in device_map.items():
if device == "cpu" and best_device != CPU:
logger.info(f"truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, best_device)

calibration_dataset = self._prepare_dataset_for_quantization(calibration_dataset, batch_size, tokenizer,)

# Calculate the average length of the average input_ids
total_input_ids_length = 0
max_input_id_length = 0
for row in calibration_dataset:
input_ids = row["input_ids"]
if isinstance(input_ids, torch.Tensor):
input_ids_length = input_ids.numel()
if input_ids.dim() <= 2:
input_ids_length = input_ids.shape[-1]
else:
raise ValueError(
"Expected a 1-dimensional tensor or 2-dimensional tensor for 'input_ids', but got a tensor with {0} dimensions.".format(
input_ids.dim()))
else:
input_ids_length = len(input_ids)

Expand All @@ -284,17 +304,6 @@ def quantize(
logger.warning(f"The average length of input_ids of calibration_dataset should be greater than "
f"{min_calibration_dataset_input_ids_avg_length}: actual avg: {avg}.")

device_map = self.hf_device_map
if device_map:
for name, device in device_map.items():
if device == "cpu" and best_device != CPU:
logger.info(f"truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, best_device)

calibration_dataset = self._prepare_dataset_for_quantization(calibration_dataset, batch_size, tokenizer,)

if isinstance(self.quantize_config, AutoRoundQuantizeConfig):
from auto_round import AutoRound
from auto_round import __version__ as auto_round_version
Expand Down
6 changes: 6 additions & 0 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import shutil
import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from typing import Dict, List, Optional, Tuple, Type, Union

import accelerate
Expand Down Expand Up @@ -778,3 +779,8 @@ def check_requires_version(requires_version, current_version):
return OPERATOR_MAP[op_symbol](current_version, required_version)
else:
return None

class MODALITY(str, Enum):
TEXT = "text"
IMAGE_TO_TEXT = "image_to_text"
TEXT_TO_IMAGE = "text_to_image"

0 comments on commit f8ea146

Please sign in to comment.