Skip to content

Commit

Permalink
Merge pull request #449 from TylunasLi/no_tensor_core
Browse files Browse the repository at this point in the history
采用向量化访存优化旧架构GPU性能
  • Loading branch information
ztxz16 authored Apr 26, 2024
2 parents 4736535 + da16e05 commit c7f45a3
Showing 1 changed file with 219 additions and 19 deletions.
238 changes: 219 additions & 19 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,39 @@ void showError(cudaError_t result, char const* const message, const char* const
}
}


#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])

typedef union __align__(16) {
uint2 in;
uint8_t out[8];
} union_char8;

typedef union __align__(16) {
uint32_t in;
uint8_t out[4];
} union_char4;

typedef union __align__(16) _union_half_4 {
uint2 in;
half out[4];
half2 out2[2];
__device__ _union_half_4() {
// Do nothing
}
} union_half4;

typedef union __align__(16) _union_half_8 {
uint4 in;
half out[8];
half2 out2[4];
__device__ _union_half_8() {
// Do nothing
}
} union_half8;

const size_t ST128_FP16_COUNT = 8;

static std::map<int, cublasHandle_t> s_fastllmCublasHandleMap;
cublasHandle_t getFastllmCublasHandle() {
int id = -1;
Expand Down Expand Up @@ -59,10 +92,40 @@ __global__ void FastllmCudaFloat2HalfKernel(float* a, half *b, int len) {
}

__global__ void FastllmCudaInt82HalfKernel(uint8_t* a, float *scales, uint8_t *zeros, half *b, int len, int per) {
#ifdef CUDA_NO_TENSOR_CORE
float scalesBuffer[2];
uint8_t zerosBuffer[2];
int threshold = ST128_FP16_COUNT;
int index = (threadIdx.x + blockIdx.x * blockDim.x) * ST128_FP16_COUNT;
for (int idx = index; idx < len; idx += (gridDim.x * blockDim.x) * ST128_FP16_COUNT) {
int startIdx = idx / per;
int endIdx = (idx + ST128_FP16_COUNT - 1) / per;
scalesBuffer[1] = scalesBuffer[0] = scales[startIdx];
zerosBuffer[1] = zerosBuffer[0] = zeros[startIdx];
if (endIdx > startIdx) {
threshold = (idx + ST128_FP16_COUNT - 1) % per;
scalesBuffer[1] = scales[endIdx];
zerosBuffer[1] = zeros[endIdx];
}
// 读取
union_char8 aBuffer[2];
half bBuffer[ST128_FP16_COUNT];
aBuffer[0].in = *reinterpret_cast<const uint2 *>(a + idx);
// 处理
for (int i=0; i<ST128_FP16_COUNT; i++) {
if (idx + i < len) {
int scaleIdx = i < threshold ? 0 : 1;
bBuffer[i] = __float2half(scalesBuffer[scaleIdx] * ((float)aBuffer[0].out[i] - zerosBuffer[scaleIdx]));
}
}
reinterpret_cast<uint4 *>(b)[idx / ST128_FP16_COUNT] = *reinterpret_cast<uint4 *>(bBuffer);
}
#else
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
b[idx] = __float2half(scales[idx / per] * ((float)a[idx] - zeros[idx / per]));
}
#endif
}

__global__ void FastllmCudaInt4Group2HalfKernel(uint8_t* a, float *scales, float *mins, half *b, int len, int per,
Expand All @@ -80,13 +143,46 @@ __global__ void FastllmCudaInt4Group2HalfKernel(uint8_t* a, float *scales, float

__global__ void FastllmCudaInt42HalfKernel(uint8_t* a, float *scales, float *mins, half *b, int len, int per) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
#ifdef CUDA_NO_TENSOR_CORE
float2 scalesBuffer;
float2 minBuffer;
int threshold = ST128_FP16_COUNT;
for (int index = idx * ST128_FP16_COUNT; index < len; index += (gridDim.x * blockDim.x) * ST128_FP16_COUNT) {
int startIdx = index / per;
int endIdx = (index + ST128_FP16_COUNT - 1) / per;
scalesBuffer.x = scalesBuffer.y = __ldg(scales + startIdx);
minBuffer.x = minBuffer.y = __ldg(mins + startIdx);
if (endIdx > startIdx) {
threshold = (idx + ST128_FP16_COUNT - 1) % per;
scalesBuffer.y = __ldg(scales + endIdx);
minBuffer.y = __ldg(mins + endIdx);
}
// 读取
union_char4 aBuffer;
union_half8 bBuffer;
aBuffer.in = *reinterpret_cast<const uint32_t *>(a + index / 2);
// 处理
for (int i = 0; i < ST128_FP16_COUNT / 2; i++) {
if (index + i * 2 + 1 < len) {
float scale = i * 2 < threshold ? scalesBuffer.x : scalesBuffer.y;
float min = i * 2 < threshold ? minBuffer.x : minBuffer.y;
bBuffer.out[i * 2] = __float2half(scale * (aBuffer.out[i] >> 4) + min);
bBuffer.out[i * 2 + 1] = __float2half(scale * (aBuffer.out[i] & 0xF) + min);
}
// if (a[index + i] != aBuffer.out[i] && index < 100)
// printf("%d - %d : %d\n", index + i, a[index + i], aBuffer.out[i]);
}
reinterpret_cast<uint4 *>(b)[idx] = bBuffer.in;
}
#else
if (idx < len) {
if (idx % 2 == 1) {
b[idx] = __float2half(scales[idx / per] * (a[idx / 2] & 0xF) + mins[idx / per]);
} else {
b[idx] = __float2half(scales[idx / per] * (a[idx / 2] >> 4) + mins[idx / per]);
}
}
#endif
}

__global__ void FastllmCudaHalf2FlotaKernel(half* a, float *b, int len) {
Expand Down Expand Up @@ -806,25 +902,51 @@ template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvFp32Fp16Kernel2(float *A, half *B, float *C, float *bias, int m, int k) {
__shared__ float sdata[THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;
const half zero = __float2half_rn(0.0);
float4 regA;
union_half4 regB;

// 1. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
const half *baseB = B + p * m;
#ifdef CUDA_NO_TENSOR_CORE
#pragma unroll
for (int i = tid*4; i < m; i += THREAD_PER_BLOCK*4) {
regA = FETCH_FLOAT4(A[i]);
regB.in = *reinterpret_cast<const uint2 *>(baseB + i);
float sum = 0.0f;
if (i < m)
sum += regA.x * __low2float(regB.out2[0]);
if (i + 1 < m)
sum += regA.y * __high2float(regB.out2[0]);
if (i + 2 < m)
sum += regA.z * __low2float(regB.out2[1]);
if (i + 3 < m)
sum += regA.w * __high2float(regB.out2[1]);
sdata[tid] += sum;
}
#else
for (int i = tid; i < m; i += THREAD_PER_BLOCK) {
sdata[tid] += A[i] * (float)B[p * m + i];
}
#endif
__syncthreads();
for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) {
if ((tid & (2 * s - 1)) == 0) {
sdata[tid] += sdata[tid + s];
float diff = 0.0f;
for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) {
if (tid < s) {
float other = sdata[tid + s] - diff;
float sumTmp = sdata[tid] + other;
diff = (sumTmp - sdata[tid]) - other;
sdata[tid] = sumTmp;
}
__syncthreads();
}

if (tid == 0) {
C[p] = sdata[0] + bias[p];
C[p] = sdata[0] + __ldg(bias + p);
}
__syncthreads();
}
Expand All @@ -843,25 +965,51 @@ __global__ void FastllmGemvInt8Kernel2(float *A, uint8_t *B, float *C,
}
__syncthreads();*/

float4 regA;
union_char4 regB;

// 2. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
uint8_t zero = zeros[p];
const uint8_t *baseB = B + p * m;
#ifdef CUDA_NO_TENSOR_CORE
#pragma unroll
for (int i = tid*4; i < m; i += THREAD_PER_BLOCK*4) {
regA = FETCH_FLOAT4(A[i]);
regB.in = *reinterpret_cast<const uint32_t *>(baseB + i);
float sum = 0.0f;
if (i < m)
sum += regA.x * (float)(regB.out[0] - zero);
if (i + 1 < m)
sum += regA.y * (float)(regB.out[1] - zero);
if (i + 2 < m)
sum += regA.z * (float)(regB.out[2] - zero);
if (i + 3 < m)
sum += regA.w * (float)(regB.out[3] - zero);
sdata[tid] += sum;
}
#else
for (int i = tid; i < m; i += THREAD_PER_BLOCK) {
sdata[tid] += A[i] * (B[p * m + i] - zero);
}
#endif
__syncthreads();
for (unsigned int s = 1; s < THREAD_PER_BLOCK; s *= 2) {
if ((tid & (2 * s - 1)) == 0) {
sdata[tid] += sdata[tid + s];
float diff = 0.0f;
for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) {
if (tid < s) {
float other = sdata[tid + s] - diff;
float sumTmp = sdata[tid] + other;
diff = (sumTmp - sdata[tid]) - other;
sdata[tid] = sumTmp;
}
__syncthreads();
}

if (tid == 0) {
C[p] = sdata[0] * scales[p] + bias[p];
C[p] = sdata[0] * __ldg(scales + p) + __ldg(bias + p);
}
__syncthreads();
}
Expand Down Expand Up @@ -1020,6 +1168,47 @@ __global__ void FastllmGemvInt4NoZeroKernel2(float *A, uint8_t *B, float *C,
}
}

template <int THREAD_PER_BLOCK, int PART>
__global__ void FastllmGemvInt4NoZeroKernel1(float *A, uint8_t *B, float *C,
float *bias, float *scales, float *mins,
int m, int k) {
__shared__ float sdata[THREAD_PER_BLOCK];
unsigned int tid = threadIdx.x;

// 1. 计算
int st = blockIdx.x * PART;
int end = st + PART;
for (int p = st; p < end; p++) {
sdata[tid] = 0;
const uint8_t *baseB = B + p * m / 2;
float minv = __ldg(mins + p) / __ldg(scales + p);
for (int i = tid * 2; i < m / 2; i += THREAD_PER_BLOCK * 2) {
float4 aBuffer = FETCH_FLOAT4(A[i * 2]);
uint16_t bBuffer = *reinterpret_cast<const uint16_t *>(baseB + i);
sdata[tid] += aBuffer.x * (minv + ((bBuffer >> 4) & 15)) + aBuffer.y * (minv + (bBuffer & 15));
sdata[tid] += aBuffer.z * (minv + (bBuffer >> 12)) + aBuffer.w * (minv + ((bBuffer >> 8) & 15));
}
__syncthreads();

float diff = 0.0f;
for (unsigned int s = THREAD_PER_BLOCK/2; s > 0; s >>= 1) {
if (tid < s) {
float other = sdata[tid + s] - diff;
float sumTmp = sdata[tid] + other;
diff = (sumTmp - sdata[tid]) - other;
sdata[tid] = sumTmp;
}
__syncthreads();
}
//if (tid <= 32)
//warpReduce(sdata, tid);
if (tid == 0) {
C[p] = sdata[0] * scales[p] + bias[p];
}
__syncthreads();
}
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmSplitBatchKernel(uint8_t *input, uint8_t **outputs, int outer, int channels, int inner) {
int bid = blockIdx.x / outer, oid = blockIdx.x % outer;
Expand Down Expand Up @@ -1416,12 +1605,13 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input, len);

len = k * m;
FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaZeropoints,
cudaFp16Weight, len, m);

#ifdef CUDA_NO_TENSOR_CORE
int gridSize = (len - 1) / (threadPerBlock * ST128_FP16_COUNT) + 1;
FastllmCudaInt82HalfKernel <<< gridSize, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaZeropoints,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand All @@ -1431,6 +1621,11 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#else
FastllmCudaInt82HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t*)weight.cudaData,
cudaScales,
cudaZeropoints,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand Down Expand Up @@ -1684,12 +1879,12 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
len);

len = k * m;
FastllmCudaInt42HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t *) weight.cudaData,
cudaScales,
cudaMins,
cudaFp16Weight, len, m);

#ifdef CUDA_NO_TENSOR_CORE
int gridSize = (len - 1) / (threadPerBlock * 4) + 1;
FastllmCudaInt42HalfKernel <<< gridSize, threadPerBlock>>>((uint8_t *) weight.cudaData,
cudaScales, cudaMins,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand All @@ -1699,6 +1894,11 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
cudaOutput, CType,
k, ComputeType, static_cast<cublasGemmAlgo_t>(CUBLAS_GEMM_DEFAULT));
#else
FastllmCudaInt42HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((uint8_t *) weight.cudaData,
cudaScales,
cudaMins,
cudaFp16Weight, len, m);

status = cublasGemmEx(fastllmCublasHandle,
CUBLAS_OP_T, CUBLAS_OP_N,
k, n, m,
Expand Down Expand Up @@ -1730,7 +1930,7 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
#endif
} else {
for (int i = 0; i < n; i++) {
FastllmGemvInt4NoZeroKernel2<256, 1> <<< k, 256 >>>(cudaInput + i * m,
FastllmGemvInt4NoZeroKernel1<256, 1> <<< k, 256 >>>(cudaInput + i * m,
(uint8_t *) weight.cudaData,
cudaOutput + i * k,
cudaBiasData,
Expand Down

0 comments on commit c7f45a3

Please sign in to comment.