Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml: aarch64: implement SVE kernels for q4_K_q8_K vector dot #11227

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -5573,7 +5573,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r

uint32_t utmp[4];

#ifdef __ARM_NEON
#ifdef __ARM_FEATURE_SVE
float sumf = 0;
for (int i = 0; i < nb; ++i) {

const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);

const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));

memcpy(utmp, x[i].scales, K_SCALE_SIZE);

uint32x2_t mins8 = { 0 };
mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);

utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;

const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
sumf -= dmin * vaddvq_s32(prod);

const uint8_t * scales = (const uint8_t *)utmp;

const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;

const int vector_length = ggml_cpu_get_sve_cnt()*8;
const svuint8_t m4b = svdup_n_u8(0xf);
const svint32_t mzero = svdup_n_s32(0);
svint32_t sumi1 = svdup_n_s32(0);
svint32_t sumi1_1 = svdup_n_s32(0);
svint32_t sumi1_2 = svdup_n_s32(0);
svint32_t sumi2 = svdup_n_s32(0);
svint32_t sumi2_1 = svdup_n_s32(0);
svint32_t sumi2_2 = svdup_n_s32(0);
switch (vector_length) {
case 128:
{
for (int j = 0; j < QK_K/64; ++j) {
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);

q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
q4 += 32;
}
sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
} break;
case 256:
case 512:
{
for (int j = 0; j < QK_K/64; ++j) {
const svuint8_t q4bits = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);

q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
}
sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
} break;
default:
assert(false && "Unsupported vector length");
break;
}
}
*s = sumf;
#elif __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
const int32x4_t mzero = vdupq_n_s32(0);

Expand Down
Loading