From 476a9dd3455042fd96339fb77e347d447fbb6d89 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 20 Sep 2024 15:00:22 +0000 Subject: [PATCH] feat(qbits): prioritize Marlin int4 kernel over AWQ --- optimum/quanto/tensor/weights/qbits.py | 58 ++++++++++++++++---------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/optimum/quanto/tensor/weights/qbits.py b/optimum/quanto/tensor/weights/qbits.py index 3afce3f5..644ac327 100644 --- a/optimum/quanto/tensor/weights/qbits.py +++ b/optimum/quanto/tensor/weights/qbits.py @@ -92,32 +92,46 @@ def create(qtype, axis, group_size, size, stride, data, scale, shift, requires_g a `WeightQBitsTensor` (can be a subclass). """ from .awq import AWQWeightQBitsTensor + from .marlin import MarlinInt4WeightQBitsTensor from .tinygemm import TinyGemmWeightQBitsTensor - if ( - qtype == qint4 - and size[0] >= 128 # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs) - and scale.dtype == torch.float16 - and axis == 0 - and group_size == 128 - and len(size) == 2 - and (data.device.type == "cuda" and torch.version.cuda) - and torch.cuda.get_device_capability(data.device)[0] >= 8 - ): - if type(data) is PackedTensor: - data = data.unpack() - return AWQWeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad) - if qtype == qint4 and scale.dtype == torch.bfloat16 and axis == 0 and group_size == 128 and len(size) == 2: - if data.device.type == "cpu" or ( - (data.device.type == "cuda" and torch.version.cuda) - and version.parse(torch.version.cuda).release >= (12, 1) + if qtype == qint4 and axis == 0 and group_size == 128 and len(size) == 2: + if ( + scale.dtype == torch.float16 + and data.device.type == "cuda" and torch.cuda.get_device_capability(data.device)[0] >= 8 ): - if type(data) is PackedTensor: - data = data.unpack() - return TinyGemmWeightQBitsTensor( - qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad - ) + out_features, in_features = size + # Marlin kernel uses two configurations for defining thread blocks + # along n (out_features) and k (in_features): + # - thread_n = 128, thread_k = 128 + # - thread_n = 256, thread_k = 64 + # This means that: + # - in_features must be divisible by 128 (and thus by 64 also) + # - out_features must be divisible by 256 (and thus by 128 also) + if in_features % 128 == 0 and out_features % 256 == 0: + if type(data) is PackedTensor: + data = data.unpack() + return MarlinInt4WeightQBitsTensor( + qtype, axis, group_size, size, stride, data, scale, shift, requires_grad + ) + if size[0] >= 128: # FIXME Workaround AWQ GEMM crash (GEMV might work for short inputs) + if type(data) is PackedTensor: + data = data.unpack() + return AWQWeightQBitsTensor( + qtype, axis, group_size, size, stride, data, scale, shift, requires_grad + ) + if scale.dtype == torch.bfloat16: + if data.device.type == "cpu" or ( + data.device.type == "cuda" + and version.parse(torch.version.cuda).release >= (12, 1) + and torch.cuda.get_device_capability(data.device)[0] >= 8 + ): + if type(data) is PackedTensor: + data = data.unpack() + return TinyGemmWeightQBitsTensor( + qtype, axis, group_size, size, stride, data, (scale, shift), requires_grad + ) return WeightQBitsTensor(qtype, axis, group_size, size, stride, data, scale, shift, requires_grad)