Skip to content

Commit

Permalink
If quant_inference=True and device != DEVICE.CPU, use torch.float16 (M…
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud authored Dec 20, 2024
1 parent 5b7163e commit cd5ac26
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def from_quantized(

if torch_dtype is None or torch_dtype == "auto":
# TODO FIX ME for `dynamic`, non-quantized modules should be in native type
torch_dtype = auto_dtype_from_config(config=config, device=device, device_map=device_map)
torch_dtype = auto_dtype_from_config(config=config, device=device, device_map=device_map, quant_inference=True)
elif not isinstance(torch_dtype, torch.dtype):
raise ValueError(f"torch_dtype value of `{torch_dtype}` is not a torch.dtype instance.")

Expand Down
8 changes: 7 additions & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,13 @@ def get_checkpoints(model_id_or_path: str, extensions: List[str], possible_model


# return the most stable tensor dtype for quantization while minimizing vram
def auto_dtype_from_config(config: PretrainedConfig, device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, device: Optional[Union[str, int]] = None ) -> torch.dtype:
def auto_dtype_from_config(config: PretrainedConfig,
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
device: Optional[Union[str, int]] = None,
quant_inference: bool = False) -> torch.dtype:
if quant_inference and device != DEVICE.CPU:
return torch.float16

# TODO mps has errors with bfloat16, lock to float16 for now
if sys.platform == "darwin" or "mps" in [device, device_map] or (
isinstance(device_map, Dict) and "mps" in device_map.values()):
Expand Down

0 comments on commit cd5ac26

Please sign in to comment.