From bdbf5256eff991a07f8d171251768fef31344bd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Tue, 4 Jun 2024 15:11:23 +0800 Subject: [PATCH] =?UTF-8?q?deep=20seek=20v2=E7=9A=84moe=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E5=8A=A0=E9=80=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/devices/cpu/cpudevice.h | 20 + include/fastllm.h | 4 + include/models/deepseekv2.h | 2 + src/devices/cpu/cpudevice.cpp | 692 ++++++++++++++++++++++++------ src/devices/tfacc/tfaccdevice.cpp | 83 +--- src/fastllm.cpp | 13 + src/models/deepseekv2.cpp | 129 +++--- 7 files changed, 694 insertions(+), 249 deletions(-) diff --git a/include/devices/cpu/cpudevice.h b/include/devices/cpu/cpudevice.h index 960b32d6..6a061e3d 100644 --- a/include/devices/cpu/cpudevice.h +++ b/include/devices/cpu/cpudevice.h @@ -9,6 +9,21 @@ #include "alivethreadpool.h" namespace fastllm { + struct MultiThreadOnlineQuantizationOp : MultiThreadBaseOp { + float *input; + uint8_t *output; + LowBitConfig *configs; + int n, m, group, groupCnt; + float *inputSums, *iscales, *izeros; + + MultiThreadOnlineQuantizationOp (float *input, uint8_t *output, LowBitConfig *configs, int n, int m, int group, int groupCnt, + float *inputSums, float *iscales, float *izeros) : + input(input), output(output), configs(configs), n(n), m(m), group(group), groupCnt(groupCnt), + inputSums(inputSums), iscales(iscales), izeros(izeros) {} ; + + void Run(); + }; + class CpuDevice : BaseDevice { public: CpuDevice (); @@ -36,6 +51,11 @@ namespace fastllm { void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); }; + class CpuMergeMOE : BaseOperator { + protected: + void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); + }; + class CpuEmbedding : BaseOperator { void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); diff --git a/include/fastllm.h b/include/fastllm.h index 52be5645..fc052e86 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -488,6 +488,10 @@ namespace fastllm { void CopyKVCache(Data &oldCache, Data &newCache, int oldBsStart, int newBsStart, int bs, int offset); + bool CanRunMergeMOE(); + void MergeMOE(const Data &input, const Data &logits, std::vector weights, std::vector biass, + float routeScale, float sharedScale, int topk, Data &output); + void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output, int group, float scale, int attentionType); diff --git a/include/models/deepseekv2.h b/include/models/deepseekv2.h index 04eded93..a37a3f59 100644 --- a/include/models/deepseekv2.h +++ b/include/models/deepseekv2.h @@ -96,6 +96,8 @@ namespace fastllm { std::string rope_scaling_type; bool mergeSwiglu = false; + std::vector > weights; + std::vector > biass; }; } diff --git a/src/devices/cpu/cpudevice.cpp b/src/devices/cpu/cpudevice.cpp index 20220a01..7e4f9ed2 100644 --- a/src/devices/cpu/cpudevice.cpp +++ b/src/devices/cpu/cpudevice.cpp @@ -24,6 +24,7 @@ namespace fastllm { this->ops["ToFloat16"] = (BaseOperator*)(new CpuToFloat16()); this->ops["ToFloat32"] = (BaseOperator*)(new CpuToFloat32()); this->ops["Attention"] = (BaseOperator*)(new CpuAttention()); + this->ops["MergeMOE"] = (BaseOperator*)(new CpuMergeMOE()); this->ops["CopyKVCache"] = (BaseOperator*)(new CpuCopyKVCacheOp()); this->ops["Embedding"] = (BaseOperator*)(new CpuEmbedding()); this->ops["LayerNorm"] = (BaseOperator*)(new CpuLayerNormOp()); @@ -440,6 +441,460 @@ namespace fastllm { } } + void OnlineQuantization(float *inputData, std::vector &uinput, std::vector &inputConfigs, + int n, int m, int group, int groupCnt, + std::vector &inputSums, std::vector &iscales, std::vector &izeros) { + inputConfigs.resize(n * group); + uinput.resize(n * m); + inputSums.resize(n * group); + iscales.resize(n * group); + izeros.resize(n * group); + + if (n > 1) { + auto pool = GetAlivePool(); + int threadNum = pool->threads.size(); + int per = n / pool->threads.size(); + int cur = 0; + std::vector ops; + for (int i = 0; i < threadNum; i++) { + int end = (i == threadNum - 1 ? n : cur + per + (cur + per * (threadNum - i) < n)); + ops.push_back(new MultiThreadOnlineQuantizationOp( + inputData + cur * m, uinput.data() + cur * m, inputConfigs.data() + cur * group, + end - cur, m, group, groupCnt, + inputSums.data() + cur * group, iscales.data() + cur * group, izeros.data() + cur * group)); + cur = end; + } + for (int i = 0; i < threadNum; i++) { + pool->PushOp(i, ops[i]); + } + for (int i = 0; i < threadNum; i++) { + pool->Wait(i); + delete ops[i]; + } + } else { + MultiThreadOnlineQuantizationOp(inputData, uinput.data(), inputConfigs.data(), n, m, group, groupCnt, + inputSums.data(), iscales.data(), izeros.data()).Run(); + } + } + + struct MultiThreadSwigluOp : MultiThreadBaseOp { + float *input, *output; + int mid, len, n, inputStride, outputStride; + + MultiThreadSwigluOp (float *input, int mid, int len, float *output, + int n, int inputStride, int outputStride) : + input(input), mid(mid), len(len), output(output), + n(n), inputStride(inputStride), outputStride(outputStride) {} + + void Run() { + for (int o = 0; o < n; o++) { + float *cur = (float*)input + o * inputStride; + float *out = (float*)output + o * outputStride; + int i = 0; + #ifdef __aarch64__ + float32x4_t c1 = vdupq_n_f32(1.0f); + for (; i + 3 < len; i += 4) { + float32x4_t vx = vld1q_f32(cur + i); + float32x4_t vy = vld1q_f32(cur + i + mid); + vx = vdivq_f32(vx, vaddq_f32(c1, exp_ps(vnegq_f32(vx)))); + vy = vmulq_f32(vx, vy); + vst1q_f32(out + i, vy); + } + #endif + for (; i < len; i++) { + float x = cur[i], y = cur[i + mid]; + out[i] = (x / (1.0 + expf(-x))) * y; + } + } + } + }; + + struct MultiThreadLinearInt4GroupOp : MultiThreadBaseOp { + uint8_t *a, *b; + int32_t *c; + int n, m, k, kstride; + int *weightSums; + float *weightMins; + float *scales; + float *bias; + float *iscales, *izeros; + float *inputSums; + int group, groupCnt; + + MultiThreadLinearInt4GroupOp( + uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int kstride, + int *weightSums, float *weightMins, float *scales, float *bias, + float *iscales, float *izeros, float *inputSums, int group, int groupCnt + ) : + a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride), + weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), + iscales(iscales), izeros(izeros), inputSums(inputSums), group(group), groupCnt(groupCnt) {} + + void Run() { + std::vector values; + values.resize(group); + + int block = 0; + for (; block < n; block++) { + uint8_t *weightWalk = b; + uint8_t *inputStart = a + block * m; + + for (int i = 0; i < k; i++) { + std::fill(values.begin(), values.end(), 0.0f); + uint8_t *inputWalk = inputStart; + float sum = 0.0; + + for (int g = 0; g < group; g++) { + int st = g * groupCnt, end = std::min(m, (g + 1) * groupCnt); + float &value = values[g]; + int j = st; +#ifdef __ARM_FEATURE_DOTPROD + uint8x8_t maskHigh = vdup_n_u8(0xF0); + uint8x8_t maskLow = vdup_n_u8(0xF); + uint32x2_t sum0 = {0, 0}; + + for (; j + 15 < end; j += 16) { + uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); + uint8x8x2_t in = vld2_u8(inputWalk + j); + uint8x8_t va = vand_u8(ori, maskLow); + uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); + sum0 = vdot_u32(sum0, va, in.val[1]); + sum0 = vdot_u32(sum0, vb, in.val[0]); + } + value += sum0[0] + sum0[1]; +#elif defined(__aarch64__) + uint8x8_t maskHigh = vdup_n_u8(0xF0); + uint8x8_t maskLow = vdup_n_u8(0xF); + uint32x4_t sum0 = {0, 0, 0, 0}; + + for (; j + 15 < end; j += 16) { + uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); + uint8x8x2_t in = vld2_u8(inputWalk + j); + uint8x8_t va = vand_u8(ori, maskLow); + uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); + sum0 = vpadalq_u16(sum0, vmull_u8(va, in.val[1])); + sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); + } + value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; +#elif defined(__AVX2__) + value += DotU4U8(weightWalk + (i * m + st) / 2, inputWalk + st, end - st); + j += (end - st); +#endif + for (; j + 1 < end; j += 2) { + int id = (i * m + j) / 2; + value += (weightWalk[id] >> 4) * inputWalk[j]; + value += (weightWalk[id] & 0xF) * inputWalk[j + 1]; + } + } + + int g = 0; +#ifdef __aarch64__ + float32x4_t vSum = vdupq_n_f32(0.0f); + float32x4_t vGroupCnt = vdupq_n_f32(groupCnt); + for (; g + 3 < group; g += 4) { + int iid = block * group + g; + int gid = i * group + g; + float32x4_t vValue = vld1q_f32(values.data() + g); + float32x4_t vWeightSum = vcvtq_f32_s32(vld1q_s32(weightSums + gid)); + float32x4_t vWeightMin = vld1q_f32(weightMins + gid); + float32x4_t vScale = vld1q_f32(scales + gid); + float32x4_t vIzero = vld1q_f32(izeros + iid); + float32x4_t vIscale = vld1q_f32(iscales + iid); + float32x4_t vInputSum = vld1q_f32(inputSums + iid); + float32x4_t vMiddle = vsubq_f32(vInputSum, vmulq_f32(vIzero, vGroupCnt)); + vValue = vsubq_f32(vValue, vmulq_f32(vWeightSum, vIzero)); + vSum = vaddq_f32(vSum, vmulq_f32(vScale, vmulq_f32(vIscale, vValue))); + vSum = vaddq_f32(vSum, vmulq_f32(vWeightMin, vmulq_f32(vMiddle, vIscale))); + } + sum += vSum[0] + vSum[1] + vSum[2] + vSum[3]; +#endif + for (; g < group; g++) { + int iid = block * group + g; + int gid = i * group + g; + int value = values[g]; + value -= weightSums[gid] * izeros[iid]; + sum += scales[gid] * iscales[iid] * value + + weightMins[gid] * (inputSums[iid] - izeros[iid] * groupCnt) * iscales[iid]; + } + + if (group * groupCnt > m) { + int iid = block * group + group - 1; + int gid = i * group + group - 1; + sum += weightMins[gid] * izeros[iid] * (group * groupCnt - m) * iscales[iid]; + } + + ((float*)c)[block * kstride + i] = sum + (bias == nullptr ? 0.0 : bias[i]); + } + } + } + }; + + void MultiplyInt4GroupMultiThread(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, + int *weightSums, float *weightMins, float *scales, float *bias, + std::vector &configs, int threadNum, int group, int groupCnt); + void MultiplyInt4GroupMultiThreadLaunch(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, + int *weightSums, float *weightMins, float *scales, float *bias, + std::vector &inputSums, std::vector &iscales, std::vector &izeros, + std::vector &configs, int startTid, int threadNum, int group, int groupCnt, + std::vector &ops, AliveThreadPool *pool); + + void CpuMergeMOE::Run(const std::string &opType, const fastllm::DataDict &datas, + const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { + fastllm::BaseOperator *op = (fastllm::BaseOperator*)(new CpuLinearOp()); + + Data &input = *(datas.find("input")->second); + Data &output = *(datas.find("output")->second); + Data &logits = *(datas.find("logits")->second); + Data **weights = (Data**)(datas.find("weights")->second); + Data **biass = (Data**)(datas.find("biass")->second); + int topk = intParams.find("topk") != intParams.end() ? intParams.find("topk")->second : 1; + float sharedScale = floatParams.find("sharedScale") != floatParams.end() ? floatParams.find("sharedScale")->second : 1.0f; + float routeScale = floatParams.find("routeScale") != floatParams.end() ? floatParams.find("routeScale")->second : 1.0f; + output.Allocate(); + + if (input.dataType == DataType::FLOAT32 && weights[0]->dataType == DataType::INT4_GROUP && input.dims[0] == 1) { + int dimsLen = logits.dims.size(); + int outer = logits.Count(0) / logits.Count(dimsLen - 1); + int channels = logits.dims[dimsLen - 1]; + + std::vector > oriV; + for (int j = 0; j < channels; j++) { + oriV.push_back(std::make_pair(-((float*)logits.cpuData)[j], j)); + } + sort(oriV.begin(), oriV.end()); + + std::vector > v; + for (int j = 0; j < topk; j++) { + v.push_back(std::make_pair(oriV[j].second + 1, -oriV[j].first * routeScale)); + } + v.push_back(std::make_pair(0, sharedScale)); + float *inputData = (float *) input.cpuData; + + int n = input.dims[0], m = input.dims[1]; + int group = weights[0]->group, groupCnt = weights[0]->groupCnt; + if (weights[0]->dataType != DataType::INT4_GROUP) { + group = 1; + groupCnt = m; + } + + std::vector inputConfigs; + std::vector uinput; + std::vector inputSums; + std::vector iscales, izeros; + + OnlineQuantization((float*)input.cpuData, uinput, inputConfigs, n, m, group, groupCnt, + inputSums, iscales, izeros); + std::vector middles; + std::vector results; + for (int j = 0; j < v.size(); j++) { + int idx = v[j].first; + weights[idx * 2]->CalcWeightSum(); + weights[idx * 2 + 1]->CalcWeightSum(); + middles.push_back(new float[weights[idx * 2]->dims[0]]); + results.push_back(new float[weights[idx * 2 + 1]->dims[0]]); + } + + output.Allocate(0.0f); + std::vector ops; + auto *pool = GetAlivePool(); + int threads = pool->threads.size(); + ops.resize(threads); + + std::vector > inputConfigsDown; + std::vector > uinputsDown; + std::vector > inputSumsDown; + std::vector > iscalesDown, izerosDown; + inputConfigsDown.resize(v.size()); + uinputsDown.resize(v.size()); + inputSumsDown.resize(v.size()); + iscalesDown.resize(v.size()); + izerosDown.resize(v.size()); + + for (int st = 0; st < v.size(); st++) { + int k = weights[v[st].first * 2]->dims[0]; + int end = st, selSum = 1; // 一共处理selSum * k个输出 + + int curSum = 1; + for (int l = st + 1; l < v.size(); l++) { + int curK = weights[v[l].first * 2]->dims[0]; + if (curK % k != 0) { + break; + } + curSum += (curK / k); + if (threads % curSum == 0) { + end = l; + selSum = curSum; + } + } + int base = threads / selSum; + + int threadSt = 0; + for (int l = st; l <= end; l++) { + int idx = v[l].first; + Data *weight = weights[idx * 2]; + uint8_t *weightData = (uint8_t *) weight->cpuData; + float *outputData = middles[l]; + float *biasData = nullptr; + int curK = weight->dims[0]; + int curThread = (curK / k) * base; + MultiplyInt4GroupMultiThreadLaunch(uinput.data(), weightData, (int32_t *) outputData, n, m, curK, + weight->weightSum.data(), weight->mins.data(), weight->scales.data(), biasData, + inputSums, iscales, izeros, + inputConfigs, threadSt, curThread, group, groupCnt, ops, pool); + threadSt += curThread; + } + + for (int j = 0; j < ops.size(); j++) { + pool->Wait(j); + delete ops[j]; + } + + // swiglu + threadSt = 0; + for (int l = st; l <= end; l++) { + int idx = v[l].first; + int spatial = weights[idx * 2]->dims[0], mid = spatial / 2; + float *outputData = middles[l]; + int curK = weights[idx * 2]->dims[0]; + int curThread = (curK / k) * base; + int per = mid / curThread; + int cur = 0; + for (int i = 0; i < curThread; i++) { + int end = (i == curThread - 1 ? mid : cur + per + (cur + per * (curThread - i) < mid)); + ops[threadSt + i] = (new fastllm::MultiThreadSwigluOp(outputData + cur, mid, end - cur, outputData + cur, + n, spatial, spatial)); + cur = end; + } + for (int i = 0; i < curThread; i++) { + pool->PushOp(threadSt + i, ops[threadSt + i]); + } + threadSt += curThread; + } + for (int j = 0; j < ops.size(); j++) { + pool->Wait(j); + delete ops[j]; + } + + for (int l = st; l <= end; l++) { + int idx = v[l].first; + int mid = weights[idx * 2]->dims[0] / 2; + Data *weightDown = weights[idx * 2 + 1]; + int groupDown = weightDown->group, groupCntDown = weightDown->groupCnt; + auto &inputConfigs = inputConfigsDown[l]; + auto &inputSums = inputSumsDown[l]; + auto &iscales = iscalesDown[l]; + auto &izeros = izerosDown[l]; + auto &uinputDown = uinputsDown[l]; + inputConfigs.resize(n * groupDown); + uinputDown.resize(n * mid); + inputSums.resize(n * groupDown); + iscales.resize(n * groupDown); + izeros.resize(n * groupDown); + + ops[l - st] = new MultiThreadOnlineQuantizationOp( + middles[l], uinputDown.data(), inputConfigs.data(), + n, mid, groupDown, groupCntDown, + inputSums.data(), iscales.data(), izeros.data()); + pool->PushOp(l - st, ops[l - st]); + } + + for (int l = st; l <= end; l++) { + pool->Wait(l - st); + delete ops[l - st]; + } + + threadSt = 0; + for (int l = st; l <= end; l++) { + int idx = v[l].first; + int mid = weights[idx * 2]->dims[0] / 2; + int curK = weights[idx * 2]->dims[0]; + Data *weightDown = weights[idx * 2 + 1]; + int groupDown = weightDown->group, groupCntDown = weightDown->groupCnt; + auto &inputConfigs = inputConfigsDown[l]; + auto &inputSums = inputSumsDown[l]; + auto &iscales = iscalesDown[l]; + auto &izeros = izerosDown[l]; + auto &uinputDown = uinputsDown[l]; + int curThread = (curK / k) * base; + MultiplyInt4GroupMultiThreadLaunch(uinputDown.data(), (uint8_t*)weightDown->cpuData, (int32_t *) results[l], 1, mid, m, + weightDown->weightSum.data(), weightDown->mins.data(), weightDown->scales.data(), nullptr, + inputSums, iscales, izeros, + inputConfigs, threadSt, curThread, groupDown, groupCntDown, ops, pool); + threadSt += curThread; + } + + for (int j = 0; j < ops.size(); j++) { + pool->Wait(j); + delete ops[j]; + } + + st = end; + } + + for (int j = 0; j < v.size(); j++) { + float value = v[j].second; + float *fLastOutput = (float*)output.cpuData; + float *curOutput = (float*)results[j]; + for (int k = 0; k < m; k++) { + fLastOutput[k] += curOutput[k] * value; + } + } + } else { + // normal + Data gate, attenPart, moePart, w1, w2, w3; + TopK(logits, gate, topk); + gate.ToDevice(DataDevice::CPU); + float *gateData = (float*)gate.cpuData; + + if (input.dims[0] == 1) { + output.Allocate(0.0f); + for (int j = 0; j < topk; j++) { + int idx = (int)(gateData[j * 2] + 1e-1); + float value = gateData[j * 2 + 1] * routeScale; + + Linear(input, *weights[(idx + 1) * 2], Data(), w3); + Swiglu(w3, w1); + Linear(w1, *weights[(idx + 1) * 2 + 1], Data(), w2); + AddTo(output, w2, value); + } + + Linear(input, *weights[0], Data(), w3); + Swiglu(w3, w1); + Linear(w1, *weights[1], Data(), w2); + AddTo(output, w2, sharedScale); + } else { + Data moeFinal = Data(); + moeFinal.Resize({0, input.dims[1]}); + moeFinal.Expansion(input.dims); + for (int b = 0; b < input.dims[0]; b++) { + Data *currentData = &input; + Split(input, 0, b, b + 1, attenPart); + currentData = &attenPart; + moePart.Resize(currentData->dims); + moePart.Allocate(0.0f); + + for (int j = 0; j < topk; j++) { + int idx = (int)(gateData[(b * topk + j) * 2] + 1e-1); + float value = gateData[(b * topk + j) * 2 + 1] * routeScale; + + Linear(*currentData, *weights[(idx + 1) * 2], Data(), w3); + Swiglu(w3, w1); + Linear(w1, *weights[(idx + 1) * 2 + 1], Data(), w2); + AddTo(moePart, w2, value); + } + + Linear(*currentData, *weights[0], Data(), w3); + Swiglu(w3, w1); + Linear(w1, *weights[1], Data(), w2); + AddTo(moePart, w2, sharedScale); + + CatDirect(moeFinal, moePart, 0); + } + memcpy(output.cpuData, moeFinal.cpuData, output.GetBytes()); + } + } + } + void CpuCopyKVCacheOp::Reshape(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { return; @@ -1368,126 +1823,6 @@ namespace fastllm { } }; - struct MultiThreadLinearInt4GroupOp : MultiThreadBaseOp { - uint8_t *a, *b; - int32_t *c; - int n, m, k, kstride; - int *weightSums; - float *weightMins; - float *scales; - float *bias; - float *iscales, *izeros; - float *inputSums; - int group, groupCnt; - - MultiThreadLinearInt4GroupOp( - uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int kstride, - int *weightSums, float *weightMins, float *scales, float *bias, - float *iscales, float *izeros, float *inputSums, int group, int groupCnt - ) : - a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride), - weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), - iscales(iscales), izeros(izeros), inputSums(inputSums), group(group), groupCnt(groupCnt) {} - - void Run() { - std::vector values; - values.resize(group); - - int block = 0; - for (; block < n; block++) { - uint8_t *weightWalk = b; - uint8_t *inputStart = a + block * m; - - for (int i = 0; i < k; i++) { - std::fill(values.begin(), values.end(), 0.0f); - uint8_t *inputWalk = inputStart; - float sum = 0.0; - - for (int g = 0; g < group; g++) { - int st = g * groupCnt, end = std::min(m, (g + 1) * groupCnt); - float &value = values[g]; - int j = st; -#ifdef __ARM_FEATURE_DOTPROD - uint8x8_t maskHigh = vdup_n_u8(0xF0); - uint8x8_t maskLow = vdup_n_u8(0xF); - uint32x2_t sum0 = {0, 0}; - - for (; j + 15 < end; j += 16) { - uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); - uint8x8x2_t in = vld2_u8(inputWalk + j); - uint8x8_t va = vand_u8(ori, maskLow); - uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); - sum0 = vdot_u32(sum0, va, in.val[1]); - sum0 = vdot_u32(sum0, vb, in.val[0]); - } - value += sum0[0] + sum0[1]; -#elif defined(__aarch64__) - uint8x8_t maskHigh = vdup_n_u8(0xF0); - uint8x8_t maskLow = vdup_n_u8(0xF); - uint32x4_t sum0 = {0, 0, 0, 0}; - - for (; j + 15 < end; j += 16) { - uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2); - uint8x8x2_t in = vld2_u8(inputWalk + j); - uint8x8_t va = vand_u8(ori, maskLow); - uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4); - sum0 = vpadalq_u16(sum0, vmull_u8(va, in.val[1])); - sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0])); - } - value += sum0[0] + sum0[1] + sum0[2] + sum0[3]; -#elif defined(__AVX2__) - value += DotU4U8(weightWalk + (i * m + st) / 2, inputWalk + st, end - st); - j += (end - st); -#endif - for (; j + 1 < end; j += 2) { - int id = (i * m + j) / 2; - value += (weightWalk[id] >> 4) * inputWalk[j]; - value += (weightWalk[id] & 0xF) * inputWalk[j + 1]; - } - } - - int g = 0; -#ifdef __aarch64__ - float32x4_t vSum = vdupq_n_f32(0.0f); - float32x4_t vGroupCnt = vdupq_n_f32(groupCnt); - for (; g + 3 < group; g += 4) { - int iid = block * group + g; - int gid = i * group + g; - float32x4_t vValue = vld1q_f32(values.data() + g); - float32x4_t vWeightSum = vcvtq_f32_s32(vld1q_s32(weightSums + gid)); - float32x4_t vWeightMin = vld1q_f32(weightMins + gid); - float32x4_t vScale = vld1q_f32(scales + gid); - float32x4_t vIzero = vld1q_f32(izeros + iid); - float32x4_t vIscale = vld1q_f32(iscales + iid); - float32x4_t vInputSum = vld1q_f32(inputSums + iid); - float32x4_t vMiddle = vsubq_f32(vInputSum, vmulq_f32(vIzero, vGroupCnt)); - vValue = vsubq_f32(vValue, vmulq_f32(vWeightSum, vIzero)); - vSum = vaddq_f32(vSum, vmulq_f32(vScale, vmulq_f32(vIscale, vValue))); - vSum = vaddq_f32(vSum, vmulq_f32(vWeightMin, vmulq_f32(vMiddle, vIscale))); - } - sum += vSum[0] + vSum[1] + vSum[2] + vSum[3]; -#endif - for (; g < group; g++) { - int iid = block * group + g; - int gid = i * group + g; - int value = values[g]; - value -= weightSums[gid] * izeros[iid]; - sum += scales[gid] * iscales[iid] * value + - weightMins[gid] * (inputSums[iid] - izeros[iid] * groupCnt) * iscales[iid]; - } - - if (group * groupCnt > m) { - int iid = block * group + group - 1; - int gid = i * group + group - 1; - sum += weightMins[gid] * izeros[iid] * (group * groupCnt - m) * iscales[iid]; - } - - ((float*)c)[block * kstride + i] = sum + (bias == nullptr ? 0.0 : bias[i]); - } - } - } - }; - struct MultiThreadLinearInt4NoZeroOp : MultiThreadBaseOp { uint8_t *a, *b; int32_t *c; @@ -1731,6 +2066,29 @@ namespace fastllm { } } + //a = [n, m], b = [k, m], c = aT(b') = [n, k] + void MultiplyInt4GroupMultiThreadLaunch(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, + int *weightSums, float *weightMins, float *scales, float *bias, + std::vector &inputSums, std::vector &iscales, std::vector &izeros, + std::vector &configs, int startTid, int threadNum, int group, int groupCnt, + std::vector &ops, + AliveThreadPool *pool) { + int per = k / threadNum; + int cur = 0; + + for (int i = 0; i < threadNum; i++) { + int end = (i == threadNum - 1 ? k : cur + per + (cur + per * (threadNum - i) < k)); + ops[startTid + i] = new MultiThreadLinearInt4GroupOp(a, b + cur * m / 2, c + cur, n, m, end - cur, k, + weightSums + cur * group, weightMins + cur * group, scales + cur * group, + (bias == nullptr ? (float *) nullptr : bias + cur), iscales.data(), izeros.data(), + inputSums.data(), group, groupCnt); + cur = end; + } + for (int i = 0; i < threadNum; i++) { + pool->PushOp(startTid + i, ops[startTid + i]); + } + } + //a = [n, m], b = [k, m], c = aT(b') = [n, k] void MultiplyInt4GroupMultiThread(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int *weightSums, float *weightMins, float *scales, float *bias, @@ -1780,6 +2138,100 @@ namespace fastllm { } } + void GetArrayMinMax(float *a, int len, float &minValue, float &maxValue) { + int j = 0; + minValue = 1e100; + maxValue = -1e100; +#ifdef __aarch64__ + float32x4_t mins = vdupq_n_f32(1e100); + float32x4_t maxs = vdupq_n_f32(-1e100); + for (; j + 3 < len; j += 4) { + float32x4_t v = vld1q_f32(a + j); + mins = vminq_f32(mins, v); + maxs = vmaxq_f32(maxs, v); + } + for (int l = 0; l < 4; l++) { + minValue = std::min(minValue, mins[l]); + maxValue = std::max(maxValue, maxs[l]); + } +#endif + for (; j < len; j++) { + minValue = std::min(minValue, a[j]); + maxValue = std::max(maxValue, a[j]); + } + } + + void QuantizationAll(float *fValue, uint8_t *uValue, int len, LowBitConfig *config) { + float scale = config->scale; + float zeroPoint = config->zeroPoint; + int j = 0; +#ifdef __aarch64__ + float32x4_t scales = vdupq_n_f32(scale); + float32x4_t zeros = vdupq_n_f32(zeroPoint + 0.5); + int32x4_t maxds = vcombine_s32(vcreate_s32(0x000000ff000000ff), vcreate_s32(0x000000ff000000ff)); + int32x4_t minds = vcombine_s32(vcreate_s32(0x0000000000000000), vcreate_s32(0x0000000000000000)); + for (; j + 7 < len; j += 8) { + float32x4_t fin1 = vld1q_f32(fValue + j); + float32x4_t fin2 = vld1q_f32(fValue + j + 4); + fin1 = vaddq_f32(vdivq_f32(fin1, scales), zeros); + fin2 = vaddq_f32(vdivq_f32(fin2, scales), zeros); + int32x4_t out1 = vcvtq_s32_f32(fin1); + int32x4_t out2 = vcvtq_s32_f32(fin2); + out1 = vmaxq_s32(out1, minds); + out1 = vminq_s32(out1, maxds); + out2 = vmaxq_s32(out2, minds); + out2 = vminq_s32(out2, maxds); + uint16x8_t out3 = vpaddq_u16(vreinterpretq_u16_s32(out1), vreinterpretq_u16_s32(out2)); + uint8x8_t out = vmovn_u16(out3); + vst1_u8(uValue + j, out); + } +#endif + for (; j < len; j++) { + uValue[j] = (uint8_t) (std::min(255., (double) std::max(fValue[j] / scale + zeroPoint + 0.5, 0.0))); + } + } + + void MultiThreadOnlineQuantizationOp::Run() { + for (int i = 0; i < n; i++) { + float *cur = input + i * m; + uint8_t *u = output + i * m; + for (int g = 0; g < group; g++) { + int st = g * groupCnt; + int end = std::min(m, (g + 1) * groupCnt); + float minValue = 1e9, maxValue = -1e9; + GetArrayMinMax(input + i * m + st, end - st, minValue, maxValue); + configs[i * group + g] = (LowBitConfig(minValue, maxValue, 8, 0)); + QuantizationAll(cur + st, u + st, end - st, &configs[i * group + g]); + } + } +#ifdef __AVX__ + uint8_t *temp = new uint8_t[32]; + for (int i = 0; i < n; i++) { + for (int j = 0; j + 31 < m; j += 32) { + memcpy(temp, output + i * m + j, 32); + for (int k = 0; k < 16; k++) { + output[i * m + j + k] = temp[k * 2 + 1]; + output[i * m + j + k + 16] = temp[k * 2]; + } + } + } + delete[] temp; +#endif + if (inputSums != nullptr) { + for (int i = 0; i < n; i++) { + for (int g = 0; g < group; g++) { + iscales[i * group + g] = configs[i * group + g].scale; + izeros[i * group + g] = configs[i * group + g].zeroPoint; + int sum = 0; + for (int j = g * groupCnt; j < (g + 1) * groupCnt && j < m; j++) { + sum += output[i * m + j]; + } + inputSums[i * group + g] = sum; + } + } + } + } + bool CpuLinearOp::CanRun(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) { if (intParams.find("exType") != intParams.end()) { diff --git a/src/devices/tfacc/tfaccdevice.cpp b/src/devices/tfacc/tfaccdevice.cpp index 16268bdd..d5501657 100644 --- a/src/devices/tfacc/tfaccdevice.cpp +++ b/src/devices/tfacc/tfaccdevice.cpp @@ -6,6 +6,7 @@ #include #include "devices/tfacc/tfaccdevice.h" +#include "devices/cpu/cpudevice.h" #include "devices/tfacc/fastllm-tfacc.h" #include "devices/cpu/alivethreadpool.h" @@ -23,84 +24,6 @@ #include "utils.h" namespace fastllm { - void GetArrayMinMax(float *a, int len, float &minValue, float &maxValue) { - int j = 0; - minValue = 1e100; - maxValue = -1e100; -#ifdef __aarch64__ - float32x4_t mins = vdupq_n_f32(1e100); - float32x4_t maxs = vdupq_n_f32(-1e100); - for (; j + 3 < len; j += 4) { - float32x4_t v = vld1q_f32(a + j); - mins = vminq_f32(mins, v); - maxs = vmaxq_f32(maxs, v); - } - for (int l = 0; l < 4; l++) { - minValue = std::min(minValue, mins[l]); - maxValue = std::max(maxValue, maxs[l]); - } -#endif - for (; j < len; j++) { - minValue = std::min(minValue, a[j]); - maxValue = std::max(maxValue, a[j]); - } - } - - void QuantizationAll(float *fValue, uint8_t *uValue, int len, LowBitConfig *config) { - float scale = config->scale; - float zeroPoint = config->zeroPoint; - int j = 0; -#ifdef __aarch64__ - float32x4_t scales = vdupq_n_f32(scale); - float32x4_t zeros = vdupq_n_f32(zeroPoint + 0.5); - int32x4_t maxds = vcombine_s32(vcreate_s32(0x000000ff000000ff), vcreate_s32(0x000000ff000000ff)); - int32x4_t minds = vcombine_s32(vcreate_s32(0x0000000000000000), vcreate_s32(0x0000000000000000)); - for (; j + 7 < len; j += 8) { - float32x4_t fin1 = vld1q_f32(fValue + j); - float32x4_t fin2 = vld1q_f32(fValue + j + 4); - fin1 = vaddq_f32(vdivq_f32(fin1, scales), zeros); - fin2 = vaddq_f32(vdivq_f32(fin2, scales), zeros); - int32x4_t out1 = vcvtq_s32_f32(fin1); - int32x4_t out2 = vcvtq_s32_f32(fin2); - out1 = vmaxq_s32(out1, minds); - out1 = vminq_s32(out1, maxds); - out2 = vmaxq_s32(out2, minds); - out2 = vminq_s32(out2, maxds); - uint16x8_t out3 = vpaddq_u16(vreinterpretq_u16_s32(out1), vreinterpretq_u16_s32(out2)); - uint8x8_t out = vmovn_u16(out3); - vst1_u8(uValue + j, out); - } -#endif - for (; j < len; j++) { - uValue[j] = (uint8_t) (std::min(255., (double) std::max(fValue[j] / scale + zeroPoint + 0.5, 0.0))); - } - } - - struct MultiThreadOnlineQuantizationOp : MultiThreadBaseOp { - float *input; - uint8_t *output; - LowBitConfig *configs; - int n, m, group, groupCnt; - - MultiThreadOnlineQuantizationOp (float *input, uint8_t *output, LowBitConfig *configs, int n, int m, int group, int groupCnt) : - input(input), output(output), configs(configs), n(n), m(m), group(group), groupCnt(groupCnt) {} ; - - void Run() { - for (int i = 0; i < n; i++) { - float *cur = input + i * m; - uint8_t *u = output + i * m; - for (int g = 0; g < group; g++) { - int st = g * groupCnt; - int end = std::min(m, (g + 1) * groupCnt); - float minValue = 1e9, maxValue = -1e9; - GetArrayMinMax(input + i * m + st, end - st, minValue, maxValue); - configs[i * group + g] = (LowBitConfig(minValue, maxValue, 8, 0)); - QuantizationAll(cur + st, u + st, end - st, &configs[i * group + g]); - } - } - } - }; - static TfaccClient tfaccClient; TfaccDevice::TfaccDevice() { @@ -235,7 +158,7 @@ namespace fastllm { int end = (i == threadNum - 1 ? n : cur + per + (cur + per * (threadNum - i) < n)); ops.push_back(new MultiThreadOnlineQuantizationOp( inputData + cur * m, uinput.data() + cur * m, inputConfigs.data() + cur * group, - end - cur, m, group, groupCnt)); + end - cur, m, group, groupCnt, nullptr, nullptr, nullptr)); cur = end; } for (int i = 0; i < threadNum; i++) { @@ -246,7 +169,7 @@ namespace fastllm { delete ops[i]; } } else { - MultiThreadOnlineQuantizationOp(inputData, uinput.data(), inputConfigs.data(), n, m, group, groupCnt).Run(); + MultiThreadOnlineQuantizationOp(inputData, uinput.data(), inputConfigs.data(), n, m, group, groupCnt, nullptr, nullptr, nullptr).Run(); } if (weight.dataType == DataType::INT4) { diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 7ce8cd01..3f964c36 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -2278,6 +2278,19 @@ namespace fastllm { }); } + bool CanRunMergeMOE() { + return curExecutor->CanRunOnFirstDevice("MergeMOE", {}, {}, {}); + } + + void MergeMOE(const Data &input, const Data &logits, std::vector weights, std::vector biass, + float routeScale, float sharedScale, int topk, Data &output) { + curExecutor->Run("MergeMOE", { + {"input", (Data*)&input}, {"logits", (Data*)&logits}, + {"weights", (Data*)weights.data()}, {"biass", (Data*)biass.data()}, + {"output", (Data*)&output} + }, {{"sharedScale", sharedScale}, {"routeScale", routeScale}}, {{"topk", topk}}); + } + void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output, int group, float scale, int attentionType) { int maskType = 0; // 0: 因果mask diff --git a/src/models/deepseekv2.cpp b/src/models/deepseekv2.cpp index a3e80f24..4ec94b51 100644 --- a/src/models/deepseekv2.cpp +++ b/src/models/deepseekv2.cpp @@ -272,6 +272,24 @@ namespace fastllm { Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates); int seqlen = hiddenStates.dims[1]; + + if (weights.size() == 0) { + weights.resize(block_cnt); + biass.resize(block_cnt); + for (int i = 0; i < block_cnt; i++) { + weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"]); + weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.down_proj.weight"]); + biass[i].push_back(nullptr); + biass[i].push_back(nullptr); + for (int j = 0; j < this->num_experts; j++) { + weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".gateup_proj.weight"]); + weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".down_proj.weight"]); + biass[i].push_back(nullptr); + biass[i].push_back(nullptr); + } + } + } + for (int i = 0; i < block_cnt; i++) { ApplyDeviceMap(this->deviceMap, i + 1, block_cnt); RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"], @@ -419,69 +437,81 @@ namespace fastllm { attenInput.Reshape({batch * len, attenInput.dims[2]}); Linear(attenInput, weight[gateWeightName], Data(), routerLogits); Softmax(routerLogits, routerLogits, -1); - TopK(routerLogits, gate, this->num_experts_per_tok); - moeFinal = Data(); - moeFinal.Resize({0, attenInput.dims[1]}); - moeFinal.Expansion(attenInput.dims); - - gate.ToDevice(DataDevice::CPU); - float *gateData = (float*)gate.cpuData; - /// TODO: 这里是greedy topk, 需要实现group limited topk - for (int b = 0; b < batch * len; b++) { - Data *currentData = &attenInput; - if (batch * len != 1) { - Split(attenInput, 0, b, b + 1, attenPart); - currentData = &attenPart; - } - moePart.Resize(currentData->dims); - moePart.Allocate(0.0f); - for (int j = 0; j < this->num_experts_per_tok; j++) { - int idx = (int)(gateData[(b * this->num_experts_per_tok + j) * 2] + 1e-1); - float value = gateData[(b * this->num_experts_per_tok + j) * 2 + 1]; - value *= routed_scaling_factor; + if (this->mergeSwiglu && CanRunMergeMOE()) { + MergeMOE ( + attenInput, routerLogits, + weights[i], biass[i], + this->routed_scaling_factor, 1.0f, + this->num_experts_per_tok, + moeFinal + ); + } else { + TopK(routerLogits, gate, this->num_experts_per_tok); + gate.ToDevice(DataDevice::CPU); + float *gateData = (float*)gate.cpuData; + /// TODO: 这里是greedy topk, 需要实现group limited topk + + moeFinal = Data(); + moeFinal.Resize({0, attenInput.dims[1]}); + moeFinal.Expansion(attenInput.dims); + + for (int b = 0; b < batch * len; b++) { + Data *currentData = &attenInput; + if (batch * len != 1) { + Split(attenInput, 0, b, b + 1, attenPart); + currentData = &attenPart; + } + moePart.Resize(currentData->dims); + moePart.Allocate(0.0f); + + for (int j = 0; j < this->num_experts_per_tok; j++) { + int idx = (int)(gateData[(b * this->num_experts_per_tok + j) * 2] + 1e-1); + float value = gateData[(b * this->num_experts_per_tok + j) * 2 + 1]; + value *= routed_scaling_factor; + if (this->mergeSwiglu) { + if (CanRunLinearEx(LinearExType::ExSwiglu)) { + LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu); + } else { + Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w3); + Swiglu(w3, w1); + } + } else { + if (CanRunLinearEx(LinearExType::ExSilu)) { + LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu); + } else { + Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1); + Silu(w1, w1); + } + Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3); + MulTo(w1, w3); + } + Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".down_proj.weight"], Data(), w2); + AddTo(moePart, w2, value); + } + if (this->mergeSwiglu) { if (CanRunLinearEx(LinearExType::ExSwiglu)) { - LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu); + LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu); } else { - Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w3); + Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w3); Swiglu(w3, w1); } } else { if (CanRunLinearEx(LinearExType::ExSilu)) { - LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu); + LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1, LinearExType::ExSilu); } else { - Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1); + Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1); Silu(w1, w1); } - Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3); + Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.up_proj.weight"], Data(), w3); MulTo(w1, w3); } - Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".down_proj.weight"], Data(), w2); - AddTo(moePart, w2, value); - } + Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.down_proj.weight"], Data(), w2); + AddTo(moePart, w2); - if (this->mergeSwiglu) { - if (CanRunLinearEx(LinearExType::ExSwiglu)) { - LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu); - } else { - Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w3); - Swiglu(w3, w1); - } - } else { - if (CanRunLinearEx(LinearExType::ExSilu)) { - LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1, LinearExType::ExSilu); - } else { - Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1); - Silu(w1, w1); - } - Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.up_proj.weight"], Data(), w3); - MulTo(w1, w3); + CatDirect(moeFinal, moePart, 0); } - Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.down_proj.weight"], Data(), w2); - AddTo(moePart, w2); - - CatDirect(moeFinal, moePart, 0); } moeFinal.Reshape(hiddenStates.dims); @@ -530,11 +560,12 @@ namespace fastllm { } } } +/* if (sinDataPtr != &sinData) delete sinDataPtr; if (cosDataPtr != &cosData) delete cosDataPtr; - +*/ return lastRet; }