Skip to content

Commit

Permalink
deep seek v2的moe结构加速
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 4, 2024
1 parent f54d990 commit bdbf525
Show file tree
Hide file tree
Showing 7 changed files with 694 additions and 249 deletions.
20 changes: 20 additions & 0 deletions include/devices/cpu/cpudevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
#include "alivethreadpool.h"

namespace fastllm {
struct MultiThreadOnlineQuantizationOp : MultiThreadBaseOp {
float *input;
uint8_t *output;
LowBitConfig *configs;
int n, m, group, groupCnt;
float *inputSums, *iscales, *izeros;

MultiThreadOnlineQuantizationOp (float *input, uint8_t *output, LowBitConfig *configs, int n, int m, int group, int groupCnt,
float *inputSums, float *iscales, float *izeros) :
input(input), output(output), configs(configs), n(n), m(m), group(group), groupCnt(groupCnt),
inputSums(inputSums), iscales(iscales), izeros(izeros) {} ;

void Run();
};

class CpuDevice : BaseDevice {
public:
CpuDevice ();
Expand Down Expand Up @@ -36,6 +51,11 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuMergeMOE : BaseOperator {
protected:
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuEmbedding : BaseOperator {
void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
Expand Down
4 changes: 4 additions & 0 deletions include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@ namespace fastllm {

void CopyKVCache(Data &oldCache, Data &newCache, int oldBsStart, int newBsStart, int bs, int offset);

bool CanRunMergeMOE();
void MergeMOE(const Data &input, const Data &logits, std::vector <Data*> weights, std::vector <Data*> biass,
float routeScale, float sharedScale, int topk, Data &output);

void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output,
int group, float scale, int attentionType);

Expand Down
2 changes: 2 additions & 0 deletions include/models/deepseekv2.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ namespace fastllm {
std::string rope_scaling_type;

bool mergeSwiglu = false;
std::vector <std::vector <Data*> > weights;
std::vector <std::vector <Data*> > biass;
};
}

Expand Down
692 changes: 572 additions & 120 deletions src/devices/cpu/cpudevice.cpp

Large diffs are not rendered by default.

83 changes: 3 additions & 80 deletions src/devices/tfacc/tfaccdevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <fcntl.h>

#include "devices/tfacc/tfaccdevice.h"
#include "devices/cpu/cpudevice.h"
#include "devices/tfacc/fastllm-tfacc.h"
#include "devices/cpu/alivethreadpool.h"

Expand All @@ -23,84 +24,6 @@
#include "utils.h"

namespace fastllm {
void GetArrayMinMax(float *a, int len, float &minValue, float &maxValue) {
int j = 0;
minValue = 1e100;
maxValue = -1e100;
#ifdef __aarch64__
float32x4_t mins = vdupq_n_f32(1e100);
float32x4_t maxs = vdupq_n_f32(-1e100);
for (; j + 3 < len; j += 4) {
float32x4_t v = vld1q_f32(a + j);
mins = vminq_f32(mins, v);
maxs = vmaxq_f32(maxs, v);
}
for (int l = 0; l < 4; l++) {
minValue = std::min(minValue, mins[l]);
maxValue = std::max(maxValue, maxs[l]);
}
#endif
for (; j < len; j++) {
minValue = std::min(minValue, a[j]);
maxValue = std::max(maxValue, a[j]);
}
}

void QuantizationAll(float *fValue, uint8_t *uValue, int len, LowBitConfig *config) {
float scale = config->scale;
float zeroPoint = config->zeroPoint;
int j = 0;
#ifdef __aarch64__
float32x4_t scales = vdupq_n_f32(scale);
float32x4_t zeros = vdupq_n_f32(zeroPoint + 0.5);
int32x4_t maxds = vcombine_s32(vcreate_s32(0x000000ff000000ff), vcreate_s32(0x000000ff000000ff));
int32x4_t minds = vcombine_s32(vcreate_s32(0x0000000000000000), vcreate_s32(0x0000000000000000));
for (; j + 7 < len; j += 8) {
float32x4_t fin1 = vld1q_f32(fValue + j);
float32x4_t fin2 = vld1q_f32(fValue + j + 4);
fin1 = vaddq_f32(vdivq_f32(fin1, scales), zeros);
fin2 = vaddq_f32(vdivq_f32(fin2, scales), zeros);
int32x4_t out1 = vcvtq_s32_f32(fin1);
int32x4_t out2 = vcvtq_s32_f32(fin2);
out1 = vmaxq_s32(out1, minds);
out1 = vminq_s32(out1, maxds);
out2 = vmaxq_s32(out2, minds);
out2 = vminq_s32(out2, maxds);
uint16x8_t out3 = vpaddq_u16(vreinterpretq_u16_s32(out1), vreinterpretq_u16_s32(out2));
uint8x8_t out = vmovn_u16(out3);
vst1_u8(uValue + j, out);
}
#endif
for (; j < len; j++) {
uValue[j] = (uint8_t) (std::min(255., (double) std::max(fValue[j] / scale + zeroPoint + 0.5, 0.0)));
}
}

struct MultiThreadOnlineQuantizationOp : MultiThreadBaseOp {
float *input;
uint8_t *output;
LowBitConfig *configs;
int n, m, group, groupCnt;

MultiThreadOnlineQuantizationOp (float *input, uint8_t *output, LowBitConfig *configs, int n, int m, int group, int groupCnt) :
input(input), output(output), configs(configs), n(n), m(m), group(group), groupCnt(groupCnt) {} ;

void Run() {
for (int i = 0; i < n; i++) {
float *cur = input + i * m;
uint8_t *u = output + i * m;
for (int g = 0; g < group; g++) {
int st = g * groupCnt;
int end = std::min(m, (g + 1) * groupCnt);
float minValue = 1e9, maxValue = -1e9;
GetArrayMinMax(input + i * m + st, end - st, minValue, maxValue);
configs[i * group + g] = (LowBitConfig(minValue, maxValue, 8, 0));
QuantizationAll(cur + st, u + st, end - st, &configs[i * group + g]);
}
}
}
};

static TfaccClient tfaccClient;

TfaccDevice::TfaccDevice() {
Expand Down Expand Up @@ -235,7 +158,7 @@ namespace fastllm {
int end = (i == threadNum - 1 ? n : cur + per + (cur + per * (threadNum - i) < n));
ops.push_back(new MultiThreadOnlineQuantizationOp(
inputData + cur * m, uinput.data() + cur * m, inputConfigs.data() + cur * group,
end - cur, m, group, groupCnt));
end - cur, m, group, groupCnt, nullptr, nullptr, nullptr));
cur = end;
}
for (int i = 0; i < threadNum; i++) {
Expand All @@ -246,7 +169,7 @@ namespace fastllm {
delete ops[i];
}
} else {
MultiThreadOnlineQuantizationOp(inputData, uinput.data(), inputConfigs.data(), n, m, group, groupCnt).Run();
MultiThreadOnlineQuantizationOp(inputData, uinput.data(), inputConfigs.data(), n, m, group, groupCnt, nullptr, nullptr, nullptr).Run();
}

if (weight.dataType == DataType::INT4) {
Expand Down
13 changes: 13 additions & 0 deletions src/fastllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2278,6 +2278,19 @@ namespace fastllm {
});
}

bool CanRunMergeMOE() {
return curExecutor->CanRunOnFirstDevice("MergeMOE", {}, {}, {});
}

void MergeMOE(const Data &input, const Data &logits, std::vector <Data*> weights, std::vector <Data*> biass,
float routeScale, float sharedScale, int topk, Data &output) {
curExecutor->Run("MergeMOE", {
{"input", (Data*)&input}, {"logits", (Data*)&logits},
{"weights", (Data*)weights.data()}, {"biass", (Data*)biass.data()},
{"output", (Data*)&output}
}, {{"sharedScale", sharedScale}, {"routeScale", routeScale}}, {{"topk", topk}});
}

void Attention(const Data &q, const Data &k, const Data &v, const Data &mask, Data &output,
int group, float scale, int attentionType) {
int maskType = 0; // 0: 因果mask
Expand Down
129 changes: 80 additions & 49 deletions src/models/deepseekv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@ namespace fastllm {

Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates);
int seqlen = hiddenStates.dims[1];

if (weights.size() == 0) {
weights.resize(block_cnt);
biass.resize(block_cnt);
for (int i = 0; i < block_cnt; i++) {
weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"]);
weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.down_proj.weight"]);
biass[i].push_back(nullptr);
biass[i].push_back(nullptr);
for (int j = 0; j < this->num_experts; j++) {
weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".gateup_proj.weight"]);
weights[i].push_back(&weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(j) + ".down_proj.weight"]);
biass[i].push_back(nullptr);
biass[i].push_back(nullptr);
}
}
}

for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"],
Expand Down Expand Up @@ -419,69 +437,81 @@ namespace fastllm {
attenInput.Reshape({batch * len, attenInput.dims[2]});
Linear(attenInput, weight[gateWeightName], Data(), routerLogits);
Softmax(routerLogits, routerLogits, -1);
TopK(routerLogits, gate, this->num_experts_per_tok);
moeFinal = Data();
moeFinal.Resize({0, attenInput.dims[1]});
moeFinal.Expansion(attenInput.dims);

gate.ToDevice(DataDevice::CPU);
float *gateData = (float*)gate.cpuData;
/// TODO: 这里是greedy topk, 需要实现group limited topk
for (int b = 0; b < batch * len; b++) {
Data *currentData = &attenInput;
if (batch * len != 1) {
Split(attenInput, 0, b, b + 1, attenPart);
currentData = &attenPart;
}
moePart.Resize(currentData->dims);
moePart.Allocate(0.0f);

for (int j = 0; j < this->num_experts_per_tok; j++) {
int idx = (int)(gateData[(b * this->num_experts_per_tok + j) * 2] + 1e-1);
float value = gateData[(b * this->num_experts_per_tok + j) * 2 + 1];
value *= routed_scaling_factor;
if (this->mergeSwiglu && CanRunMergeMOE()) {
MergeMOE (
attenInput, routerLogits,
weights[i], biass[i],
this->routed_scaling_factor, 1.0f,
this->num_experts_per_tok,
moeFinal
);
} else {
TopK(routerLogits, gate, this->num_experts_per_tok);
gate.ToDevice(DataDevice::CPU);
float *gateData = (float*)gate.cpuData;
/// TODO: 这里是greedy topk, 需要实现group limited topk

moeFinal = Data();
moeFinal.Resize({0, attenInput.dims[1]});
moeFinal.Expansion(attenInput.dims);

for (int b = 0; b < batch * len; b++) {
Data *currentData = &attenInput;
if (batch * len != 1) {
Split(attenInput, 0, b, b + 1, attenPart);
currentData = &attenPart;
}
moePart.Resize(currentData->dims);
moePart.Allocate(0.0f);

for (int j = 0; j < this->num_experts_per_tok; j++) {
int idx = (int)(gateData[(b * this->num_experts_per_tok + j) * 2] + 1e-1);
float value = gateData[(b * this->num_experts_per_tok + j) * 2 + 1];
value *= routed_scaling_factor;
if (this->mergeSwiglu) {
if (CanRunLinearEx(LinearExType::ExSwiglu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w3);
Swiglu(w3, w1);
}
} else {
if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1);
Silu(w1, w1);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3);
MulTo(w1, w3);
}
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".down_proj.weight"], Data(), w2);
AddTo(moePart, w2, value);
}

if (this->mergeSwiglu) {
if (CanRunLinearEx(LinearExType::ExSwiglu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gateup_proj.weight"], Data(), w3);
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w3);
Swiglu(w3, w1);
}
} else {
if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".gate_proj.weight"], Data(), w1);
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1);
Silu(w1, w1);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".up_proj.weight"], Data(), w3);
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.up_proj.weight"], Data(), w3);
MulTo(w1, w3);
}
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.experts." + std::to_string(idx) + ".down_proj.weight"], Data(), w2);
AddTo(moePart, w2, value);
}
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.down_proj.weight"], Data(), w2);
AddTo(moePart, w2);

if (this->mergeSwiglu) {
if (CanRunLinearEx(LinearExType::ExSwiglu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w1, LinearExType::ExSwiglu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gateup_proj.weight"], Data(), w3);
Swiglu(w3, w1);
}
} else {
if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
} else {
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.gate_proj.weight"], Data(), w1);
Silu(w1, w1);
}
Linear(*currentData, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.up_proj.weight"], Data(), w3);
MulTo(w1, w3);
CatDirect(moeFinal, moePart, 0);
}
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.shared_experts.down_proj.weight"], Data(), w2);
AddTo(moePart, w2);

CatDirect(moeFinal, moePart, 0);
}

moeFinal.Reshape(hiddenStates.dims);
Expand Down Expand Up @@ -530,11 +560,12 @@ namespace fastllm {
}
}
}
/*
if (sinDataPtr != &sinData)
delete sinDataPtr;
if (cosDataPtr != &cosData)
delete cosDataPtr;

*/
return lastRet;
}

Expand Down

0 comments on commit bdbf525

Please sign in to comment.