Skip to content

Commit

Permalink
长文本切块处理
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 20, 2024
1 parent b5a3902 commit 28e79e3
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 261 deletions.
27 changes: 13 additions & 14 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ template <int BN, int BM, int BK>
__global__ void HalfFC(
half * __restrict__ a, half * __restrict__ b, half * __restrict__ c,
const int N, const int M, const int K,
half scale) {
half scale, const int base) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 // support tensor core
int tid = threadIdx.x;
int bx = blockIdx.x;
Expand All @@ -140,7 +140,7 @@ __global__ void HalfFC(
int wrap0 = wid >> 1;
int wrap1 = wid & 1;

if (stN + BN <= stK) {
if (base + stN + BN <= stK) {
return;
}

Expand Down Expand Up @@ -199,7 +199,7 @@ __global__ void HalfFC(
__syncthreads();

for (int i = 0; i < BN; i++) {
if (stN + i < stK + tid) {
if (base + stN + i < stK + tid) {
cur[i][tid] = (half)0;
}
}
Expand All @@ -210,13 +210,13 @@ __global__ void HalfFC(
#endif
}

void GpuQK(half *q, half *k, half *qk, int qlen, int klen, int dim, float scale) {
void GpuQK(half *q, half *k, half *qk, int qlen, int klen, int dim, float scale, int base) {
const int BQ = 128, BK = 128, DIM = 128;
dim3 blockDim(128);
int BX = (qlen + BQ - 1) / BQ;
int BY = (klen + BK - 1) / BK;
dim3 gridDim(BX, BY);
HalfFC <BQ, DIM, BK> <<<gridDim, blockDim>>> (q, k, qk, qlen, dim, klen, (half)scale);
HalfFC <BQ, DIM, BK> <<<gridDim, blockDim>>> (q, k, qk, qlen, dim, klen, (half)scale, base);
}

template <int THREAD_PER_BLOCK>
Expand Down Expand Up @@ -786,9 +786,9 @@ __global__ void FastllmSoftmaxKernelInner1(half* input, half *output, int outer,
}

template <int THREAD_PER_BLOCK, typename T>
__global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels) {
__global__ void FastllmSoftmaxKernelInner1WithCausalMask(T* input, T *output, int outer, int channels, int base) {
int o = blockIdx.x;
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, o + 1);
FastllmSoftmaxKernelInner1Func <THREAD_PER_BLOCK> (input + o * channels, output + o * channels, o + base + 1);
}

template <int THREAD_PER_BLOCK>
Expand Down Expand Up @@ -3005,7 +3005,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const
return true;
}

if (q1 >= 1024) {
if (q1 >= 1024 || (q1 > 1 && q1 != k1)) {
float *qk = (float *) FastllmCudaMalloc(q1 * k1 * sizeof(float));
float beta = 0, one = 1;
auto fastllmCublasHandle = getFastllmCublasHandle();
Expand All @@ -3029,7 +3029,7 @@ bool FastllmCudaAttention(const fastllm::Data &q, const fastllm::Data &k, const

if (batch == 1 && maskd == nullptr) {
CausalMask<256, float> <<<q1, 256>>>(qk, 0, q1, k1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1);
} else {
if (maskd) {
SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, -10000, q1 * k1);
Expand Down Expand Up @@ -3136,8 +3136,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co
int maskStride = (mask.dims.size() == 3 ? mask.strides[0] : mask.Count(0));

half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale);

if (q1 >= 1024) {
if (q1 >= 1024 || (q1 > 1 && q1 != k1)) {
int alignQ1 = q1, alignK1 = k1;
bool useFastAttn = getCudaInfos()->hasTensorCore && batch == 1 && (q2 == 128 && v2 == 128);
if (useFastAttn) {
Expand All @@ -3154,8 +3153,8 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co
//DeviceSync();
//auto st = std::chrono::system_clock::now();
if (useFastAttn) {
GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1);
GpuQK(qd + i * q.Count(1), kd + (i / group) * k.Count(1), qk, alignQ1, alignK1, q2, scale, k1 - q1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, alignK1, k1 - q1);
status = cublasHgemmStridedBatched(fastllmCublasHandle,
CUBLAS_OP_N, CUBLAS_OP_N,
v2, q1, alignK1, &one,
Expand Down Expand Up @@ -3193,7 +3192,7 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co

if (batch == 1 && maskd == nullptr) {
CausalMask<256, half> <<<q1, 256>>>(qk, __float2half_rn(0), q1, k1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1);
FastllmSoftmaxKernelInner1WithCausalMask<128> <<< q1, 128 >>>(qk, qk, q1, k1, k1 - q1);
} else {
SimpleMask<256> <<< (q1 * k1 / 256) + 1, 256>>>(qk, maskd + (i / (q0 / batch)) * maskStride, __float2half_rn(-10000), q1 * k1);
int outer = q1;
Expand Down
14 changes: 14 additions & 0 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,20 @@ auto st = std::chrono::system_clock::now();
ret = model->ForwardBatch(seqLens.size(), inputIds, attentionMasks,
positionIds, seqLens, pastKeyValues, generationConfigs,
tokensManager, &logits);
} else {
if (seqLens[0] > 8192) {
int len = seqLens[0];
int first = 8192, part = 2048;
for (int st = 0; st < len; ) {
int curLen = std::min(st == 0 ? first : part, len - st);
Data curInput, curPositionIds;
Split(inputIds, 1, st, st + curLen, curInput);
Split(*positionIds[0], 1, st, st + curLen, curPositionIds);

ret = std::vector <int> {model->Forward(curInput, Data(), curPositionIds,
*pastKeyValue1, generationConfigs[0], tokensManager, logits[0])};
st += curLen;
}
} else {
ret = std::vector <int> {model->Forward(inputIds,
attentionMasks[0] == nullptr ? Data() : *attentionMasks[0],
Expand Down
Loading

0 comments on commit 28e79e3

Please sign in to comment.