Skip to content

Commit

Permalink
feat(qbits): prioritize Marlin int4 kernel over AWQ
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 6, 2024
1 parent 5aee658 commit 476a9dd
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions optimum/quanto/tensor/weights/qbits.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,46 @@ def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_g
a `WeightQBitsTensor` (can be a subclass).
"""
from .awq import AWQWeightQBitsTensor
from .marlin import MarlinInt4WeightQBitsTensor
from .tinygemm import TinyGemmWeightQBitsTensor

if (
qtype == qint4
and size[0] >= 128 # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs)
and scale.dtype == torch.float16
and axis == 0
and group_size == 128
and len(size) == 2
and (data.device.type == "cuda" and torch.version.cuda)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
if type(data) is PackedTensor:
data = data.unpack()
return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)
if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2:
if data.device.type == "cpu" or (
(data.device.type == "cuda" and torch.version.cuda)
and version.parse(torch.version.cuda).release >= (12, 1)
if qtype == qint4 and axis == 0 and group_size == 128 and len(size) == 2:
if (
scale.dtype == torch.float16
and data.device.type == "cuda"
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
if type(data) is PackedTensor:
data = data.unpack()
return TinyGemmWeightQBitsTensor(
qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad
)
out_features, in_features = size
# Marlin kernel uses two configurations for defining thread blocks
# along n (out_features) and k (in_features):
# - thread_n = 128, thread_k = 128
# - thread_n = 256, thread_k = 64
# This means that:
# - in_features must be divisible by 128 (and thus by 64 also)
# - out_features must be divisible by 256 (and thus by 128 also)
if in_features % 128 == 0 and out_features % 256 == 0:
if type(data) is PackedTensor:
data = data.unpack()
return MarlinInt4WeightQBitsTensor(
qtype, axis, group_size, size, stride, data, scale, shift, requires_grad
)
if size[0] >= 128: # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs)
if type(data) is PackedTensor:
data = data.unpack()
return AWQWeightQBitsTensor(
qtype, axis, group_size, size, stride, data, scale, shift, requires_grad
)
if scale.dtype == torch.bfloat16:
if data.device.type == "cpu" or (
data.device.type == "cuda"
and version.parse(torch.version.cuda).release >= (12, 1)
and torch.cuda.get_device_capability(data.device)[0] >= 8
):
if type(data) is PackedTensor:
data = data.unpack()
return TinyGemmWeightQBitsTensor(
qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad
)

return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)

Expand Down

0 comments on commit 476a9dd

Please sign in to comment.