Skip to content

Commit

Permalink
feat: enable ggml-rpc
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 2, 2024
1 parent 05b13a4 commit d6d1e8f
Show file tree
Hide file tree
Showing 12 changed files with 1,476 additions and 9 deletions.
4 changes: 3 additions & 1 deletion android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ set(
${RNLLAMA_LIB_DIR}/ggml-alloc.c
${RNLLAMA_LIB_DIR}/ggml-backend.cpp
${RNLLAMA_LIB_DIR}/ggml.c
${RNLLAMA_LIB_DIR}/ggml-rpc.cpp
${RNLLAMA_LIB_DIR}/ggml-quants.c
${RNLLAMA_LIB_DIR}/common.cpp
${RNLLAMA_LIB_DIR}/json.hpp
Expand Down Expand Up @@ -55,6 +56,7 @@ function(build_library target_name cpu_flags)
target_compile_options(${target_name} PRIVATE -O3 -DNDEBUG)
target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden)
target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections)
target_compile_options(${target_name} PRIVATE -DLM_GGML_USE_RPC)

target_link_options(${target_name} PRIVATE -Wl,--gc-sections)
target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL)
Expand All @@ -77,7 +79,7 @@ if (${ANDROID_ABI} STREQUAL "arm64-v8a")

# https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md#cross-compile-using-android-ndk
# llama.cpp will deal with the cpu features
# build_library("rnllama_v8_7" "-march=armv8.7-a")
# build_library("rnllama_v8_7" "-march=armv8.7-a")
# TODO: Add support runtime check for cpu features
# At the moment runtime check is failing.

Expand Down
9 changes: 6 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
// float rope_freq_base,
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f
// float rope_freq_scale,
params.hasKey("rope_freq_scale") ? (float) params.getDouble("rope_freq_scale") : 0.0f,
// String rpc_servers
params.hasKey("rpc_servers") ? String.join(",", params.getArray("rpc_servers").toArrayList()) : ""
);
this.modelDetails = loadModelDetails(this.context);
this.reactContext = reactContext;
Expand Down Expand Up @@ -346,7 +348,8 @@ protected static native long initContext(
String lora,
float lora_scaled,
float rope_freq_base,
float rope_freq_scale
float rope_freq_scale,
String rpc_servers
);
protected static native WritableMap loadModelDetails(
long contextPtr
Expand Down
9 changes: 8 additions & 1 deletion android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ Java_com_rnllama_LlamaContext_initContext(
jstring lora_str,
jfloat lora_scaled,
jfloat rope_freq_base,
jfloat rope_freq_scale
jfloat rope_freq_scale,
jstring rpc_servers
) {
UNUSED(thiz);

Expand Down Expand Up @@ -189,6 +190,11 @@ Java_com_rnllama_LlamaContext_initContext(
defaultParams.rope_freq_base = rope_freq_base;
defaultParams.rope_freq_scale = rope_freq_scale;

const char *rpc_servers_chars = env->GetStringUTFChars(rpc_servers, nullptr);
if (rpc_servers_chars != nullptr && rpc_servers_chars[0] != '\0') {
defaultParams.rpc_servers = rpc_servers_chars;
}

auto llama = new rnllama::llama_rn_context();
bool is_model_loaded = llama->loadModel(defaultParams);

Expand All @@ -201,6 +207,7 @@ Java_com_rnllama_LlamaContext_initContext(

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(lora_str, lora_chars);
env->ReleaseStringUTFChars(rpc_servers, rpc_servers_chars);

return reinterpret_cast<jlong>(llama->ctx);
}
Expand Down
Loading

0 comments on commit d6d1e8f

Please sign in to comment.