Skip to content

Commit

Permalink
add shard_quantized()
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jul 23, 2024
1 parent 5f5eae6 commit 21c5712
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
36 changes: 36 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,39 @@ def from_quantized(
**kwargs,
)

@classmethod
def shard_quantized(cls,
model_name_or_path: str,
save_dir: str,
max_shard_size: str,
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
quantize_config: Optional[QuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
allow_unsafe_loading: bool = False,
verify_hash: Optional[Union[str, List[str]]] = None,
safetensors_metadata: Optional[Dict[str, str]] = None,
**kwargs,):
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
shard_quantized_func = MODEL_MAP[model_type].shard_quantized

return shard_quantized_func(
model_name_or_path=model_name_or_path,
save_dir=save_dir,
max_shard_size=max_shard_size,
device_map=device_map,
max_memory=max_memory,
device=device,
quantize_config=quantize_config,
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
allow_unsafe_loading=allow_unsafe_loading,
verify_hash=verify_hash,
safetensors_metadata=safetensors_metadata,
**kwargs,
)

35 changes: 34 additions & 1 deletion gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ def __init__(
quantized: bool,
quantize_config: QuantizeConfig,
qlinear_kernel: nn.Module = None,
from_quantized: bool = False,
):
super().__init__()

self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.from_quantized = from_quantized
self.quantize_config = quantize_config
self.config = self.model.config

Expand Down Expand Up @@ -576,6 +578,33 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)

@classmethod
def shard_quantized(cls,
model_name_or_path: str,
save_dir: str,
max_shard_size: str,
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
torch_dtype: [str | torch.dtype] = "auto",
quantize_config: Optional[QuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
format: Optional[FORMAT] = None,
allow_unsafe_loading: bool = False,
verify_hash: Optional[Union[str, List[str]]] = None,
safetensors_metadata: Optional[Dict[str, str]] = None,
**kwargs,
):
quantized_model = cls.from_quantized(model_name_or_path, device_map=device_map, max_memory=max_memory, device=device, backend=BACKEND.TRITON, torch_dtype=torch_dtype,
quantize_config=quantize_config, model_basename=model_basename, use_safetensors=use_safetensors,trust_remote_code=trust_remote_code,
format=format,allow_unsafe_loading=allow_unsafe_loading,verify_hash=verify_hash,safetensors_metadata=safetensors_metadata,**kwargs)
# Skip from_quantized check
quantized_model.from_quantized = False
quantized_model.save_quantized(save_dir, safetensors_metadata=safetensors_metadata, use_safetensors=use_safetensors,
max_shard_size=max_shard_size, model_base_name=model_basename)

def save_quantized(
self,
save_dir: str,
Expand All @@ -585,6 +614,10 @@ def save_quantized(
model_base_name: Optional[str] = None
):
"""save quantized model and configs to local disk"""
if self.from_quantized:
raise NotImplementedError("Saving a quantized model again is not supported. \n"
"If you need to shard the model file, refer to shard_quantized().")

os.makedirs(save_dir, exist_ok=True)

# write gptqmodel tooling fingerprint to config
Expand Down Expand Up @@ -838,7 +871,7 @@ def skip(*args, **kwargs):
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
model_name_or_path: str,
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
Expand Down

0 comments on commit 21c5712

Please sign in to comment.