From c2aa654ad681fe4877a146ba1cb19e4378c51288 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:00:58 -0500 Subject: [PATCH 1/2] q5_k q4_k q3_k q2_k q6_k multi row example --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 36 ++--- .../vulkan-shaders/mul_mat_vec_base.comp | 2 - .../vulkan-shaders/mul_mat_vec_q2_k.comp | 27 ++-- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 27 ++-- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 26 ++-- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 21 ++- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 147 ++++++++++-------- 7 files changed, 159 insertions(+), 127 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5d9eba9833fa3..adf183551dc6a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -44,12 +44,6 @@ #define MAX_VK_BUFFERS 256 -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 1 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ @@ -1792,11 +1786,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -1806,11 +1800,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -1820,11 +1814,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); // dequant shaders diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp index 2ec1af5c75542..3894fca8260a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp @@ -2,8 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require -#define K_QUANTS_PER_ITERATION 2 - #ifdef MUL_MAT_ID #define EXPERT_COUNT 8 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index fcf02210ea6c7..1a5350d99ea14 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -3,9 +3,11 @@ #include "mul_mat_vec_base.comp" -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE tmp[32]; +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; @@ -20,22 +22,25 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y_idx = i * QUANT_K + y_offset; f16vec2 d = data_a[ib0 + i].d; @@ -71,7 +76,7 @@ void main() { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + [[unroll]] for (int l = 0; l < 2; ++l) { sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), @@ -96,7 +101,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 723fadde0d52b..b19c3811136bb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -3,9 +3,11 @@ #include "mul_mat_vec_base.comp" -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE tmp[32]; +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; @@ -20,17 +22,20 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; - const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const uint v_in = tid - step*v_im; // 0...15 or 0...7 + const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - step*v_im; // 0...15 or 0...7 const uint8_t m = uint8_t(1 << (4 * v_im)); - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; @@ -38,7 +43,7 @@ void main() { const uint s_shift = 4 * v_im; - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); @@ -66,7 +71,7 @@ void main() { u8vec2 s10 = unpack8(s10_16); FLOAT_TYPE sum = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + [[unroll]] for (int l = 0; l < 2; ++l) { sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), @@ -83,7 +88,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 5846f2e86f3e3..b86d28589c64c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -4,11 +4,12 @@ #include "mul_mat_vec_base.comp" -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE tmp[32]; +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; -// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads void main() { const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; @@ -22,14 +23,17 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; - const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const uint step = 4; - const uint il = tid/step; // 0...3 - const uint ir = tid - step*il; // 0...7 or 0...3 - const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const uint il = itid/step; // 0...3 + const uint ir = itid - step*il; // 0...7 or 0...3 + const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; @@ -40,7 +44,7 @@ void main() { FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; @@ -115,7 +119,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index b455cbd31ec91..fd243cf9161fc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -4,9 +4,11 @@ #include "mul_mat_vec_base.comp" -layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE tmp[32]; +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; @@ -21,11 +23,14 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1 + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...16 + const uint ix = tid/16; - const uint il = tid/4; // 0...3 - const uint ir = tid - 4*il; // 0...7 or 0...3 + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...7 or 0...3 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; @@ -36,7 +41,7 @@ void main() { FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; @@ -143,7 +148,7 @@ void main() { // sum up partial sums and write back result barrier(); - [[unroll]] for (uint s = 16; s > 0; s >>= 1) { + [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 760aff85499f4..1da7d7eb05e92 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -7,21 +7,15 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; -shared FLOAT_TYPE tmp[BLOCK_SIZE]; - -void main() { - const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; - - if (row >= p.stride_d) { - return; - } +shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; - const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; @@ -42,69 +36,96 @@ void main() { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - FLOAT_TYPE scales[4]; - scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); - scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); - scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); - scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); - - uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); - uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); - - uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; - uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; - uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; - uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); - uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; - uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; - uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; - uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; - - uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; - uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; - uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; - uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; - - uvec4 q0 = uvec4(unpack8(q0_u32)); - uvec4 q1 = uvec4(unpack8(q1_u32)); - uvec4 q2 = uvec4(unpack8(q2_u32)); - uvec4 q3 = uvec4(unpack8(q3_u32)); - - B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4]; - B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8]; - B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16]; - B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24]; - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), - fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), - fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), - fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); + FLOAT_TYPE temp[NUM_ROWS]; + + for (uint i = 0; i < NUM_ROWS; ++i) { + temp[i] = FLOAT_TYPE(0); + } + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + FLOAT_TYPE scales[4]; + scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); + scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); + scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); + scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); + + uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; + uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + uvec4 q0 = uvec4(unpack8(q0_u32)); + uvec4 q1 = uvec4(unpack8(q1_u32)); + uvec4 q2 = uvec4(unpack8(q2_u32)); + uvec4 q3 = uvec4(unpack8(q3_u32)); + + B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4]; + B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8]; + B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16]; + B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 4; ++l) { + sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), + fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), + fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), + fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); + } + temp[n] += sum * d; } - temp += sum * d; } - tmp[gl_LocalInvocationID.x] = temp; // sum up partial sums and write back result - + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[n][tid] = temp[n]; + } barrier(); - [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[n][tid] += tmpsh[n][tid + s]; + } } barrier(); } if (tid == 0) { - data_d[d_offset + row] = D_TYPE(tmp[0]); + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); + } + } +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); } } From 8ee6beea08cfa11323ad6e265572d443c846ab80 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Mon, 9 Dec 2024 21:35:13 -0500 Subject: [PATCH 2/2] revert as multi row isnt faster for k quants --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 147 ++++++++---------- 2 files changed, 66 insertions(+), 87 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index adf183551dc6a..9850a9533fd8d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1790,7 +1790,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -1804,7 +1804,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); @@ -1818,7 +1818,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {subgroup_size_16, 2}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1, true); // dequant shaders diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index 1da7d7eb05e92..760aff85499f4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -7,15 +7,21 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (constant_id = 0) const uint BLOCK_SIZE = 32; -layout (constant_id = 1) const uint NUM_ROWS = 1; -shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; + + if (row >= p.stride_d) { + return; + } -void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { uint a_offset, b_offset, d_offset; get_offsets(a_offset, b_offset, d_offset); const uint num_blocks_per_row = p.ncols / QUANT_K; + const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; // 16 threads are used to process each block const uint it_size = gl_WorkGroupSize.x/16; @@ -36,96 +42,69 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint s_offset = 8*v_im + is; const uint y_offset = 128*v_im + l0; - FLOAT_TYPE temp[NUM_ROWS]; - - for (uint i = 0; i < NUM_ROWS; ++i) { - temp[i] = FLOAT_TYPE(0); - } - - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; - - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { - const uint y_idx = i * QUANT_K + y_offset; - - const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); - - FLOAT_TYPE scales[4]; - scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); - scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); - scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); - scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); - - uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); - uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); - - uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; - uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; - uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; - uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; - - uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); - uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; - uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; - uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; - uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; - - uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; - uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; - uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; - uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; - - uvec4 q0 = uvec4(unpack8(q0_u32)); - uvec4 q1 = uvec4(unpack8(q1_u32)); - uvec4 q2 = uvec4(unpack8(q2_u32)); - uvec4 q3 = uvec4(unpack8(q3_u32)); - - B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4]; - B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8]; - B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16]; - B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24]; - - FLOAT_TYPE sum = FLOAT_TYPE(0.0); - [[unroll]] for (int l = 0; l < 4; ++l) { - sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), - fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), - fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), - fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); - } - temp[n] += sum * d; + FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { + const uint y_idx = i * QUANT_K + y_offset; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + FLOAT_TYPE scales[4]; + scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); + scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); + scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); + scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); + + uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; + uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + uvec4 q0 = uvec4(unpack8(q0_u32)); + uvec4 q1 = uvec4(unpack8(q1_u32)); + uvec4 q2 = uvec4(unpack8(q2_u32)); + uvec4 q3 = uvec4(unpack8(q3_u32)); + + B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4]; + B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8]; + B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16]; + B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24]; + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 4; ++l) { + sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), + fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), + fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), + fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); } + temp += sum * d; } + tmp[gl_LocalInvocationID.x] = temp; // sum up partial sums and write back result - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] = temp[n]; - } + barrier(); - [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + [[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { if (tid < s) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - tmpsh[n][tid] += tmpsh[n][tid + s]; - } + tmp[tid] += tmp[tid + s]; } barrier(); } if (tid == 0) { - [[unroll]] for (uint n = 0; n < num_rows; ++n) { - data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); - } - } -} - -void main() { - const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); - - // do NUM_ROWS at a time, unless there aren't enough remaining rows - if (first_row + NUM_ROWS <= p.stride_d) { - compute_outputs(first_row, NUM_ROWS); - } else { - if (first_row >= p.stride_d) { - return; - } - compute_outputs(first_row, p.stride_d - first_row); + data_d[d_offset + row] = D_TYPE(tmp[0]); } }