diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index f9d33f4aa88265..7f4f9cd038be4b 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -78,7 +78,7 @@ def test_completion_stream_vs_non_stream(): @pytest.mark.parametrize("n_slots", [1, 2]) def test_consistent_result_same_seed(n_slots: int): global server - server.n_slots = 1 + server.n_slots = n_slots server.start() last_res = None for _ in range(4): @@ -115,7 +115,7 @@ def test_different_result_different_seed(n_slots: int): @pytest.mark.parametrize("temperature", [0.0, 1.0]) def test_consistent_result_different_batch_size(n_batch: int, temperature: float): global server - server.n_batch = 1 + server.n_batch = n_batch server.start() last_res = None for _ in range(4): diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index ea17d6077e7cff..329bc585b41346 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7420,14 +7420,14 @@ static void ggml_compute_forward_mul_mat( if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + if (!llamafile_sgemm(params, + ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), - ith, nth, src0->type, src1->type, dst->type)) @@ -7472,14 +7472,14 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + if (!llamafile_sgemm(params, + ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), - ith, nth, src0->type, vec_dot_type, dst->type)) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 196de415564954..cacb093a7e1a26 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -53,6 +53,8 @@ #include "ggml-cpu-impl.h" #include "ggml-quants.h" +#include + #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else @@ -297,12 +299,11 @@ static int64_t BLOCK_SIZE(size_t m) { template class tinyBLAS { public: - tinyBLAS(int64_t k, + tinyBLAS(const ggml_compute_params * params, int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + TC *C, int64_t ldc) + : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) { } bool matmul(int64_t m, int64_t n) { @@ -310,6 +311,16 @@ class tinyBLAS { return false; // compute RN/RM for only tile with size RN&RN-1/RM&RM-1 #if VECTOR_REGISTERS == 32 + /* + if (m % 4 == 0 && n == 1) { + mnpack<1, 1, 2>(m, 1, 1); + return true; + } + if (m % 8 == 0 && n == 2) { + mnpack<8, 2, 1>(m, 2, 1); + return true; + } + */ if (m % 8 == 0 && n < 4) { mnpack<8, 3, 1>(m, n, n); return true; @@ -362,6 +373,25 @@ class tinyBLAS { } } + template + inline void gemv_bloc(int64_t ii) { + D Cv[RM] = {}; + for (int64_t i = 0; i < RM; ++i) { + for (int j=0; j < KN; j++) { + Cv[i][j] = 0; + } + } + for (int64_t l = 0; l < k; l += KN) { + V Bv = load(B + l); + for (int64_t i = 0; i < RM; ++i) { + V Av = load(A + lda * (ii + i) + l); + Cv[i] = madd(Av, Bv, Cv[i]); + } + } + for (int64_t i = 0; i < RM; ++i) + C[ii + i] = hsum(Cv[i]); + } + template inline void gemm_bloc(int64_t ii, int64_t jj) { D Cv[RN][RM] = {}; @@ -399,6 +429,52 @@ class tinyBLAS { template NOINLINE void gemm(int64_t m, int64_t n) { GGML_ASSERT(m % (RM * BM) == 0); + // const int64_t ytiles = m / (RM * BM); + const int64_t xtiles = (n + RN -1) / RN; + const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN; + + static std::atomic current_chunk; + if (params->ith == 0) { + GGML_ASSERT((xtiles * RN - n) >= 0); + GGML_ASSERT((xtiles * RN - n) < RN); + + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed); + } + ggml_barrier(params->threadpool); + int64_t ii = params->ith * RM * BM; + +/* + if (n==1) { + while (ii < m) { + for (int64_t bi = 0; bi < BM * RM; bi+=RM) { + gemv_bloc(ii + bi); + } + ii = std::atomic_fetch_add_explicit(¤t_chunk, 1, std::memory_order_relaxed) * RM * BM; + } + } else { + */ + while (ii < m) { + for (int64_t bi = 0; bi < BM * RM; bi+=RM) { + int64_t jj = 0; + for (; jj(ii + bi, jj); + } + if constexpr (RN > 1) { + for (; jj(ii + bi, jj); + } + } + GGML_ASSERT(jj == n); + } + ii = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed) * RM * BM; + } + //} + ggml_barrier(params->threadpool); + +#if 0 + const int ith = params->ith; + const int nth = params->nth; const int64_t ytiles = m / (RM * BM); const int64_t xtiles = (n + RN -1) / RN; const int64_t jj_RN = (xtiles - (xtiles * RN - n)); @@ -421,8 +497,10 @@ class tinyBLAS { } } } +#endif } + const ggml_compute_params * params; const TA *const A; const TB *const B; TC *const C; @@ -430,8 +508,6 @@ class tinyBLAS { const int64_t lda; const int64_t ldb; const int64_t ldc; - const int ith; - const int nth; }; ////////////////////////////////////////////////////////////////////////////////////////// @@ -1635,8 +1711,9 @@ class tinyBLAS_PPC { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ -bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, - int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) { +bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k, + const void *A, int64_t lda, const void *B, int64_t ldb, void *C, + int64_t ldc, int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); @@ -1644,8 +1721,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda assert(lda >= k); assert(ldb >= k); assert(ldc >= m); - assert(nth > 0); - assert(ith < nth); + assert(params->nth > 0); + assert(params->ith < params->nth); // only enable sgemm for prompt processing if (n < 2) @@ -1660,27 +1737,24 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda if (Btype != GGML_TYPE_F32) return false; #if defined(__AVX512F__) - tinyBLAS<16, __m512, __m512, float, float, float> tb{ + tinyBLAS<16, __m512, __m512, float, float, float> tb{ params, k, (const float *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); #elif defined(__AVX__) || defined(__AVX2__) - tinyBLAS<8, __m256, __m256, float, float, float> tb{ + tinyBLAS<8, __m256, __m256, float, float, float> tb{ params, k, (const float *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); #elif defined(__ARM_NEON) if (n < 4) return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ + tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params, k, (const float *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); #elif defined(__MMA__) if (k % 8) @@ -1689,7 +1763,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1700,29 +1774,26 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_BF16: { #if defined(__AVX512BF16__) if (Btype == GGML_TYPE_BF16) { - tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif defined(__AVX512F__) if (Btype == GGML_TYPE_BF16) { - tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif defined(__AVX2__) if (Btype == GGML_TYPE_BF16) { - tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #endif @@ -1730,41 +1801,54 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda } case GGML_TYPE_F16: { #if defined(__AVX512F__) +/* + if (Btype == GGML_TYPE_F32) { + tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ params, k, + (const ggml_fp16_t *)A, lda, + (const float *)B, ldb, + (float *)C, ldc}; + return tb.matmul(m, n); + } + */ if (Btype == GGML_TYPE_F16) { - tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ k, + tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) if (Btype == GGML_TYPE_F16) { - tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ k, + tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } + //if (Btype == GGML_TYPE_F32) { + // tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ k, + // (const ggml_fp16_t *)A, lda, + // (const float *)B, ldb, + // (float *)C, ldc, + // params->ith, params->nth}; + // return tb.matmul(m, n); + //} #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) if (n < 8) return false; if (Btype == GGML_TYPE_F16) { - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (Btype == GGML_TYPE_F32) { - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #endif @@ -1779,7 +1863,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) @@ -1787,7 +1871,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1803,7 +1887,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) @@ -1811,7 +1895,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1827,7 +1911,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q5_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1843,7 +1927,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_iq4_nl *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1855,6 +1939,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; } + (void)params; (void)m; (void)n; (void)k; @@ -1864,8 +1949,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (void)ldb; (void)C; (void)ldc; - (void)ith; - (void)nth; (void)Atype; (void)Btype; (void)Ctype; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.h b/ggml/src/ggml-cpu/llamafile/sgemm.h index caf6dd5567b3ad..3d2909515242a2 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.h +++ b/ggml/src/ggml-cpu/llamafile/sgemm.h @@ -5,8 +5,8 @@ extern "C" { #endif -bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, - const void *, int64_t, void *, int64_t, int, int, +bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t, + const void *, int64_t, const void *, int64_t, void *, int64_t, int, int, int); #ifdef __cplusplus