Skip to content

Commit

Permalink
Revert "[FEATURE] Add GPTQModel.shard_quantized() api (ModelCloud#271)"
Browse files Browse the repository at this point in the history
This reverts commit 88392c7.
  • Loading branch information
ZX-ModelCloud committed Jul 23, 2024
1 parent 119fc4d commit 36f602a
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 107 deletions.
20 changes: 0 additions & 20 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,3 @@ def from_quantized(
**kwargs,
)

@classmethod
def shard_quantized(cls,
quantized_model_path_or_id: str,
max_shard_size: str,
save_dir: str,
safetensors_metadata: Optional[Dict[str, str]] = None,
use_safetensors: bool = True,
model_base_name: Optional[str] = None):
model_type = check_and_get_model_type(quantized_model_path_or_id)
shard_quantized_func = MODEL_MAP[model_type].shard_quantized

return shard_quantized_func(
quantized_model_path_or_id=quantized_model_path_or_id,
save_dir=save_dir,
max_shard_size=max_shard_size,
safetensors_metadata=safetensors_metadata,
use_safetensors=use_safetensors,
model_base_name=model_base_name,
)

41 changes: 3 additions & 38 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,12 @@ def __init__(
quantized: bool,
quantize_config: QuantizeConfig,
qlinear_kernel: nn.Module = None,
from_quantized: bool = False,
):
super().__init__()

self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.from_quantized = from_quantized
self.quantize_config = quantize_config
self.config = self.model.config

Expand Down Expand Up @@ -578,32 +576,6 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)

@classmethod
def shard_quantized(cls,
quantized_model_path_or_id: str,
max_shard_size: str,
save_dir: str,
safetensors_metadata: Optional[Dict[str, str]] = None,
use_safetensors: bool = True,
model_base_name: Optional[str] = None
):
# gptqmodel_post_init will check if the device matches.
# Here, the CPU is always used, so you need to skip it.
quantized_model = cls.from_quantized(quantized_model_path_or_id,
device="cpu",
backend=BACKEND.TRITON,
use_safetensors=use_safetensors,
safetensors_metadata=safetensors_metadata,
model_basename=model_base_name,
skip_gptqmodel_post_init=True,)
# Skip from_quantized check
quantized_model.from_quantized = False
quantized_model.save_quantized(save_dir,
safetensors_metadata=safetensors_metadata,
use_safetensors=use_safetensors,
max_shard_size=max_shard_size,
model_base_name=model_base_name)

def save_quantized(
self,
save_dir: str,
Expand All @@ -613,9 +585,6 @@ def save_quantized(
model_base_name: Optional[str] = None
):
"""save quantized model and configs to local disk"""
if self.from_quantized:
raise NotImplementedError("Saving a loaded quantized model is not supported. If you need to re-shard the model, please use `GPTQModel.shard_quantized()` api.")

os.makedirs(save_dir, exist_ok=True)

# write gptqmodel tooling fingerprint to config
Expand Down Expand Up @@ -869,7 +838,7 @@ def skip(*args, **kwargs):
@classmethod
def from_quantized(
cls,
model_name_or_path: str,
model_name_or_path: Optional[str],
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
Expand All @@ -886,7 +855,6 @@ def from_quantized(
):
if backend == BACKEND.VLLM or backend == BACKEND.SGLANG:
import os

# to optimize vllm inference, set an environment variable 'VLLM_ATTENTION_BACKEND' to 'FLASHINFER'.
os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER'

Expand Down Expand Up @@ -1265,10 +1233,8 @@ def skip(*args, **kwargs):
logger.warning("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096

skip_gptqmodel_post_init = kwargs.pop("skip_gptqmodel_post_init", None)
if skip_gptqmodel_post_init is None:
# Any post-initialization that require device information, for example buffers initialization on device.
model = gptqmodel_post_init(model, use_act_order=quantize_config.desc_act, quantize_config=quantize_config)
# Any post-initialization that require device information, for example buffers initialization on device.
model = gptqmodel_post_init(model, use_act_order=quantize_config.desc_act, quantize_config=quantize_config)

model.eval()

Expand All @@ -1277,7 +1243,6 @@ def skip(*args, **kwargs):
quantized=True,
quantize_config=quantize_config,
qlinear_kernel=qlinear_kernel,
from_quantized=True,
)

def __getattr__(self, item):
Expand Down
48 changes: 0 additions & 48 deletions tests/test_shard_quantized.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ def test_load_shared_vllm(self):
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
self.assertEquals(generated_text,
" Paris.\n2. Who has a national flag with a white field surrounded by")
" Paris.\n2. Who has a national flag with a white field surrounded by")

0 comments on commit 36f602a

Please sign in to comment.