Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vulkan: dynamic subgroup size for the remaining k quants #10745

Merged
merged 2 commits into from
Dec 10, 2024

Conversation

netrunnereve
Copy link
Collaborator

The remaining K-quants now support variable subgroup sizes, with each superblock being calculated by 16 threads. See #10536 for my original Q6_K implementation.

All tests were done on my RX 470.

PR:

  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   2556 runs -   395.42 us/run - 117.44 MFLOP/run - 297.00 GFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   2556 runs -   533.33 us/run - 117.44 MFLOP/run - 220.20 GFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   3408 runs -   381.58 us/run - 117.44 MFLOP/run - 307.78 GFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   2556 runs -   479.21 us/run - 117.44 MFLOP/run - 245.07 GFLOPS
model size params backend ngl threads test t/s
llama 8B Q2_K - Medium 2.95 GiB 8.03 B Vulkan 100 8 pp512 120.50 ± 0.18
llama 8B Q2_K - Medium 2.95 GiB 8.03 B Vulkan 100 8 tg128 17.66 ± 0.07
llama 8B Q3_K - Medium 3.74 GiB 8.03 B Vulkan 100 8 pp512 112.05 ± 0.17
llama 8B Q3_K - Medium 3.74 GiB 8.03 B Vulkan 100 8 tg128 15.74 ± 0.02
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 pp512 127.11 ± 0.54
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 tg128 20.58 ± 0.00
llama 8B Q5_K - Small 5.21 GiB 8.03 B Vulkan 100 8 pp512 118.19 ± 0.30
llama 8B Q5_K - Small 5.21 GiB 8.03 B Vulkan 100 8 tg128 16.04 ± 0.01

Master:

  MUL_MAT(type_a=q2_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   1704 runs -   721.46 us/run - 117.44 MFLOP/run - 162.78 GFLOPS
  MUL_MAT(type_a=q3_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   1704 runs -   976.82 us/run - 117.44 MFLOP/run - 120.23 GFLOPS
  MUL_MAT(type_a=q4_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   2556 runs -   515.76 us/run - 117.44 MFLOP/run - 227.70 GFLOPS
  MUL_MAT(type_a=q5_K,type_b=f32,m=4096,n=1,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3]):                   1704 runs -   876.81 us/run - 117.44 MFLOP/run - 133.94 GFLOPS
model size params backend ngl threads test t/s
llama 8B Q2_K - Medium 2.95 GiB 8.03 B Vulkan 100 8 pp512 120.22 ± 0.21
llama 8B Q2_K - Medium 2.95 GiB 8.03 B Vulkan 100 8 tg128 9.57 ± 0.01
llama 8B Q3_K - Medium 3.74 GiB 8.03 B Vulkan 100 8 pp512 112.21 ± 0.02
llama 8B Q3_K - Medium 3.74 GiB 8.03 B Vulkan 100 8 tg128 9.39 ± 0.00
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 pp512 127.43 ± 0.28
llama 8B Q4_K - Small 4.36 GiB 8.03 B Vulkan 100 8 tg128 14.12 ± 0.00
llama 8B Q5_K - Small 5.21 GiB 8.03 B Vulkan 100 8 pp512 118.34 ± 0.38
llama 8B Q5_K - Small 5.21 GiB 8.03 B Vulkan 100 8 tg128 8.97 ± 0.00

I also tried to make Q6_K process multiple rows per subgroup like how it's done in mul_mat_vec, but it actually made inference slightly slower. I've left those changes in c2aa654 if anyone wants to play with it.

@netrunnereve netrunnereve requested a review from 0cc4m December 10, 2024 03:08
@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Dec 10, 2024
@jeffbolznv
Copy link
Collaborator

The changes look good to me, but I haven't had a chance to test them yet. I'll try to do it today.

Copy link
Collaborator

@0cc4m 0cc4m left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good and I also see a significant improvements in tg on Radeon Pro VII.

@jeffbolznv
Copy link
Collaborator

Looks perf neutral on NVIDIA, as expected.

@0cc4m 0cc4m merged commit dafae66 into ggerganov:master Dec 10, 2024
47 checks passed
@netrunnereve netrunnereve deleted the vulkan2 branch December 11, 2024 01:47
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Dec 20, 2024
…0745)

* q5_k

q4_k

q3_k

q2_k

q6_k multi row example

* revert as multi row isnt faster for k quants
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants