Skip to content

Commit

Permalink
Check QuantLinear Device
Browse files Browse the repository at this point in the history
  • Loading branch information
ZX-ModelCloud committed Jul 5, 2024
1 parent b39fa13 commit 9735bb6
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 36 deletions.
5 changes: 5 additions & 0 deletions gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from torch import device

CPU = device("cpu")
CUDA = device("cuda")
CUDA_0 = device("cuda:0")


DEVICE_TYPE_CPU = "cuda"
DEVICE_TYPE_CUDA = "cuda"

SUPPORTED_MODELS = [
"bloom",
"gptj",
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
gptqmodel_post_init, make_quant, move_to, nested_move_to, pack_model, simple_dispatch_model,
verify_model_hash, verify_sharded_model_hashes)
from ..version import __version__
from ._const import CPU, CUDA_0, SUPPORTED_MODELS
from ._const import CPU, DEVICE_TYPE_CUDA, CUDA_0, SUPPORTED_MODELS

logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
Expand Down Expand Up @@ -754,7 +754,7 @@ def from_quantized(
check_cuda()

if backend == Backend.QBITS:
device = torch.device("cpu")
device = CPU
try:
pass
except Exception as e:
Expand Down Expand Up @@ -961,7 +961,7 @@ def skip(*args, **kwargs):
if device is not None:
device = torch.device(device)
if not max_memory and not device_map:
device_map = {"": device.index if device.type == "cuda" else device.type}
device_map = {"": device.index if device.type == DEVICE_TYPE_CUDA else device.type}
if not isinstance(device_map, dict) and device_map != "sequential":
max_memory = accelerate.utils.get_balanced_memory(
model=model,
Expand Down
27 changes: 21 additions & 6 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
import torch.nn as nn

from ...models._const import DEVICE_TYPE_CUDA

class BaseQuantLinear(nn.Module):

class BaseQuantLinear(nn.Module):
SUPPORTED_BITS = []
SUPPORTED_GROUP_SIZE = []
SUPPORTED_DESC_ACT = [True, False]
SUPPORTED_SYM = [True, False]
SUPPORTED_SHARDS: bool = True
SUPPORTED_DEVICES = [DEVICE_TYPE_CUDA]

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, *args, **kwargs):
super().__init__()
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym)
if err:
raise NotImplementedError(err)

@classmethod
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_error: bool = True) -> bool:
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool) -> bool:
validate, _ = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym)
return validate

@classmethod
def _validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, ):
validate = True
print("cccc",cls, cls.SUPPORTED_BITS)
err = ""
if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS:
validate = False
Expand All @@ -25,11 +39,12 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_e
elif cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT:
validate = False
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
return validate, err

if not validate and raise_error:
raise NotImplementedError(err)

return validate
@classmethod
def validate_device(cls, device_type: str):
if cls.SUPPORTED_DEVICES and device_type not in cls.SUPPORTED_DEVICES:
raise NotImplementedError(f"{cls} only supports `{cls.SUPPORTED_DEVICES}` bits: actual device = `{device_type}`")

# override me
def post_init(self):
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def __init__(
self,
bits: int,
group_size: int,
sym: bool,
desc_act: bool,
sym: bool,
infeatures: int,
outfeatures: int,
bias: bool,
Expand All @@ -104,13 +104,11 @@ def __init__(
layout: str = "nt",
**kwargs,
):
super().__init__()
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

# TODO: remove delayed import after bitblas whl support for 11.7, 11.8, 12.0 are added
import_bitblas()

self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)

self._validate_parameters(group_size, infeatures, outfeatures)

self.bits = bits
Expand Down Expand Up @@ -243,6 +241,7 @@ def reset_parameters(self):
self.q_params = None

def post_init(self):
self.validate_device(self.qweight.device.type)
# eliminate runtime overhead like exllama state
param_list = [self.qweight, self.scales, self.zeros]
if self.bitblas_matmul.config.with_bias:
Expand Down
8 changes: 4 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel_exllama_kernels import make_q4, q4_matmul
from gptqmodel.models._const import CUDA

logger = getLogger(__name__)

Expand Down Expand Up @@ -40,9 +41,8 @@ class ExllamaQuantLinear(BaseQuantLinear):

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeatures: int, outfeatures: int, bias: bool, **kwargs,):
super().__init__()
self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, device: torch.device, infeatures: int, outfeatures: int, bias: bool, **kwargs,):
super().__init__(bits, group_size, sym, desc_act, device, **kwargs)

self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self, bits: int, group_size: int , sym:bool, desc_act: bool, infeat
self.bias = None

def post_init(self):
assert self.qweight.device.type == "cuda"
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,9 @@ class ExllamaV2QuantLinear(BaseQuantLinear):

"""Linear layer implementation with per-group 4-bit quantization of the weights"""

def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeatures: int, outfeatures: int,
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs,):
super().__init__()
self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

self.q_handle = None
self.q_tensors = None
Expand Down Expand Up @@ -156,7 +155,7 @@ def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeat
self.bias = None

def post_init(self, temp_dq):
assert self.qweight.device.type == "cuda"
self.validate_device(self.qweight.device.type)
assert self.qweight.device.index is not None

# resize due to padding after model weights have been loaded
Expand Down
8 changes: 5 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,9 @@ class MarlinQuantLinear(BaseQuantLinear):
SUPPORTED_DESC_ACT = [False]
SUPPORTED_SYM = [True]

def __init__(self, bits: int, group_size: int, sym: bool, desc_act: bool, infeatures: int, outfeatures: int,
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int,
bias: bool, **kwargs):
super().__init__()
self.validate(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

if not torch.cuda.get_device_capability()[0] >= 8:
raise ValueError(
Expand Down Expand Up @@ -170,6 +169,9 @@ def forward(self, A):
C = C + self.bias if self.bias is not None else C
return C

def post_init(self):
self.validate_device(self.B.device.type)


# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
Expand Down
11 changes: 5 additions & 6 deletions gptqmodel/nn_modules/qlinear/qlinear_qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn as nn
import transformers
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
from gptqmodel.models._const import DEVICE_TYPE_CPU

logger = getLogger(__name__)

Expand Down Expand Up @@ -38,11 +39,13 @@ def convert_dtype_torch2str(dtype):

class QBitsQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4, 8]
SUPPORTED_DEVICES = [DEVICE_TYPE_CPU]

def __init__(
self,
bits: int,
group_size: int,
desc_act: bool,
sym: bool,
infeatures: int,
outfeatures: int,
Expand All @@ -52,12 +55,8 @@ def __init__(
weight_dtype=torch.bfloat16,
**kwargs,
):
super().__init__()

self.sym = False


self.validate(bits=bits, group_size=group_size, sym=self.sym, desc_act=False)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)

self.infeatures = infeatures
self.outfeatures = outfeatures
Expand Down Expand Up @@ -102,9 +101,9 @@ def __init__(
self.trainable = trainable

def post_init(self):
self.validate_device(self.qweight.device.type)
from intel_extension_for_transformers import qbits

assert self.qweight.device.type == "cpu"
if self.bias is not None:
self.bias = self.bias.to(dtype=torch.float32)

Expand Down
9 changes: 4 additions & 5 deletions gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
SUPPORTED_BITS = [2, 4, 8]
"""
Triton v2 quantized linear layer.
Expand All @@ -22,10 +23,8 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
dequant and matmul into single kernel.add()
"""

def __init__(self, bits, group_size, infeatures, outfeatures, bias, **kwargs,):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures, outfeatures, bias, **kwargs,):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, **kwargs)
if infeatures % 32 != 0 or outfeatures % 32 != 0:
raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
self.infeatures = infeatures
Expand Down Expand Up @@ -65,7 +64,7 @@ def __init__(self, bits, group_size, infeatures, outfeatures, bias, **kwargs,):
self.bias = None

def post_init(self):
pass
self.validate_device(self.qweight.device.type)

def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def select_quant_linear(
allow_backends = format_dict[format]
for k, v in backend_dict.items():
in_allow_backends = k in allow_backends
validate = v.validate(bits, group_size, desc_act, sym, raise_error=False)
validate = v.validate(bits, group_size, desc_act, sym)
check_pack_func = hasattr(v, "pack") if pack else True
if in_allow_backends and validate and check_pack_func:
logger.info(f"Auto choose the fastest one based on quant model compatibility: {v}")
Expand Down

0 comments on commit 9735bb6

Please sign in to comment.