From c0f8b11ed6e0a649d32891a253acdd9e1e335a19 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Wed, 4 Dec 2024 12:06:33 +0800 Subject: [PATCH] add TorchQuantLinear --- gptqmodel/nn_modules/qlinear/qlinear_cuda.py | 325 ++++-------------- gptqmodel/nn_modules/qlinear/qlinear_torch.py | 242 +++++++++++++ gptqmodel/utils/importer.py | 6 +- 3 files changed, 313 insertions(+), 260 deletions(-) create mode 100644 gptqmodel/nn_modules/qlinear/qlinear_torch.py diff --git a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py index ded0cc86c..212c06e46 100644 --- a/gptqmodel/nn_modules/qlinear/qlinear_cuda.py +++ b/gptqmodel/nn_modules/qlinear/qlinear_cuda.py @@ -1,92 +1,42 @@ -import math - -import numpy as np import torch -import torch.nn as nn -import transformers from gptqmodel.nn_modules.qlinear import BaseQuantLinear from gptqmodel.utils.logger import setup_logger +from gptqmodel.nn_modules.qlinear.qlinear_torch import TorchQuantLinear logger = setup_logger() +cuda_import_exception = None try: import gptqmodel_cuda_64 import gptqmodel_cuda_256 _gptqmodel_cuda_available = True -except ImportError: +except ImportError as e: + cuda_import_exception = e logger.warning("CUDA extension not installed.") gptqmodel_cuda_256 = None gptqmodel_cuda_64 = None _gptqmodel_cuda_available = False -class CudaQuantLinear(BaseQuantLinear): + +class CudaQuantLinear(TorchQuantLinear): SUPPORTS_BITS = [2, 3, 4, 8] def __init__( - self, - bits: int, - group_size: int, - sym: bool, - desc_act: bool, - infeatures: int, - outfeatures: int, - bias: bool, - kernel_switch_threshold=128, - weight_dtype=torch.float16, - **kwargs, + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + infeatures: int, + outfeatures: int, + bias: bool, + kernel_switch_threshold=128, + weight_dtype=torch.float16, + **kwargs, ): - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs) - global _gptqmodel_cuda_available - - self.infeatures = infeatures - self.outfeatures = outfeatures - self.bits = bits - self.group_size = group_size if group_size != -1 else infeatures - self.maxq = 2**self.bits - 1 - - self.register_buffer( - "qweight", - torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(infeatures / self.group_size), - outfeatures // 32 * self.bits, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), - dtype=weight_dtype, - ), - ) - self.register_buffer( - "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), - ) - if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype)) - else: - self.bias = None - - # is performed by unpacking the weights and using torch.matmul - if self.bits in [2, 4, 8]: - self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0) - elif self.bits == 3: - self.wf = torch.tensor( - [ - [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], - [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], - [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], - ], - dtype=torch.int32, - ).reshape(1, 3, 12) + super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, + outfeatures=outfeatures, **kwargs) self.kernel_switch_threshold = kernel_switch_threshold self.gptqmodel_cuda_available = _gptqmodel_cuda_available @@ -97,204 +47,61 @@ def __init__( if infeatures % 64 != 0 or outfeatures % 64 != 0: self.gptqmodel_cuda_available = False - def post_init(self): - pass - - def pack(self, linear, scales, zeros, g_idx=None): - W = linear.weight.data.clone() - if isinstance(linear, nn.Conv2d): - W = W.flatten(1) - if isinstance(linear, transformers.pytorch_utils.Conv1D): - W = W.t() - - self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx - - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() - scale_zeros = zeros * scales - self.scales = scales.clone().to(dtype=linear.weight.dtype) - if linear.bias is not None: - self.bias = linear.bias.clone().to(dtype=linear.weight.dtype) - - intweight = [] - for idx in range(self.infeatures): - intweight.append( - torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[ - :, None - ] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.numpy().astype(np.uint32) - - i = 0 - row = 0 - qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) - while row < qweight.shape[0]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qweight[row] |= intweight[j] << (self.bits * (j - i)) - i += 32 // self.bits - row += 1 - elif self.bits == 3: - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i)) - i += 10 - qweight[row] |= intweight[i] << 30 - row += 1 - qweight[row] |= (intweight[i] >> 2) & 1 - i += 1 - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 1) - i += 10 - qweight[row] |= intweight[i] << 31 - row += 1 - qweight[row] |= (intweight[i] >> 1) & 0x3 - i += 1 - for j in range(i, i + 10): - qweight[row] |= intweight[j] << (3 * (j - i) + 2) - i += 10 - row += 1 - - qweight = qweight.astype(np.int32) - self.qweight = torch.from_numpy(qweight) - - zeros = zeros.numpy().astype(np.uint32) - qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) - i = 0 - col = 0 - while col < qzeros.shape[1]: - if self.bits in [2, 4, 8]: - for j in range(i, i + (32 // self.bits)): - qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) - i += 32 // self.bits - col += 1 - elif self.bits == 3: - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) - i += 10 - qzeros[:, col] |= zeros[:, i] << 30 - col += 1 - qzeros[:, col] |= (zeros[:, i] >> 2) & 1 - i += 1 - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) - i += 10 - qzeros[:, col] |= zeros[:, i] << 31 - col += 1 - qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 - i += 1 - for j in range(i, i + 10): - qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) - i += 10 - col += 1 - - qzeros = qzeros.astype(np.int32) - self.qzeros = torch.from_numpy(qzeros) - def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.outfeatures,) x = x.reshape(-1, x.shape[-1]) x_dtype = x.dtype - if ( - x.device.type == "cuda" - and self.gptqmodel_cuda_available - and (self.kernel_switch_threshold == 0 or x.shape[0] < self.kernel_switch_threshold) - ): - out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) - if self.bits == 2: - self.gptqmodel_cuda.vecquant2matmul( - x.float(), - self.qweight, - out, - self.scales.float(), - self.qzeros, - self.g_idx, - ) - elif self.bits == 3: - self.gptqmodel_cuda.vecquant3matmul( - x.float(), - self.qweight, - out, - self.scales.float(), - self.qzeros, - self.g_idx, - ) - elif self.bits == 4: - self.gptqmodel_cuda.vecquant4matmul( - x.float(), - self.qweight, - out, - self.scales.float(), - self.qzeros, - self.g_idx, - ) - elif self.bits == 8: - self.gptqmodel_cuda.vecquant8matmul( - x.float(), - self.qweight, - out, - self.scales.float(), - self.qzeros, - self.g_idx, - ) - else: - if self.wf.device != self.qzeros.device: - self.wf = self.wf.to(self.qzeros.device) - if self.bits in [2, 4, 8]: - zeros = torch.bitwise_right_shift( - torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), - self.wf.unsqueeze(0), - ).to(torch.int16 if self.bits == 8 else torch.int8) - zeros = torch.bitwise_and(zeros, (2**self.bits) - 1) + if x.device.type != "cuda": + raise NotImplementedError(f"Unable to use cuda kernel. x.device.type is {x.device.type}") - zeros = zeros.reshape(self.scales.shape) - - weight = torch.bitwise_right_shift( - torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), - self.wf.unsqueeze(-1), - ).to(torch.int16 if self.bits == 8 else torch.int8) - weight = torch.bitwise_and(weight, (2**self.bits) - 1) - elif self.bits == 3: - zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand( - -1, -1, -1, 12 - ) - zeros = zeros >> self.wf.unsqueeze(0) - zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) - zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) - zeros = zeros & 0x7 - zeros = torch.cat( - [zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], - dim=2, - ) - - zeros = zeros.reshape(self.scales.shape) + if not self.gptqmodel_cuda_available: + raise ValueError( + f"Trying to use the cuda backend, but could not import the C++/CUDA dependencies with the following error: {cuda_import_exception}" + ) - weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand( - -1, -1, 12, -1 - ) - weight = (weight >> self.wf.unsqueeze(-1)) & 0x7 - weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) - weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) - weight = weight & 0x7 - weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) + if self.kernel_switch_threshold != 0 and x.shape[0] >= self.kernel_switch_threshold: + raise ValueError( + f"Trying to use the cuda backend, x.shape[0] is {x.shape[0]}, x.shape[0] cannot be greater than kernel_switch_threshold{self.kernel_switch_threshold}" + ) - weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) - num_itr = self.g_idx.shape[0] // x.shape[-1] - if num_itr == 1: - weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]) - else: - num_dim = self.g_idx.shape[0] // num_itr - weights = [] - for i in range(num_itr): - scale_i = self.scales[:, i * num_dim : (i + 1) * num_dim] - weight_i = weight[:, i * num_dim : (i + 1) * num_dim] - zeros_i = zeros[:, i * num_dim : (i + 1) * num_dim] - g_idx_i = self.g_idx[i * num_dim : (i + 1) * num_dim] - weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) - weights = torch.cat(weights, dim=1) - out = torch.matmul(x, weights) + out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + if self.bits == 2: + self.gptqmodel_cuda.vecquant2matmul( + x.float(), + self.qweight, + out, + self.scales.float(), + self.qzeros, + self.g_idx, + ) + elif self.bits == 3: + self.gptqmodel_cuda.vecquant3matmul( + x.float(), + self.qweight, + out, + self.scales.float(), + self.qzeros, + self.g_idx, + ) + elif self.bits == 4: + self.gptqmodel_cuda.vecquant4matmul( + x.float(), + self.qweight, + out, + self.scales.float(), + self.qzeros, + self.g_idx, + ) + elif self.bits == 8: + self.gptqmodel_cuda.vecquant8matmul( + x.float(), + self.qweight, + out, + self.scales.float(), + self.qzeros, + self.g_idx, + ) out = out.to(x_dtype) out = out.reshape(out_shape) out = out + self.bias if self.bias is not None else out diff --git a/gptqmodel/nn_modules/qlinear/qlinear_torch.py b/gptqmodel/nn_modules/qlinear/qlinear_torch.py new file mode 100644 index 000000000..1b6b19e9d --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/qlinear_torch.py @@ -0,0 +1,242 @@ +import math + +import numpy as np +import torch +import torch.nn as nn +import transformers + +from gptqmodel.models._const import DEVICE +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.utils.logger import setup_logger + +logger = setup_logger() + +class TorchQuantLinear(BaseQuantLinear): + SUPPORTS_BITS = [2, 3, 4, 8] + SUPPORTS_DEVICES = [] # empty means all devices are supported. + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + infeatures: int, + outfeatures: int, + bias: bool, + weight_dtype=torch.float16, + **kwargs, + ): + super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, **kwargs) + + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.group_size = group_size if group_size != -1 else infeatures + self.maxq = 2**self.bits - 1 + + self.register_buffer( + "qweight", + torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32), + ) + self.register_buffer( + "qzeros", + torch.zeros( + ( + math.ceil(infeatures / self.group_size), + outfeatures // 32 * self.bits, + ), + dtype=torch.int32, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + (math.ceil(infeatures / self.group_size), outfeatures), + dtype=weight_dtype, + ), + ) + self.register_buffer( + "g_idx", + torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), + ) + if bias: + self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype)) + else: + self.bias = None + + # is performed by unpacking the weights and using torch.matmul + if self.bits in [2, 4, 8]: + self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0) + elif self.bits == 3: + self.wf = torch.tensor( + [ + [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0], + [0, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31], + [0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0], + ], + dtype=torch.int32, + ).reshape(1, 3, 12) + + def post_init(self): + pass + + def pack(self, linear, scales, zeros, g_idx=None): + W = linear.weight.data.clone() + if isinstance(linear, nn.Conv2d): + W = W.flatten(1) + if isinstance(linear, transformers.pytorch_utils.Conv1D): + W = W.t() + + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().to(dtype=linear.weight.dtype) + if linear.bias is not None: + self.bias = linear.bias.clone().to(dtype=linear.weight.dtype) + + intweight = [] + for idx in range(self.infeatures): + intweight.append( + torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[ + :, None + ] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + + i = 0 + row = 0 + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i)) + i += 10 + qweight[row] |= intweight[i] << 30 + row += 1 + qweight[row] |= (intweight[i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 1) + i += 10 + qweight[row] |= intweight[i] << 31 + row += 1 + qweight[row] |= (intweight[i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qweight[row] |= intweight[j] << (3 * (j - i) + 2) + i += 10 + row += 1 + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + elif self.bits == 3: + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) + i += 10 + qzeros[:, col] |= zeros[:, i] << 30 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 2) & 1 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) + i += 10 + qzeros[:, col] |= zeros[:, i] << 31 + col += 1 + qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 + i += 1 + for j in range(i, i + 10): + qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) + i += 10 + col += 1 + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x: torch.Tensor): + out_shape = x.shape[:-1] + (self.outfeatures,) + x = x.reshape(-1, x.shape[-1]) + x_dtype = x.dtype + if self.wf.device != self.qzeros.device: + self.wf = self.wf.to(self.qzeros.device) + + if self.bits in [2, 4, 8]: + zeros = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), + self.wf.unsqueeze(0), + ).to(torch.int16 if self.bits == 8 else torch.int8) + zeros = torch.bitwise_and(zeros, (2**self.bits) - 1) + + zeros = zeros.reshape(self.scales.shape) + + weight = torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), + self.wf.unsqueeze(-1), + ).to(torch.int16 if self.bits == 8 else torch.int8) + weight = torch.bitwise_and(weight, (2**self.bits) - 1) + elif self.bits == 3: + zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1).expand( + -1, -1, -1, 12 + ) + zeros = zeros >> self.wf.unsqueeze(0) + zeros[:, :, 0, 10] = (zeros[:, :, 0, 10] & 0x3) | ((zeros[:, :, 1, 0] << 2) & 0x4) + zeros[:, :, 1, 11] = (zeros[:, :, 1, 11] & 0x1) | ((zeros[:, :, 2, 0] << 1) & 0x6) + zeros = zeros & 0x7 + zeros = torch.cat( + [zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], + dim=2, + ) + + zeros = zeros.reshape(self.scales.shape) + + weight = self.qweight.reshape(self.qweight.shape[0] // 3, 3, 1, self.qweight.shape[1]).expand( + -1, -1, 12, -1 + ) + weight = (weight >> self.wf.unsqueeze(-1)) & 0x7 + weight[:, 0, 10] = (weight[:, 0, 10] & 0x3) | ((weight[:, 1, 0] << 2) & 0x4) + weight[:, 1, 11] = (weight[:, 1, 11] & 0x1) | ((weight[:, 2, 0] << 1) & 0x6) + weight = weight & 0x7 + weight = torch.cat([weight[:, 0, :11], weight[:, 1, 1:12], weight[:, 2, 1:11]], dim=1) + + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + num_itr = self.g_idx.shape[0] // x.shape[-1] + if num_itr == 1: + weights = self.scales[self.g_idx.long()] * (weight - zeros[self.g_idx.long()]) + else: + num_dim = self.g_idx.shape[0] // num_itr + weights = [] + for i in range(num_itr): + scale_i = self.scales[:, i * num_dim : (i + 1) * num_dim] + weight_i = weight[:, i * num_dim : (i + 1) * num_dim] + zeros_i = zeros[:, i * num_dim : (i + 1) * num_dim] + g_idx_i = self.g_idx[i * num_dim : (i + 1) * num_dim] + weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()])) + weights = torch.cat(weights, dim=1) + out = torch.matmul(x, weights) + out = out.to(x_dtype) + out = out.reshape(out_shape) + out = out + self.bias if self.bias is not None else out + return out + + +__all__ = ["TorchQuantLinear"] diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index cd6a451fa..99632c204 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -10,6 +10,7 @@ from ..nn_modules.qlinear.qlinear_ipex import IPEXQuantLinear from ..nn_modules.qlinear.qlinear_marlin import MarlinQuantLinear from ..nn_modules.qlinear.qlinear_tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear +from ..nn_modules.qlinear.qlinear_torch import TorchQuantLinear from ..quantization import FORMAT from ..utils.logger import setup_logger @@ -23,6 +24,7 @@ BACKEND.CUDA: [CudaQuantLinear], BACKEND.BITBLAS: [BitBLASQuantLinear], BACKEND.IPEX: [IPEXQuantLinear], + BACKEND.TORCH: [TorchQuantLinear], }) format_dict = { @@ -129,5 +131,7 @@ def select_quant_linear( raise ValueError("IPEX/CPU requires minimum avx512_vnni support.") return IPEXQuantLinear + elif backend == BACKEND.TORCH: + return TorchQuantLinear else: - return CudaQuantLinear + return TorchQuantLinear