Skip to content

Commit

Permalink
加入int4下linear的向量化访存,提高推理速度。
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Apr 19, 2024
1 parent f56a048 commit da16e05
Showing 1 changed file with 52 additions and 5 deletions.
57 changes: 52 additions & 5 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ typedef union __align__(16) _union_half_4 {
}
} union_half4;

typedef union __align__(16) _union_half_8 {
uint4 in;
half out[8];
half2 out2[4];
__device__ _union_half_8() {
// Do nothing
}
} union_half8;

const size_t ST128_FP16_COUNT = 8;

static std::map<int, cublasHandle_t> s_fastllmCublasHandleMap;
Expand Down Expand Up @@ -134,13 +143,46 @@ __global__ void FastllmCudaInt4Group2HalfKernel(uint8_t* a, float *scales, float

__global__ void FastllmCudaInt42HalfKernel(uint8_t* a, float *scales, float *mins, half *b, int len, int per) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
#ifdef CUDA_NO_TENSOR_CORE
float2 scalesBuffer;
float2 minBuffer;
int threshold = ST128_FP16_COUNT;
for (int index = idx * ST128_FP16_COUNT; index < len; index += (gridDim.x * blockDim.x) * ST128_FP16_COUNT) {
int startIdx = index / per;
int endIdx = (index + ST128_FP16_COUNT - 1) / per;
scalesBuffer.x = scalesBuffer.y = __ldg(scales + startIdx);
minBuffer.x = minBuffer.y = __ldg(mins + startIdx);
if (endIdx > startIdx) {
threshold = (idx + ST128_FP16_COUNT - 1) % per;
scalesBuffer.y = __ldg(scales + endIdx);
minBuffer.y = __ldg(mins + endIdx);
}
// 读取
union_char4 aBuffer;
union_half8 bBuffer;
aBuffer.in = *reinterpret_cast<const uint32_t *>(a + index / 2);
// 处理
for (int i = 0; i < ST128_FP16_COUNT / 2; i++) {
if (index + i * 2 + 1 < len) {
float scale = i * 2 < threshold ? scalesBuffer.x : scalesBuffer.y;
float min = i * 2 < threshold ? minBuffer.x : minBuffer.y;
bBuffer.out[i * 2] = __float2half(scale * (aBuffer.out[i] >> 4) + min);
bBuffer.out[i * 2 + 1] = __float2half(scale * (aBuffer.out[i] & 0xF) + min);
}
// if (a[index + i] != aBuffer.out[i] && index < 100)
// printf("%d - %d : %d\n", index + i, a[index + i], aBuffer.out[i]);
}
reinterpret_cast<uint4 *>(b)[idx] = bBuffer.in;
}
#else
if (idx < len) {
if (idx % 2 == 1) {
b[idx] = __float2half(scales[idx / per] * (a[idx / 2] & 0xF) + mins[idx / per]);
} else {
b[idx] = __float2half(scales[idx / per] * (a[idx / 2] >> 4) + mins[idx / per]);
}
}
#endif
}

__global__ void FastllmCudaHalf2FlotaKernel(half* a, float *b, int len) {
Expand Down Expand Up @@ -1837,12 +1879,12 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
len);

len = k * m;
FastllmCudaInt42HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t *) weight.cudaData,
cudaScales,
cudaMins,
cudaFp16Weight, len, m);

#ifdef CUDA_NO_TENSOR_CORE
int gridSize = (len - 1) / (threadPerBlock * 4) + 1;
FastllmCudaInt42HalfKernel <<< gridSize, threadPerBlock>>>((uint8_t *) weight.cudaData,
cudaScales, cudaMins,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand All @@ -1852,6 +1894,11 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#else
FastllmCudaInt42HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t *) weight.cudaData,
cudaScales,
cudaMins,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand Down

0 comments on commit da16e05

Please sign in to comment.