From 3e16ec1e5a760123a8ac1a8fcfae0c6c3d73402b Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 25 Nov 2024 21:13:36 -0500 Subject: [PATCH] Opt class for positional argument handling Added support for positional arguments `MODEL` and `PROMPT`. Added model path resolution, just `file://` for now. Added functionality to download via strings like: llama-run llama3 llama-run ollama://granite-code llama-run ollama://granite-code:8b llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf llama-run https://example.com/some-file1.gguf llama-run some-file2.gguf llama-run file://some-file3.gguf Signed-off-by: Eric Curtin --- README.md | 15 + common/common.cpp | 6 - common/common.h | 11 +- examples/run/CMakeLists.txt | 13 +- examples/run/run.cpp | 559 ++++++++++++++++++++++++++---------- 5 files changed, 442 insertions(+), 162 deletions(-) diff --git a/README.md b/README.md index 6fdd8d9eefbfb..a3ce4f17c7efb 100644 --- a/README.md +++ b/README.md @@ -433,6 +433,21 @@ To learn more about model quantization, [read this documentation](examples/quant +## [`llama-run`](examples/run) + +#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^1]. + +-
+ Run a model with a specific prompt + + ```bash + llama-run granite-code + > + ``` + +
+ +[^1]: [https://github.com/containers/ramalama](RamaLama) ## [`llama-simple`](examples/simple) diff --git a/common/common.cpp b/common/common.cpp index 6143516d2250f..c44befe9da163 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1108,12 +1108,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p #define CURL_MAX_RETRY 3 #define CURL_RETRY_DELAY_SECONDS 2 - -static bool starts_with(const std::string & str, const std::string & prefix) { - // While we wait for C++20's std::string::starts_with... - return str.rfind(prefix, 0) == 0; -} - static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) { int remaining_attempts = max_attempts; diff --git a/common/common.h b/common/common.h index 95d20401d2a9a..5a49da856dc43 100644 --- a/common/common.h +++ b/common/common.h @@ -37,9 +37,9 @@ using llama_tokens = std::vector; // build info extern int LLAMA_BUILD_NUMBER; -extern char const * LLAMA_COMMIT; -extern char const * LLAMA_COMPILER; -extern char const * LLAMA_BUILD_TARGET; +extern const char * LLAMA_COMMIT; +extern const char * LLAMA_COMPILER; +extern const char * LLAMA_BUILD_TARGET; struct common_control_vector_load_info; @@ -437,6 +437,11 @@ std::vector string_split(const std::string & input, ch return parts; } +static bool starts_with(const std::string & str, + const std::string & prefix) { // While we wait for C++20's std::string::starts_with... + return str.rfind(prefix, 0) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); diff --git a/examples/run/CMakeLists.txt b/examples/run/CMakeLists.txt index 52add51ef77c3..edb3d501a26b3 100644 --- a/examples/run/CMakeLists.txt +++ b/examples/run/CMakeLists.txt @@ -1,5 +1,16 @@ set(TARGET llama-run) add_executable(${TARGET} run.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) + +# Use curl to download model url +if (LLAMA_CURL) + find_package(CURL REQUIRED) + add_definitions(-DLLAMA_USE_CURL) + include_directories(${CURL_INCLUDE_DIRS}) + find_library(CURL_LIBRARY curl REQUIRED) + set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY}) +endif () + +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_include_directories(${TARGET} PUBLIC ../../common) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index cac2faefcc256..ab660b3239565 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -1,145 +1,405 @@ #if defined(_WIN32) -#include +# include #else -#include +# include #endif -#include +#if defined(LLAMA_USE_CURL) +# include +#endif + +#include #include #include +#include #include #include #include -#include #include +#include "common.h" +#include "json.hpp" #include "llama-cpp.h" -typedef std::unique_ptr char_array_ptr; +static void printe(const char * format, ...) __attribute__((format(printf, 1, 2))); -struct Argument { - std::string flag; - std::string help_text; -}; +static void printe(const char * format, ...) { + va_list args; + va_start(args, format); + vfprintf(stderr, format, args); + va_end(args); +} -struct Options { - std::string model_path, prompt_non_interactive; - int ngl = 99; - int n_ctx = 2048; -}; +class Opt { + public: + int init_opt(int argc, const char ** argv) { + construct_help_str_(); + // Parse arguments + if (parse(argc, argv)) { + printe("Error: Failed to parse arguments.\n"); + help(); + return 1; + } -class ArgumentParser { - public: - ArgumentParser(const char * program_name) : program_name(program_name) {} + // If help is requested, show help and exit + if (help_) { + help(); + return 2; + } - void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") { - string_args[flag] = &var; - arguments.push_back({flag, help_text}); + return 0; // Success } - void add_argument(const std::string & flag, int & var, const std::string & help_text = "") { - int_args[flag] = &var; - arguments.push_back({flag, help_text}); + std::string model_; + std::string user_; + int context_size_ = 2048, ngl_ = 0; + + private: + std::string help_str_; + bool help_ = false; + + void construct_help_str_() { + help_str_ = + "Description:\n" + " Runs a llm\n" + "\n" + "Usage:\n" + " llama-run [options] MODEL [PROMPT]\n" + "\n" + "Options:\n" + " -c, --context-size \n" + " Context size (default: " + + std::to_string(context_size_); + help_str_ += + ")\n" + " -n, --ngl \n" + " Number of GPU layers (default: " + + std::to_string(ngl_); + help_str_ += + ")\n" + " -h, --help\n" + " Show help message\n" + "\n" + "Examples:\n" + " llama-run llama3\n" + " llama-run ollama://granite-code\n" + " llama-run ollama://smollm:135m\n" + " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" + " llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" + " llama-run https://example.com/some-file1.gguf\n" + " llama-run some-file2.gguf\n" + " llama-run file://some-file3.gguf\n" + " llama-run --ngl 99 some-file4.gguf\n" + " llama-run --ngl 99 some-file5.gguf Hello World\n"; } int parse(int argc, const char ** argv) { + int positional_args_i = 0; for (int i = 1; i < argc; ++i) { - std::string arg = argv[i]; - if (string_args.count(arg)) { - if (i + 1 < argc) { - *string_args[arg] = argv[++i]; - } else { - fprintf(stderr, "error: missing value for %s\n", arg.c_str()); - print_usage(); + if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) { + if (i + 1 >= argc) { return 1; } - } else if (int_args.count(arg)) { - if (i + 1 < argc) { - if (parse_int_arg(argv[++i], *int_args[arg]) != 0) { - fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]); - print_usage(); - return 1; - } - } else { - fprintf(stderr, "error: missing value for %s\n", arg.c_str()); - print_usage(); + + context_size_ = std::atoi(argv[++i]); + } else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) { + if (i + 1 >= argc) { return 1; } + + ngl_ = std::atoi(argv[++i]); + } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + help_ = true; + return 0; + } else if (!positional_args_i) { + ++positional_args_i; + model_ = argv[i]; + } else if (positional_args_i == 1) { + ++positional_args_i; + user_ = argv[i]; } else { - fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str()); - print_usage(); - return 1; + user_ += " " + std::string(argv[i]); } } - if (string_args["-m"]->empty()) { - fprintf(stderr, "error: -m is required\n"); - print_usage(); + return model_.empty(); // model_ is the only required value + } + + void help() const { printf("%s", help_str_.c_str()); } +}; + +struct progress_data { + size_t file_size = 0; + bool printed = false; +}; + +struct FileDeleter { + void operator()(FILE * file) const { + if (file) { + fclose(file); + } + } +}; + +typedef std::unique_ptr FILE_ptr; + +class LlamaData { + public: + llama_model_ptr model; + llama_sampler_ptr sampler; + llama_context_ptr context; + std::vector messages; + std::vector msg_strs; + std::vector fmtted; + + int init(Opt & opt) { + model = initialize_model(opt); + if (!model) { return 1; } + context = initialize_context(model, opt.context_size_); + if (!context) { + return 1; + } + + sampler = initialize_sampler(); return 0; } - private: - const char * program_name; - std::unordered_map string_args; - std::unordered_map int_args; - std::vector arguments; + private: + // Function to write data to a file + static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { + FILE * out = static_cast(stream); + return fwrite(ptr, size, nmemb, out); + } - int parse_int_arg(const char * arg, int & value) { - char * end; - const long val = std::strtol(arg, &end, 10); - if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { - value = static_cast(val); + // Function to capture data into a string + static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) { + std::string * str = static_cast(stream); + str->append(static_cast(ptr), size * nmemb); + return size * nmemb; + } + +#ifdef LLAMA_USE_CURL + CURL * init_curl() { return curl_easy_init(); } + + void set_write_options(CURL * curl, std::string * response_str, const FILE_ptr & out) { + if (response_str) { + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str); + } else { + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.get()); + } + } + + size_t set_resume_point(CURL * curl, const std::string & output_file) { + size_t file_size = 0; + if (std::filesystem::exists(output_file)) { + file_size = std::filesystem::file_size(output_file); + curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast(file_size)); + } + + return file_size; + } + + // Function to display progress + static int progress_callback(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t, + curl_off_t) { + progress_data * data = static_cast(ptr); + if (total_to_download <= 0) { return 0; } - return 1; + + total_to_download += data->file_size; + now_downloaded += data->file_size; + int percentage = static_cast((now_downloaded * 100) / total_to_download); + printe("\rProgress: %d%% |", percentage); + int pos = (percentage / 5); + for (int i = 0; i < 20; ++i) { + if (i < pos) { + printe("█"); + } else { + printe(" "); + } + } + + printe("| %li/%li bytes", now_downloaded, total_to_download); + fflush(stderr); + data->printed = true; + + return 0; } - void print_usage() const { - printf("\nUsage:\n"); - printf(" %s [OPTIONS]\n\n", program_name); - printf("Options:\n"); - for (const auto & arg : arguments) { - printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str()); + void set_progress_options(CURL * curl, const bool progress, progress_data & data) { + if (progress) { + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, progress_callback); } + } - printf("\n"); + void set_headers(CURL * curl, const std::vector & headers) { + if (!headers.empty()) { + struct curl_slist * chunk = NULL; + for (const auto & header : headers) { + chunk = curl_slist_append(chunk, header.c_str()); + } + + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk); + } } -}; -class LlamaData { - public: - llama_model_ptr model; - llama_sampler_ptr sampler; - llama_context_ptr context; - std::vector messages; + void perform_curl(CURL * curl, const std::string & url) { + CURLcode res; + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); + curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); - int init(const Options & opt) { - model = initialize_model(opt.model_path, opt.ngl); - if (!model) { + res = curl_easy_perform(curl); + if (res != CURLE_OK) { + printe("curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); + } + } +#endif + +#ifdef LLAMA_USE_CURL + int download(const std::string & url, const std::vector & headers, const std::string & output_file, + const bool progress, std::string * response_str = nullptr) { + CURL * curl = init_curl(); + std::string output_file_partial; + progress_data data; + if (curl) { + FILE_ptr out; + if (!output_file.empty()) { + output_file_partial = output_file + ".partial"; + out.reset(fopen(output_file_partial.c_str(), "ab")); + } + + set_write_options(curl, response_str, out); + data.file_size = set_resume_point(curl, output_file_partial); + set_progress_options(curl, progress, data); + set_headers(curl, headers); + perform_curl(curl, url); + curl_easy_cleanup(curl); + } + + if (!output_file.empty()) { + std::filesystem::rename(output_file_partial, output_file); + } + + if (data.printed) { + printe("\n"); + } + + return 0; + } +#else + int download(const std::string &, const std::vector &, const std::string &, const bool, + std::string * = nullptr) { + printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); + return 1; + } +#endif + + int huggingface_dl(const std::string & model, const std::vector headers, const std::string & bn) { + // Find the second occurrence of '/' after protocol string + size_t pos = model.find('/'); + pos = model.find('/', pos + 1); + if (pos == std::string::npos) { return 1; } - context = initialize_context(model, opt.n_ctx); - if (!context) { + const std::string hfr = model.substr(0, pos); + const std::string hff = model.substr(pos + 1); + const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff; + return download(url, headers, bn, true); + } + + int ollama_dl(std::string & model, const std::vector headers, const std::string & bn) { + if (model.find('/') == std::string::npos) { + model = "library/" + model; + } + + std::string model_tag = "latest"; + size_t colon_pos = model.find(':'); + if (colon_pos != std::string::npos) { + model_tag = model.substr(colon_pos + 1); + model = model.substr(0, colon_pos); + } + + std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag; + std::string manifest_str; + const int ret = download(manifest_url, headers, "", false, &manifest_str); + if (ret) { + return ret; + } + + nlohmann::json manifest = nlohmann::json::parse(manifest_str); + std::string layer; + for (const auto & l : manifest["layers"]) { + if (l["mediaType"] == "application/vnd.ollama.image.model") { + layer = l["digest"]; + break; + } + } + + std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer; + return download(blob_url, headers, bn, true); + } + + std::string basename(const std::string & path) { + size_t pos = path.find_last_of("/\\"); + if (pos == std::string::npos) { + return path; + } + + return path.substr(pos + 1); + } + + int remove_proto(std::string & model_) { + const std::string::size_type pos = model_.find("://"); + if (pos == std::string::npos) { return 1; } - sampler = initialize_sampler(); + model_ = model_.substr(pos + 3); // Skip past "://" return 0; } - private: + int resolve_model(std::string & model_) { + const std::string bn = basename(model_); + const std::vector headers = { "--header", + "Accept: application/vnd.docker.distribution.manifest.v2+json" }; + int ret = 0; + if (starts_with(model_, "hf://") || starts_with(model_, "huggingface://")) { + remove_proto(model_); + ret = huggingface_dl(model_, headers, bn); + } else if (starts_with(model_, "ollama://")) { + remove_proto(model_); + ret = ollama_dl(model_, headers, bn); + } else if (starts_with(model_, "file://") || std::filesystem::exists(model_)) { + remove_proto(model_); + } else { + ret = ollama_dl(model_, headers, bn); + } + + model_ = bn; + + return ret; + } + // Initializes the model and returns a unique pointer to it - llama_model_ptr initialize_model(const std::string & model_path, const int ngl) { + llama_model_ptr initialize_model(Opt & opt) { + ggml_backend_load_all(); llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; - - llama_model_ptr model(llama_load_model_from_file(model_path.c_str(), model_params)); + model_params.n_gpu_layers = opt.ngl_; + resolve_model(opt.model_); + llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params)); if (!model) { - fprintf(stderr, "%s: error: unable to load model\n", __func__); + printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); } return model; @@ -148,12 +408,11 @@ class LlamaData { // Initializes the context with the specified parameters llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_ctx; - ctx_params.n_batch = n_ctx; - + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params)); if (!context) { - fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); + printe("%s: error: failed to create the llama_context\n", __func__); } return context; @@ -170,23 +429,22 @@ class LlamaData { } }; -// Add a message to `messages` and store its content in `owned_content` -static void add_message(const char * role, const std::string & text, LlamaData & llama_data, - std::vector & owned_content) { - char_array_ptr content(new char[text.size() + 1]); - std::strcpy(content.get(), text.c_str()); - llama_data.messages.push_back({role, content.get()}); - owned_content.push_back(std::move(content)); +// Add a message to `messages` and store its content in `msg_strs` +static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { + llama_data.msg_strs.push_back(std::move(text)); + llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const LlamaData & llama_data, std::vector & formatted, const bool append) { - int result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), - llama_data.messages.size(), append, formatted.data(), formatted.size()); - if (result > static_cast(formatted.size())) { - formatted.resize(result); +static int apply_chat_template(LlamaData & llama_data, const bool append) { + int result = llama_chat_apply_template( + llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, + append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); + if (append && result > static_cast(llama_data.fmtted.size())) { + llama_data.fmtted.resize(result); result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), - llama_data.messages.size(), append, formatted.data(), formatted.size()); + llama_data.messages.size(), append, llama_data.fmtted.data(), + llama_data.fmtted.size()); } return result; @@ -199,7 +457,8 @@ static int tokenize_prompt(const llama_model_ptr & model, const std::string & pr prompt_tokens.resize(n_prompt_tokens); if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { - GGML_ABORT("failed to tokenize the prompt\n"); + printe("failed to tokenize the prompt\n"); + return -1; } return n_prompt_tokens; @@ -207,11 +466,11 @@ static int tokenize_prompt(const llama_model_ptr & model, const std::string & pr // Check if we have enough space in the context to evaluate this batch static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { - const int n_ctx = llama_n_ctx(ctx.get()); + const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); - fprintf(stderr, "context size exceeded\n"); + printe("context size exceeded\n"); return 1; } @@ -221,9 +480,10 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & // convert the token to a string static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) { char buf[256]; - int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); + int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); if (n < 0) { - GGML_ABORT("failed to convert token to piece\n"); + printe("failed to convert token to piece\n"); + return 1; } piece = std::string(buf, n); @@ -238,19 +498,19 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st // helper function to evaluate a prompt and generate a response static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { - std::vector prompt_tokens; - const int n_prompt_tokens = tokenize_prompt(llama_data.model, prompt, prompt_tokens); - if (n_prompt_tokens < 0) { + std::vector tokens; + if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) { return 1; } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); if (llama_decode(llama_data.context.get(), batch)) { - GGML_ABORT("failed to decode\n"); + printe("failed to decode\n"); + return 1; } // sample the next token, check is it an end of generation? @@ -273,22 +533,9 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str return 0; } -static int parse_arguments(const int argc, const char ** argv, Options & opt) { - ArgumentParser parser(argv[0]); - parser.add_argument("-m", opt.model_path, "model"); - parser.add_argument("-p", opt.prompt_non_interactive, "prompt"); - parser.add_argument("-c", opt.n_ctx, "context_size"); - parser.add_argument("-ngl", opt.ngl, "n_gpu_layers"); - if (parser.parse(argc, argv)) { - return 1; - } - - return 0; -} - static int read_user_input(std::string & user) { std::getline(std::cin, user); - return user.empty(); // Indicate an error or empty input + return user.empty(); // Should have data in happy path } // Function to generate a response based on the prompt @@ -296,7 +543,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, // Set response color printf("\033[33m"); if (generate(llama_data, prompt, response)) { - fprintf(stderr, "failed to generate response\n"); + printe("failed to generate response\n"); return 1; } @@ -306,11 +553,10 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const LlamaData & llama_data, std::vector & formatted, - const bool is_user_input, int & output_length) { - const int new_len = apply_chat_template(llama_data, formatted, is_user_input); +static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) { + const int new_len = apply_chat_template(llama_data, append); if (new_len < 0) { - fprintf(stderr, "failed to apply the chat template\n"); + printe("failed to apply the chat template\n"); return -1; } @@ -319,56 +565,62 @@ static int apply_chat_template_with_error_handling(const LlamaData & llama_data, } // Helper function to handle user input -static bool handle_user_input(std::string & user_input, const std::string & prompt_non_interactive) { - if (!prompt_non_interactive.empty()) { - user_input = prompt_non_interactive; - return true; // No need for interactive input +static int handle_user_input(std::string & user_input, const std::string & user_) { + if (!user_.empty()) { + user_input = user_; + return 0; // No need for interactive input } printf("\033[32m> \033[0m"); - return !read_user_input(user_input); // Returns false if input ends the loop + return read_user_input(user_input); // Returns true if input ends the loop } // Function to tokenize the prompt -static int chat_loop(LlamaData & llama_data, std::string & prompt_non_interactive) { - std::vector owned_content; - std::vector fmtted(llama_n_ctx(llama_data.context.get())); +static int chat_loop(LlamaData & llama_data, const std::string & user_) { int prev_len = 0; - + llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); while (true) { // Get user input std::string user_input; - if (!handle_user_input(user_input, prompt_non_interactive)) { + if (handle_user_input(user_input, user_)) { break; } - add_message("user", prompt_non_interactive.empty() ? user_input : prompt_non_interactive, llama_data, - owned_content); - + add_message("user", user_.empty() ? user_input : user_, llama_data); int new_len; - if (apply_chat_template_with_error_handling(llama_data, fmtted, true, new_len) < 0) { + if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) { return 1; } - std::string prompt(fmtted.begin() + prev_len, fmtted.begin() + new_len); + std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); std::string response; if (generate_response(llama_data, prompt, response)) { return 1; } + + if (!user_.empty()) { + break; + } + + add_message("assistant", response, llama_data); + if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) { + return 1; + } } + return 0; } static void log_callback(const enum ggml_log_level level, const char * text, void *) { if (level == GGML_LOG_LEVEL_ERROR) { - fprintf(stderr, "%s", text); + printe("%s", text); } } static bool is_stdin_a_terminal() { #if defined(_WIN32) HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); - DWORD mode; + DWORD mode; return GetConsoleMode(hStdin, &mode); #else return isatty(STDIN_FILENO); @@ -382,17 +634,20 @@ static std::string read_pipe_data() { } int main(int argc, const char ** argv) { - Options opt; - if (parse_arguments(argc, argv, opt)) { + Opt opt; + const int opt_ret = opt.init_opt(argc, argv); + if (opt_ret == 2) { + return 0; + } else if (opt_ret) { return 1; } if (!is_stdin_a_terminal()) { - if (!opt.prompt_non_interactive.empty()) { - opt.prompt_non_interactive += "\n\n"; + if (!opt.user_.empty()) { + opt.user_ += "\n\n"; } - opt.prompt_non_interactive += read_pipe_data(); + opt.user_ += read_pipe_data(); } llama_log_set(log_callback, nullptr); @@ -401,7 +656,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.prompt_non_interactive)) { + if (chat_loop(llama_data, opt.user_)) { return 1; }