From 32b0e7de55df9a2aa597dfe779dcb88c45fc2dc2 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:32:21 +0800 Subject: [PATCH] [Fix] all tensors not same device (#5) * fix device error * update gptqmodel version * fix test --- optimum/gptq/quantizer.py | 27 +++++++++++++-------------- optimum/gptq/utils.py | 14 ++++++++++++++ optimum/utils/import_utils.py | 2 +- tests/gptq/test_quantization.py | 5 ++++- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 61ba67b3030..976f6418b3b 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -53,7 +53,7 @@ from gptqmodel import exllama_set_max_input_length from gptqmodel.quantization import GPTQ from gptqmodel.utils.importer import hf_select_quant_linear - from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format + from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format, nested_move_to from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init from gptqmodel.version import __version__ as gptqmodel_version @@ -511,9 +511,11 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): blocks = recurse_getattr(model, self.block_name_to_quantize) + cur_layer_device = get_device(blocks[0]) + if not has_device_map: # put modules from module_name_preceding_first_block on cuda or xpu or cpu - to_device = 0 if has_device_more_than_cpu() else "cpu" + to_device = cur_layer_device for module_name in self.module_name_preceding_first_block: module = recurse_getattr(model, module_name) if module is None: @@ -525,14 +527,14 @@ def store_input_hook(_, input, *args): kwargs = args[0] if input is None: if "hidden_states" in kwargs: - input = (kwargs["hidden_states"],) + input = (nested_move_to(kwargs["hidden_states"], cur_layer_device),) else: raise ValueError("No input value found in the foward pass") layer_inputs.append(input) other_kwargs = {} for k, v in kwargs.items(): # make sure other arguments also be captured if k not in ["hidden_states"]: - other_kwargs[k] = v + other_kwargs[k] = nested_move_to(v, cur_layer_device) layer_input_kwargs.append(other_kwargs) raise ValueError @@ -540,11 +542,7 @@ def store_input_hook(_, input, *args): handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) for data in dataset: for k, v in data.items(): - # put the data on gpu, we won't put them back to cpu - if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu(): - data[k] = v.to(0) - else: - data[k] = v.to(device) + data[k] = nested_move_to(v, cur_layer_device) try: model(**data) except ValueError: @@ -571,11 +569,7 @@ def store_input_hook(_, input, *args): handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True) for data in dataset: for k, v in data.items(): - # put the data on gpu, we won't put them back to cpu - if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu(): - data[k] = v.to(0) - else: - data[k] = v.to(device) + data[k] = nested_move_to(v, cur_layer_device) try: model(**data) except ValueError: @@ -587,6 +581,7 @@ def store_input_hook(_, input, *args): if (not has_device_map or get_device(block) == torch.device("cpu")) and has_device_more_than_cpu(): block = block.to(0) layers = get_layers(block) + block_device = get_device(block) if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0: if self.true_sequential: layers_name_list = self.modules_in_block_to_quantize @@ -620,6 +615,10 @@ def tmp(_, input, output): for j in range(len(dataset)): # the args are already on the gpu # don't need to store the output + layer_inputs[j] = nested_move_to(layer_inputs[j], block_device) + for k, v in layer_input_kwargs[j].items(): + layer_input_kwargs[j][k] = nested_move_to(v, block_device) + block(*layer_inputs[j], **layer_input_kwargs[j]) # remove hook for h in handles: diff --git a/optimum/gptq/utils.py b/optimum/gptq/utils.py index a5f9afdaaef..a4582a3c0a7 100644 --- a/optimum/gptq/utils.py +++ b/optimum/gptq/utils.py @@ -113,3 +113,17 @@ def get_seqlen(model: nn.Module): "We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`" ) return 2048 + +def move_to(obj: torch.Tensor | nn.Module, device: torch.device): + if get_device(obj) != device: + obj = obj.to(device) + return obj + + +def nested_move_to(v, device): + if isinstance(v, torch.Tensor): + return move_to(v, device) + elif isinstance(v, (list, tuple)): + return type(v)([nested_move_to(e, device) for e in v]) + else: + return v diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 8f7635ce043..d0f4c85db2b 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -52,7 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0") DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0") AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0 -GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.1") # Allows 1.4.0.dev0 +GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.2") # This is the minimal required version to support some ONNX Runtime features diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py index dbbedcec983..b6b50fb617d 100644 --- a/tests/gptq/test_quantization.py +++ b/tests/gptq/test_quantization.py @@ -316,7 +316,10 @@ def test_exllama_serialization(self): # quantized models are more compatible with device map than # device context managers (they're never used in transformers testing suite) _ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference}) - _ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference}) + if is_gptqmodel_available(): + _ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference}) + else: + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference}) class GPTQTestNoBlockCaching(GPTQTestCUDA):