From 74af6426ea3fcceebbfe4a9cd1273e5275b0f5e0 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Sat, 28 Dec 2024 19:16:29 +0800 Subject: [PATCH] Fix wkv test & add gla test Signed-off-by: Molly Sophia --- tests/test-backend-ops.cpp | 42 ++++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ccdd3fb57a504..b4b9bd14825bb 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1659,17 +1659,46 @@ struct test_rwkv_wkv6 : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { const int64_t n_tokens = n_seq_tokens * n_seqs; - ggml_tensor * r = ggml_new_tensor(ctx, type, 4, std::vector{ 1, head_size, head_count, n_tokens }.data()); - ggml_tensor * k = ggml_new_tensor(ctx, type, 4, std::vector{ head_size, 1, head_count, n_tokens }.data()); - ggml_tensor * v = ggml_new_tensor(ctx, type, 4, std::vector{ 1, head_size, head_count, n_tokens }.data()); + ggml_tensor * r = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); ggml_tensor * tf = ggml_new_tensor(ctx, type, 2, std::vector{ head_size, head_count }.data()); - ggml_tensor * td = ggml_new_tensor(ctx, type, 4, std::vector{ 1, head_size, head_count, n_tokens }.data()); + ggml_tensor * td = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data()); ggml_tensor * out = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, s); return out; } }; +// GGML_OP_GATED_LINEAR_ATTN +struct test_gla : public test_case { + const ggml_type type; + + const int64_t head_count; + const int64_t head_size; + const int64_t n_seq_tokens; + const int64_t n_seqs; + + std::string vars() override { + return VARS_TO_STR5(type, head_count, head_size, n_seq_tokens, n_seqs); + } + + test_gla(ggml_type type = GGML_TYPE_F32, + int64_t head_count = 32, int64_t head_size = 64, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) + : type(type), head_count(head_count), head_size(head_size), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + const int64_t n_tokens = n_seq_tokens * n_seqs; + ggml_tensor * q = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * k = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * v = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * g = ggml_new_tensor(ctx, type, 3, std::vector{ head_size, head_count, n_tokens }.data()); + ggml_tensor * s = ggml_new_tensor(ctx, type, 2, std::vector{ head_size * head_size * head_count, n_seqs }.data()); + ggml_tensor * out = ggml_gated_linear_attn(ctx, k, v, q, g, s, pow(head_size, -0.5)); + return out; + } +}; + // GGML_OP_MUL_MAT struct test_mul_mat : public test_case { const ggml_type type_a; @@ -3626,6 +3655,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 1, 1)); + test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 1)); + test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); + test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); + for (int i = 1; i < 9; ++i) { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));