-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vulkan: multi-row k quants #10846
Merged
Merged
vulkan: multi-row k quants #10846
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
c656d92
multi row k quant shaders!
netrunnereve 62dc170
merge master
netrunnereve 7bbd9cb
better row selection
netrunnereve 63c27eb
more row choices
netrunnereve fa70739
readjust row selection
netrunnereve a3aea08
rm_kq=2 by default
netrunnereve File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,21 +6,15 @@ | |
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout (constant_id = 0) const uint BLOCK_SIZE = 32; | ||
layout (constant_id = 1) const uint NUM_ROWS = 1; | ||
|
||
shared FLOAT_TYPE tmp[BLOCK_SIZE]; | ||
|
||
void main() { | ||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z; | ||
|
||
if (row >= p.stride_d) { | ||
return; | ||
} | ||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE]; | ||
|
||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | ||
uint a_offset, b_offset, d_offset; | ||
get_offsets(a_offset, b_offset, d_offset); | ||
|
||
const uint num_blocks_per_row = p.ncols / QUANT_K; | ||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; | ||
|
||
// 16 threads are used to process each block | ||
const uint it_size = gl_WorkGroupSize.x/16; | ||
|
@@ -38,15 +32,15 @@ void main() { | |
const uint s_offset = 8*v_im; | ||
const uint y_offset = 128*v_im + l0; | ||
|
||
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp | ||
FLOAT_TYPE temp[NUM_ROWS]; | ||
|
||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||
temp[i] = FLOAT_TYPE(0); | ||
} | ||
|
||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { | ||
const uint y_idx = i * QUANT_K + y_offset; | ||
|
||
f16vec2 d = data_a[ib0 + i].d; | ||
const FLOAT_TYPE dall = d.x; | ||
const FLOAT_TYPE dmin = d.y; | ||
|
||
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0]; | ||
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8]; | ||
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16]; | ||
|
@@ -56,58 +50,84 @@ void main() { | |
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48]; | ||
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56]; | ||
|
||
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; | ||
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; | ||
|
||
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; | ||
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; | ||
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; | ||
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; | ||
|
||
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); | ||
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); | ||
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); | ||
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); | ||
|
||
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; | ||
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; | ||
uvec2 qs0 = uvec2(unpack8(qs0_u16)); | ||
uvec2 qs16 = uvec2(unpack8(qs16_u16)); | ||
|
||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); | ||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); | ||
[[unroll]] for (int l = 0; l < 2; ++l) { | ||
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), | ||
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), | ||
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), | ||
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), | ||
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), | ||
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), | ||
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), | ||
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); | ||
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), | ||
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), | ||
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), | ||
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), | ||
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), | ||
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), | ||
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), | ||
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); | ||
[[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||
f16vec2 d = data_a[ib0 + i].d; | ||
const FLOAT_TYPE dall = d.x; | ||
const FLOAT_TYPE dmin = d.y; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be |
||
|
||
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; | ||
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; | ||
|
||
uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; | ||
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; | ||
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; | ||
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; | ||
|
||
uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); | ||
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); | ||
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); | ||
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); | ||
|
||
uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; | ||
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; | ||
uvec2 qs0 = uvec2(unpack8(qs0_u16)); | ||
uvec2 qs16 = uvec2(unpack8(qs16_u16)); | ||
|
||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); | ||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); | ||
[[unroll]] for (int l = 0; l < 2; ++l) { | ||
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3), | ||
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), | ||
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3), | ||
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), | ||
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3), | ||
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), | ||
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3), | ||
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); | ||
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]), | ||
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]), | ||
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]), | ||
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]), | ||
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]), | ||
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]), | ||
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]), | ||
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); | ||
} | ||
temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n])); | ||
} | ||
temp = fma(dall, sum1, fma(-dmin, sum2, temp)); | ||
} | ||
|
||
tmp[gl_LocalInvocationID.x] = temp; | ||
|
||
// sum up partial sums and write back result | ||
[[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||
tmpsh[n][tid] = temp[n]; | ||
} | ||
barrier(); | ||
[[unroll]] for (uint s = gl_WorkGroupSize.x/2; s > 0; s >>= 1) { | ||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { | ||
if (tid < s) { | ||
tmp[tid] += tmp[tid + s]; | ||
[[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||
tmpsh[n][tid] += tmpsh[n][tid + s]; | ||
} | ||
} | ||
barrier(); | ||
} | ||
if (tid == 0) { | ||
data_d[d_offset + row] = D_TYPE(tmp[0]); | ||
[[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]); | ||
} | ||
} | ||
} | ||
|
||
void main() { | ||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); | ||
|
||
// do NUM_ROWS at a time, unless there aren't enough remaining rows | ||
if (first_row + NUM_ROWS <= p.stride_d) { | ||
compute_outputs(first_row, NUM_ROWS); | ||
} else { | ||
if (first_row >= p.stride_d) { | ||
return; | ||
} | ||
compute_outputs(first_row, p.stride_d - first_row); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@netrunnereve I completely missed this, and, because we included all arithmetic_type extensions at the top, so did glslc, but you can't use float16 variables anywhere, unless the device supports them. I found this now because I saw validation issues about use of the float16 extension on a device that does not support it. It might be better to include only the arithmetic type extensions that are actually used, then this kind of issue would show up during shader compilation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure let me fix this as part of #11081.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind you're already handling this in #11161.