Skip to content

Commit

Permalink
ggml : fix llamafile sgemm wdata offsets (ggerganov#6710)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
Georgi Gerganov authored Apr 16, 2024
1 parent 8cc91dc commit 666867b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ endif()
# 3rd party libs
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_BLAS "llama: use BLAS" OFF)
option(LLAMA_LLAMAFILE "llama: use llamafile SGEMM" ON)
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUDA "llama: use CUDA" OFF)
option(LLAMA_CUBLAS "llama: use CUDA (deprecated, use LLAMA_CUDA)" OFF)
Expand Down Expand Up @@ -286,6 +287,7 @@ if (LLAMA_METAL)
${METALKIT_FRAMEWORK}
)
endif()

if (LLAMA_BLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
Expand Down Expand Up @@ -368,6 +370,10 @@ if (LLAMA_BLAS)
endif()
endif()

if (LLAMA_LLAMAFILE)
add_compile_definitions(GGML_USE_LLAMAFILE)
endif()

if (LLAMA_QKK_64)
add_compile_definitions(GGML_QKK_64)
endif()
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ endif # LLAMA_DISABLE_LOGS
# disable ggml.c's use of sgemm.cpp
ifdef LLAMA_NO_LLAMAFILE
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=0
else
MK_CPPFLAGS += -DGGML_USE_LLAMAFILE=1
endif

# warnings
Expand Down
11 changes: 4 additions & 7 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@
#include <unistd.h>
#endif

#ifndef GGML_USE_LLAMAFILE
#ifdef __ARM_FEATURE_MATMUL_INT8
#define GGML_USE_LLAMAFILE 0
#else
#define GGML_USE_LLAMAFILE 1
#endif
#undef GGML_USE_LLAMAFILE
#endif

#if defined(_MSC_VER)
Expand Down Expand Up @@ -10879,8 +10875,9 @@ UseGgmlGemm1:;
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type),
(const char *)wdata + (nb12/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i12 +
nb13/ggml_type_size(src1->type)*ggml_type_size(vec_dot_type)*i13),
(const char *)wdata + ggml_row_size(vec_dot_type,
nb12/ggml_type_size(src1->type)*i12 +
nb13/ggml_type_size(src1->type)*i13),
row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type),
Expand Down

0 comments on commit 666867b

Please sign in to comment.