Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add FP8 support to gguf/llama: #10055

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft

Conversation

Djip007
Copy link
Contributor

@Djip007 Djip007 commented Oct 26, 2024

After make many test with FP8 this is the first part for add FP8 on llama.cpp.

I made some choose:
I implement 4 FP8 support:

  • E5M2 & E4M3: for use with FP8 distributed model
    it is not for create quantized model from FP16/BF16 model, but for FP8 distributed model like Llama-3.1-405B-FP8 so I only add element for run.

  • E4M3_Q & E3M4_Q: for gguf quantized model.
    With this on we can quantize model I made test with Mistral-Nemo / Mistral-7B and Llama-3.1-8B .
    the weight have simple quantization process like with Q8_0, but with size bloc of 256 (Q8_0 use bloc of 32)

For now all is in place for CPU backend as reference for any arch, so only use pure C++. I use openmp-simd for a little speed-up.

I add 2 files gguf-fp8.h gguf-fp8.cpp to a lower change on existing sources, and simple merge.
For now I update Makefile build (need to figure what to do with CMake)

I prefer to use C++ for template on FP8 so I can have the 4 formats with little code size.

So for now any comment is welcome.
Next step is to make a true merge request, before add faster compute.

  • add FP8 type
  • add FP8 matmul fast OP
  • ...

For those curious I have a high speed version I use for simple test here for zen4. And we have more discussion here

@github-actions github-actions bot added script Script related examples ggml changes relating to the ggml tensor library for machine learning labels Oct 26, 2024
@Djip007
Copy link
Contributor Author

Djip007 commented Oct 26, 2024

Well... I have many things to correct!

@github-actions github-actions bot added the build Compilation issues label Oct 26, 2024
@Djip007 Djip007 force-pushed the feature/fp8 branch 6 times, most recently from 0723ac5 to 08b6344 Compare October 26, 2024 06:57
@github-actions github-actions bot added the testing Everything test related label Oct 26, 2024
@Djip007 Djip007 force-pushed the feature/fp8 branch 6 times, most recently from 60ece70 to 3faf670 Compare October 27, 2024 00:48
@Djip007
Copy link
Contributor Author

Djip007 commented Oct 27, 2024

OK. now build is in good shape.
What do you think?

Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The native FP8 types are interesting since AFAIK there is hardware support for these.

What are the advantages of the FP8_Q types, for example compared to the existing Q8_0 type?

Makefile Outdated
ggml/src/ggml-fp8.cpp \
ggml/src/ggml-fp8.h \
ggml/src/ggml-common.h
$(CXX) $(CXXFLAGS) -std=c++17 -c $< -o $@
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should target -std=c++11

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use "static constexpr" in this file that is not allowed with c++11 (and may be more).
And because it is the default with gcc, I develop/test the FP8 template with it.
I see you have it use for GGML_SYCL build to.

So I can try to rewrite it with c++11 before merge request, but is it really needed? Do you have target that do not support it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is not a good reason to change the project standard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have it for GGML_SYCL build, is there a reason to keep that old standart that was the 1er draft for modern C++?

I'll have a try to rewrite it with only c++11 and see how it look.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backends are for the most part self-contained and don't affect the rest of the project. What you are trying to merge here would be part of the core ggml and would require changing the standard for the entire project. I would also prefer to use C++17, but that is a discussion for another time.

// - fp8 simple type
typedef struct { uint8_t bits; } ggml_e5m2_t;
typedef struct { uint8_t bits; } ggml_e4m3_t;
typedef struct { uint8_t bits; } ggml_e3m4_t;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems unused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes my bad. will remove it for next commit.

ggml/src/ggml-common.h Show resolved Hide resolved
Comment on lines 36 to 13
template<int N> constexpr float EXP2() {
if constexpr (N==0) return 1;
if constexpr (N>0) return EXP2<N-1>()*2;
if constexpr (N<0) return EXP2<N+1>()/2;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++11:

Suggested change
template<int N> constexpr float EXP2() {
if constexpr (N==0) return 1;
if constexpr (N>0) return EXP2<N-1>()*2;
if constexpr (N<0) return EXP2<N+1>()/2;
}
constexpr float exp2(int n) {
return n < 0 ? 1.0f / (1 << -n) : 1 << n;
}

Copy link
Contributor Author

@Djip007 Djip007 Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not work for n > 32 or 64... we have it up to 127 ;)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr float exp2(int n) {
    return n == 0 ? 1 : n > 0 ? 2 * exp2(n - 1) : 0.5 * exp2(n + 1);
}

using type = FP8<_E>;
static constexpr int E=_E;
static constexpr int M=7-_E;
static constexpr int E_BIAS=EXP2<_E-1>()-1;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The not c++11 support is this 'static constexpr"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static const int would do the same here.

@Djip007
Copy link
Contributor Author

Djip007 commented Oct 28, 2024

What are the advantages of the FP8_Q types, for example compared to the existing Q8_0 type?

Some possible advantages:

  • Q8_0 use bloc of 32 FP8_Q can use bloc of 256, so get a bpw of 8.125 vs 8.5.
  • with CPU like zen4 we can use BF16 dot2 for better speed as we have with Q8_0 on CPU. (?)
  • On supported hardware we can use FP8 tensor core with only need to apply the scale after 256 weight vs 32 for Q8_0,
  • ...

some advantage over FP8 is:

  • with this FP8_Q to use native FP8 tensor core, we need to get the max for only 256 bloc elements for convert B to FP8, when we need the max for full row with "simple" FP8
  • On my experiment the FP8_Q can be use with more accuracy that with FP8, so can be use with more model

Note: In fact I have some good accuracy with bloc of 512 or 1024, but I prefer to keep the same bloc size that is use for other Qn_K quant.

@Djip007 Djip007 force-pushed the feature/fp8 branch 4 times, most recently from 58863d0 to 4e81ab0 Compare October 29, 2024 00:54
@Djip007
Copy link
Contributor Author

Djip007 commented Oct 29, 2024

OK have hard time with macOS compiler...
I think I finally remove the need for c++17 (and make a little clean-up) 🤞

thanks @slaren @ggerganov

@ggerganov
Copy link
Owner

Got it. So the native FP8 types are useful for models like https://huggingface.co/meta-llama/Llama-3.1-405B-FP8 and https://huggingface.co/neuralmagic/Mistral-Nemo-Instruct-2407-FP8 which are already trained in this format. While FP8_Q types can be used to quantize BF16 models into data format that can utilize the NVIDIA FP8 tensor cores.

@Djip007
Copy link
Contributor Author

Djip007 commented Nov 29, 2024

rebase on last master, with little change on fp32->fp8 rounding (add round to nearest even)

Copy link
Collaborator

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the conversion code should be in ggml base, the vec dot code needs to be built in the CPU backend since that's the only code optimized for the CPU arch.

I think this could be useful to support base models released in FP8, and possibly for faster inference on ADA GPUs that support FP8, but I am not convinced that it is worth to add the new quant types based on FP8. From what I can tell, they perform worse than Q8_0, so there is no clear use case for them.

I am also not sure that we should add the OpenMP SIMD dependency. I don't really understand how it is supposed to work. Does the compiler really produce better code when it is enabled?

Comment on lines +1849 to +1801
LOG("\n");
LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for FP8, but you write header on all step in you case. So it add many line in the output.
But may be best in an other PR.

@@ -426,6 +427,27 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
// les FP8...
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// les FP8...

ggml/src/ggml-fp8.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-fp8.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-fp8.cpp Outdated Show resolved Hide resolved
ggml/src/ggml-fp8.cpp Outdated Show resolved Hide resolved
ggml/src/ggml.c Outdated Show resolved Hide resolved
@@ -26,7 +26,7 @@ function has_cmd {
}

if has_cmd wget; then
cmd="wget -q --show-progress -c -O %s/%s %s"
cmd="wget -q -c -O %s/%s %s"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to change this?

Makefile Outdated Show resolved Hide resolved
check_cxx_compiler_flag("-fopenmp-simd" SUPPORTS_OPENMP_SIMD)
if (SUPPORTS_OPENMP_SIMD)
# OpenMP_RUNTIME_MSVC=experimental / if (MSVC)
message(STATUS "Using openmp_simd.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
message(STATUS "Using openmp_simd.")
message(STATUS "Using OPENMP_SIMD.")

@Djip007
Copy link
Contributor Author

Djip007 commented Nov 29, 2024

I am also not sure that we should add the OpenMP SIMD dependency. I don't really understand how it is supposed to work. Does the compiler really produce better code when it is enabled?

# without OpenMP SIMD
llama_perf_context_print: prompt eval time =   38115,07 ms /    33 tokens ( 1155,00 ms per token,     0,87 tokens per second)
llama_perf_context_print:        eval time =   19921,22 ms /    15 runs   ( 1328,08 ms per token,     0,75 tokens per second)

# with OpenMP SIMD
llama_perf_context_print: prompt eval time =   12686,11 ms /    33 tokens (  384,43 ms per token,     2,60 tokens per second)
llama_perf_context_print:        eval time =    6604,80 ms /    15 runs   (  440,32 ms per token,     2,27 tokens per second)

So yes it help. But it is true we can have much better than that with more optimised intrinsic code. So as you like.
I will add hight speed CPU matmul OP after this PR is merge.

@slaren
Copy link
Collaborator

slaren commented Nov 29, 2024

So yes it help. But it is true we can have much better than that with more optimised intrinsic code. So as you like.
I will add hight speed CPU matmul OP after this PR is merge.

We can keep it until it is replaced with intrinsics then. It should still be limited to only the CPU backend, though.

@Djip007
Copy link
Contributor Author

Djip007 commented Nov 29, 2024

Only the conversion code should be in ggml base, the vec dot code needs to be built in the CPU backend since that's the only code optimized for the CPU arch.

Good catch, i did it before the move of cpu in backend.
I'll have a look at it. an try to split the ggml-fp8 with ggml-cpu-fp8

We can keep it until it is replaced with intrinsics then. It should still be limited to only the CPU backend, though.

And try to move this simd in cpu backend to.

@Djip007 Djip007 marked this pull request as draft November 29, 2024 13:54
- correct local CI.
- correct perplexity log
E5M2 & E4M3: for use with FP8 distributed model
E4M3_Q & E3M4_Q: for gguf quantized model.

E5M2 and A4M3 type are use like FP16 / BF16 native.
E4M3_Q and E3M4_Q are define like Q8_0 with bloc size of 256 (like QK_K)
@rhjdvsgsgks
Copy link
Contributor

@0cc4m hi. is it possible to save more vram by using VK_KHR_8bit_storage after llama.cpp got fp8 support?

We're already using that extension to support the other quants, but it's only for 8-bit integers, not floats. FP8 is not currently supported by Vulkan at all as far as I can see.

is there any plan to support fp8 in vulkan backend in the future?

I don't see how. the problem is vulkan itself don't support fp8

sorry, i misunderstood what you mean before. now i get it:
fp8 is not supported by vulkan (not vulkan backend). VK_KHR_8bit_storage provides only int8. not fp8

apologize for my nonsense

@Djip007 Djip007 force-pushed the feature/fp8 branch 3 times, most recently from e6c4791 to 1488d7d Compare November 30, 2024 20:33
@Djip007 Djip007 marked this pull request as ready for review November 30, 2024 22:02
@Djip007
Copy link
Contributor Author

Djip007 commented Nov 30, 2024

OK rebase with later master and now it is possible made use of c++17.

I refactor the fp8 so dot is now in cpu-backend. There is a few copie for now but may be remove later with optimized code.
I have add a optimised dot for EnMn_Q to for some arch (can't test on all hardware, work on x86 (AVX2/AVX512))

# without OpenMP SIMD
llama_perf_context_print: prompt eval time =   38115,07 ms /    33 tokens ( 1155,00 ms per token,     0,87 tokens per second)
llama_perf_context_print:        eval time =   19921,22 ms /    15 runs   ( 1328,08 ms per token,     0,75 tokens per second)

# with OpenMP SIMD
llama_perf_context_print: prompt eval time =   12686,11 ms /    33 tokens (  384,43 ms per token,     2,60 tokens per second)
llama_perf_context_print:        eval time =    6604,80 ms /    15 runs   (  440,32 ms per token,     2,27 tokens per second)

# with AVX512:
llama_perf_context_print: prompt eval time =    4507,33 ms /    33 tokens (  136,59 ms per token,     7,32 tokens per second)
llama_perf_context_print:        eval time =    3207,50 ms /    15 runs   (  213,83 ms per token,     4,68 tokens per second)

# Q8_0:
llama_perf_context_print: prompt eval time =     887,87 ms /    33 tokens (   26,91 ms per token,    37,17 tokens per second)
llama_perf_context_print:        eval time =    3244,96 ms /    15 runs   (  216,33 ms per token,     4,62 tokens per second)

at least the tg is in line with the Q8 on my CPU.

I hope I managed to take into account all/most comments / reviews.

@slaren
Copy link
Collaborator

slaren commented Dec 1, 2024

  • E5M2 & E4M3: for use with FP8 distributed model
    it is not for create quantized model from FP16/BF16 model, but for FP8 distributed model like Llama-3.1-405B-FP8 so I only add element for run.

As far as I can tell, this conversion is not implemented, so it is not actually possible to use FP8 models.

@slaren
Copy link
Collaborator

slaren commented Dec 1, 2024

I have taken a look at the available FP8 models, and they use a single F32 scale per weight (edit: it's per row actually). This is something that is not supported in our quantization schemes and would require significant changes to implement. cuBLAS FP8 matrix multiplication also requires a scale per weight.

So as it is, this implementation cannot be used with FP8 models, and the FP8 Q types perform worse than Q8_0. And I don't think that having a different group size than Q8_0 is a good reason to add them either. I am not sure that there are any practical applications for this code as it is at the moment.

I still think that it would be good to have support for FP8 models, but it needs more work to achieve that.

@Djip007
Copy link
Contributor Author

Djip007 commented Dec 1, 2024

I have taken a look at the available FP8 models, and they use a single F32 scale per weight (edit: it's per row actually). This is something that is not supported in our quantization schemes and would require significant changes to implement.

For this case we don't need to add it in quantization schemes, but can add it in model load ("just" add mul op with it after the matmul, like with bias.)?

So as it is, this implementation cannot be used with FP8 models, and the FP8 Q types perform worse than Q8_0. And I don't think that having a different group size than Q8_0 is not a good reason to add them either. I am not sure that there are any practical applications for this code as it is at the moment.

FP8_Q is perfectly usable and with good result.
what I get with Mistral_nemo:

BF16 PPL KLD top P
FP16 6.339475 0.000002 99.941
Q8_0 6.345049 0.000854 98.503
E3M4_Q 6.345332 0.001324 98.045
E3M4 6.346416 0.001562 97.923
E4M3_Q 6.361256 0.004980 96.182
Q6_K 6.373544 0.006197 95.778
E4M3 6.374483 0.006227 95.598
Q5_K_M 6.397938 0.010158 94.834
Q5_K_S 6.437732 0.013509 94.231
E5M2 6.493300 0.022102 92.156
Q4_K_M 6.489127 0.026584 92.176

(Yes as you see NVIDIA/INTEL/AMD was wrong not add E3M4 in hardware.)

High speed FP8-CPU kernel is WIP but need this PR #10446 for clean/simple
integration.

Do you want we wait for this kernel before add FP8 ?

@slaren
Copy link
Collaborator

slaren commented Dec 1, 2024

For this case we don't need to add it in quantization schemes, but can add it in model load ("just" add mul op with it after the matmul, like with bias.)?

It would be ok for a proof of concept, but it would effectively require adding a special case to every matrix multiplication to support FP8. The best way would be be to add support to ggml for quantization types that use one group/scale per row.

FP8_Q is perfectly usable and with good result.

That's not good enough, it's still worse quality than Q8_0, and much slower. As it is, I don't see any motivation to use these types. It needs to offer a meaningful improvement in some aspect to justify the maintenance cost of adding all this code, especially since it would be adding new file types.

I suppose one use for the FP8_Q types would be to use them with FP8 models by using the same scale for every group in the row, until we are able to support quantization types with one group per row in ggml. That would still require adding support to perform this conversion in convert_hf_to_gguf.py.

@ggerganov
Copy link
Owner

I agree with @slaren's evaluation. We have to see more gains from these data types before we merge them.

@Djip007
Copy link
Contributor Author

Djip007 commented Dec 1, 2024

I agree with @slaren's evaluation. We have to see more gains from these data types before we merge them.

I see your point.

I have some kernel that use BF16 for compute with that speed:

cpu_info model_filename test E4M3 E4M3_Q Q8_0
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp1 4.31 4.14 4.68
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp2 8.62 8.33 9.21
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp3 13.36 13.00 13.76
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp4 17.92 17.64 18.19
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp5 19.49 22.15 12.98
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp6 23.71 23.68 15.28
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp7 27.47 25.34 17.32
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp8 29.00 29.06 33.13
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp9 31.31 31.48 22.30
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp10 33.95 32.31 24.26
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp11 35.44 34.25 26.60
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp12 37.22 35.80 38.70
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp13 39.01 38.21 27.62
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp14 37.97 37.55 29.90
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp15 40.50 38.33 31.68
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp16 41.80 38.38 39.84
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp32 43.93 41.50 41.78
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp64 45.62 43.84 43.77
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp128 45.95 45.95 43.94
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp256 46.58 45.84 43.83
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp512 45.58 44.63 42.14
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 pp1024 44.65 44.09 41.14
AMD Ryzen 9 7940HS (znver4) Mistral-Nemo-Instruct-2407 tg16 4.14 4.10 4.67

But I need some time to have it clean for PR and port it to the "cpu-extra-buffer". Do you think this is enough if I can get that?

Note: I know how to make it faster but need more test/time for that.

@Djip007 Djip007 marked this pull request as draft December 1, 2024 19:17
@sorasoras
Copy link

sorasoras commented Jan 3, 2025

Any update on this. deep seekv3 use Fp8 as training format so this might be useful
@Djip007

@cpumaxx
Copy link
Contributor

cpumaxx commented Jan 20, 2025

With the advent of DS R1 also being FP8 this is looking even more important
@Djip007 any thoughts or progress?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build Compilation issues examples ggml changes relating to the ggml tensor library for machine learning script Script related Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants