Skip to content

Commit

Permalink
deepseekv2 int4加速
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 4, 2024
1 parent bdbf525 commit 5adbec0
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 139 deletions.
292 changes: 154 additions & 138 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,141 @@ namespace fastllm {
}
}
};

struct MultiThreadLinearInt4NoZeroOp : MultiThreadBaseOp {
uint8_t *a, *b;
int32_t *c;
int n, m, k, kstride;
int *weightSums;
float *weightMins, *scales, *bias;
LowBitConfig *config;
float *inputSums;

MultiThreadLinearInt4NoZeroOp(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int kstride,
int *weightSums, float *weightMins, float *scales, float *bias, LowBitConfig *config,
float *inputSums) :
a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride),
weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), config(config), inputSums(inputSums) {}

#ifdef __ARM_FEATURE_DOTPROD
inline static void RunSomeBlock(uint8_t *weightWalk, uint8_t *inputStart, int32_t *c,
int curBlock, uint32x2_t *sum, uint8x8x2_t *vi,
int block, int k, int m, int kstride) {
uint8x8_t maskHigh = vdup_n_u8(0xF0);
uint8x8_t maskLow = vdup_n_u8(0xF);
for (int i = 0; i < k; i++) {
std::vector <int> values = std::vector <int> (curBlock, 0);
uint8_t *inputWalk = inputStart;
int j = 0;

for (int j = 0; j < curBlock; j++) {
sum[j][0] = sum[j][1] = 0;
}
for (; j + 15 < m; j += 16) {
for (int x = 0; x < curBlock; x++) {
vi[x] = vld2_u8(inputWalk + j + m * x);
}
uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2);
uint8x8_t va = vand_u8(ori, maskLow);
uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4);
for (int x = 0; x < curBlock; x++) {
sum[x] = vdot_u32(sum[x], va, vi[x].val[1]);
sum[x] = vdot_u32(sum[x], vb, vi[x].val[0]);
}
}
for (int x = 0; x < curBlock; x++) {
values[x] += sum[x][0] + sum[x][1];
}

for (; j + 1 < m; j += 2) {
int id = (i * m + j) / 2;
for (int x = 0; x < curBlock; x++) {
values[x] += (weightWalk[id] >> 4) * inputWalk[j + x * m];
values[x] += (weightWalk[id] & 0xF) * inputWalk[j + 1 + x * m];
}
}

for (int x = 0; x < curBlock; x++) {
c[(block + x) * kstride + i] = values[x];
}
}
}
#endif
void Run() {
#ifdef __ARM_FEATURE_DOTPROD
#define RUNBLOCK(x) for (; block + (x - 1) < n; block += (x)) RunSomeBlock(b, a + block * m, c, (x), sum, vi, block, k, m, kstride);
int block = 0;
uint32x2_t sum[16];
uint8x8x2_t vi[16];
RUNBLOCK(16);
RUNBLOCK(8);RUNBLOCK(7);RUNBLOCK(6);RUNBLOCK(5);
RUNBLOCK(4);RUNBLOCK(3);RUNBLOCK(2);RUNBLOCK(1);
#undef RUNBLOCK
#else
int block = 0;

for (; block < n; block++) {
uint8_t *weightWalk = b;
uint8_t *inputStart = a + block * m;

for (int i = 0; i < k; i++) {
int value = 0;
uint8_t *inputWalk = inputStart;
int j = 0;
#ifdef __ARM_FEATURE_DOTPROD
uint8x8_t maskHigh = vdup_n_u8(0xF0);
uint8x8_t maskLow = vdup_n_u8(0xF);
uint32x2_t sum0 = {0, 0};

for (; j + 15 < m; j += 16) {
uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2);
uint8x8x2_t in = vld2_u8(inputWalk + j);
uint8x8_t va = vand_u8(ori, maskLow);
uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4);
sum0 = vdot_u32(sum0, va, in.val[1]);
sum0 = vdot_u32(sum0, vb, in.val[0]);
}
value += sum0[0] + sum0[1];
#elif defined(__aarch64__)
uint8x8_t maskHigh = vdup_n_u8(0xF0);
uint8x8_t maskLow = vdup_n_u8(0xF);
uint32x4_t sum0 = {0, 0, 0, 0};

for (; j + 15 < m; j += 16) {
uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2);
uint8x8x2_t in = vld2_u8(inputWalk + j);
uint8x8_t va = vand_u8(ori, maskLow);
uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4);
sum0 = vpadalq_u16(sum0, vmull_u8(va, in.val[1]));
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
}
value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m;
#endif

for (; j + 1 < m; j += 2) {
int id = (i * m + j) / 2;
value += (weightWalk[id] >> 4) * inputWalk[j];
value += (weightWalk[id] & 0xF) * inputWalk[j + 1];
}

c[block * kstride + i] = value;
}
}
#endif
for (int block = 0; block < n; block++) {
for (int i = 0; i < k; i++) {
int value = c[block * kstride + i];
value -= weightSums[i] * config[block].zeroPoint;
((float*)c)[block * kstride + i] = scales[i] * config[block].scale * value +
weightMins[i] * ((float)inputSums[block] - (int)config[block].zeroPoint * m) * config[block].scale +
(bias == nullptr ? 0.0 : bias[i]);
}
}
}
};

struct MultiThreadLinearInt4GroupOp : MultiThreadBaseOp {
uint8_t *a, *b;
Expand Down Expand Up @@ -652,7 +787,9 @@ namespace fastllm {
float routeScale = floatParams.find("routeScale") != floatParams.end() ? floatParams.find("routeScale")->second : 1.0f;
output.Allocate();

if (input.dataType == DataType::FLOAT32 && weights[0]->dataType == DataType::INT4_GROUP && input.dims[0] == 1) {
if (input.dataType == DataType::FLOAT32 &&
(weights[0]->dataType == DataType::INT4_GROUP || weights[0]->dataType == DataType::INT4_NOZERO)
&& input.dims[0] == 1) {
int dimsLen = logits.dims.size();
int outer = logits.Count(0) / logits.Count(dimsLen - 1);
int channels = logits.dims[dimsLen - 1];
Expand Down Expand Up @@ -780,6 +917,10 @@ namespace fastllm {
int mid = weights[idx * 2]->dims[0] / 2;
Data *weightDown = weights[idx * 2 + 1];
int groupDown = weightDown->group, groupCntDown = weightDown->groupCnt;
if (weightDown->dataType != DataType::INT4_GROUP) {
groupDown = 1;
groupCntDown = mid;
}
auto &inputConfigs = inputConfigsDown[l];
auto &inputSums = inputSumsDown[l];
auto &iscales = iscalesDown[l];
Expand Down Expand Up @@ -816,6 +957,10 @@ namespace fastllm {
auto &izeros = izerosDown[l];
auto &uinputDown = uinputsDown[l];
int curThread = (curK / k) * base;
if (weightDown->dataType != DataType::INT4_GROUP) {
groupDown = 1;
groupCntDown = mid;
}
MultiplyInt4GroupMultiThreadLaunch(uinputDown.data(), (uint8_t*)weightDown->cpuData, (int32_t *) results[l], 1, mid, m,
weightDown->weightSum.data(), weightDown->mins.data(), weightDown->scales.data(), nullptr,
inputSums, iscales, izeros,
Expand Down Expand Up @@ -1823,141 +1968,6 @@ namespace fastllm {
}
};

struct MultiThreadLinearInt4NoZeroOp : MultiThreadBaseOp {
uint8_t *a, *b;
int32_t *c;
int n, m, k, kstride;
int *weightSums;
float *weightMins, *scales, *bias;
LowBitConfig *config;
int *inputSums;

MultiThreadLinearInt4NoZeroOp(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int kstride,
int *weightSums, float *weightMins, float *scales, float *bias, LowBitConfig *config,
int *inputSums) :
a(a), b(b), c(c), n(n), m(m), k(k), kstride(kstride),
weightSums(weightSums), weightMins(weightMins), scales(scales), bias(bias), config(config), inputSums(inputSums) {}

#ifdef __ARM_FEATURE_DOTPROD
inline static void RunSomeBlock(uint8_t *weightWalk, uint8_t *inputStart, int32_t *c,
int curBlock, uint32x2_t *sum, uint8x8x2_t *vi,
int block, int k, int m, int kstride) {
uint8x8_t maskHigh = vdup_n_u8(0xF0);
uint8x8_t maskLow = vdup_n_u8(0xF);
for (int i = 0; i < k; i++) {
std::vector <int> values = std::vector <int> (curBlock, 0);
uint8_t *inputWalk = inputStart;
int j = 0;

for (int j = 0; j < curBlock; j++) {
sum[j][0] = sum[j][1] = 0;
}
for (; j + 15 < m; j += 16) {
for (int x = 0; x < curBlock; x++) {
vi[x] = vld2_u8(inputWalk + j + m * x);
}
uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2);
uint8x8_t va = vand_u8(ori, maskLow);
uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4);
for (int x = 0; x < curBlock; x++) {
sum[x] = vdot_u32(sum[x], va, vi[x].val[1]);
sum[x] = vdot_u32(sum[x], vb, vi[x].val[0]);
}
}
for (int x = 0; x < curBlock; x++) {
values[x] += sum[x][0] + sum[x][1];
}

for (; j + 1 < m; j += 2) {
int id = (i * m + j) / 2;
for (int x = 0; x < curBlock; x++) {
values[x] += (weightWalk[id] >> 4) * inputWalk[j + x * m];
values[x] += (weightWalk[id] & 0xF) * inputWalk[j + 1 + x * m];
}
}

for (int x = 0; x < curBlock; x++) {
c[(block + x) * kstride + i] = values[x];
}
}
}
#endif
void Run() {
#ifdef __ARM_FEATURE_DOTPROD
#define RUNBLOCK(x) for (; block + (x - 1) < n; block += (x)) RunSomeBlock(b, a + block * m, c, (x), sum, vi, block, k, m, kstride);
int block = 0;
uint32x2_t sum[16];
uint8x8x2_t vi[16];
RUNBLOCK(16);
RUNBLOCK(8);RUNBLOCK(7);RUNBLOCK(6);RUNBLOCK(5);
RUNBLOCK(4);RUNBLOCK(3);RUNBLOCK(2);RUNBLOCK(1);
#undef RUNBLOCK
#else
int block = 0;

for (; block < n; block++) {
uint8_t *weightWalk = b;
uint8_t *inputStart = a + block * m;

for (int i = 0; i < k; i++) {
int value = 0;
uint8_t *inputWalk = inputStart;
int j = 0;
#ifdef __ARM_FEATURE_DOTPROD
uint8x8_t maskHigh = vdup_n_u8(0xF0);
uint8x8_t maskLow = vdup_n_u8(0xF);
uint32x2_t sum0 = {0, 0};

for (; j + 15 < m; j += 16) {
uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2);
uint8x8x2_t in = vld2_u8(inputWalk + j);
uint8x8_t va = vand_u8(ori, maskLow);
uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4);
sum0 = vdot_u32(sum0, va, in.val[1]);
sum0 = vdot_u32(sum0, vb, in.val[0]);
}
value += sum0[0] + sum0[1];
#elif defined(__aarch64__)
uint8x8_t maskHigh = vdup_n_u8(0xF0);
uint8x8_t maskLow = vdup_n_u8(0xF);
uint32x4_t sum0 = {0, 0, 0, 0};

for (; j + 15 < m; j += 16) {
uint8x8_t ori = vld1_u8(weightWalk + (i * m + j) / 2);
uint8x8x2_t in = vld2_u8(inputWalk + j);
uint8x8_t va = vand_u8(ori, maskLow);
uint8x8_t vb = vshr_n_u8(vand_u8(ori, maskHigh), 4);
sum0 = vpadalq_u16(sum0, vmull_u8(va, in.val[1]));
sum0 = vpadalq_u16(sum0, vmull_u8(vb, in.val[0]));
}
value += sum0[0] + sum0[1] + sum0[2] + sum0[3];
#elif defined(__AVX2__)
value += DotU4U8(weightWalk + i * m / 2, inputWalk, m);
j += m;
#endif

for (; j + 1 < m; j += 2) {
int id = (i * m + j) / 2;
value += (weightWalk[id] >> 4) * inputWalk[j];
value += (weightWalk[id] & 0xF) * inputWalk[j + 1];
}

c[block * kstride + i] = value;
}
}
#endif
for (int block = 0; block < n; block++) {
for (int i = 0; i < k; i++) {
int value = c[block * kstride + i];
value -= weightSums[i] * config[block].zeroPoint;
((float*)c)[block * kstride + i] = scales[i] * config[block].scale * value +
weightMins[i] * ((float)inputSums[block] - (int)config[block].zeroPoint * m) * config[block].scale +
(bias == nullptr ? 0.0 : bias[i]);
}
}
}
};

//a = [n, m], b = [k, m], c = aT(b') = [n, k]
void MultiplyMultiThread(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k, int threadNum) {
auto *pool = GetAlivePool();
Expand Down Expand Up @@ -2029,7 +2039,7 @@ namespace fastllm {
void MultiplyInt4NoZeroMultiThread(uint8_t *a, uint8_t *b, int32_t *c, int n, int m, int k,
int *weightSums, float *weightMins, float *scales, float *bias,
std::vector <LowBitConfig> &configs, int threadNum) {
std::vector <int> inputSums;
std::vector <float> inputSums;
for (int i = 0; i < n; i++) {
int sum = 0;
for (int j = 0; j < m; j++) {
Expand Down Expand Up @@ -2078,10 +2088,16 @@ namespace fastllm {

for (int i = 0; i < threadNum; i++) {
int end = (i == threadNum - 1 ? k : cur + per + (cur + per * (threadNum - i) < k));
ops[startTid + i] = new MultiThreadLinearInt4GroupOp(a, b + cur * m / 2, c + cur, n, m, end - cur, k,
if (group > 1) {
ops[startTid + i] = new MultiThreadLinearInt4GroupOp(a, b + cur * m / 2, c + cur, n, m, end - cur, k,
weightSums + cur * group, weightMins + cur * group, scales + cur * group,
(bias == nullptr ? (float *) nullptr : bias + cur), iscales.data(), izeros.data(),
inputSums.data(), group, groupCnt);
} else {
ops[startTid + i] = new MultiThreadLinearInt4NoZeroOp(a, b + cur * m / 2, c + cur, n, m, end - cur, k,
weightSums + cur * group, weightMins + cur * group, scales + cur * group,
(bias == nullptr ? (float *) nullptr : bias + cur), configs.data(), inputSums.data());
}
cur = end;
}
for (int i = 0; i < threadNum; i++) {
Expand Down
2 changes: 1 addition & 1 deletion src/models/deepseekv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ namespace fastllm {
pastValue.ToDevice(DataDevice::CUDA);
}

int unitLen = 64;
int unitLen = 128;
#ifdef USE_CUDA
unitLen = 128;
#endif
Expand Down

0 comments on commit 5adbec0

Please sign in to comment.