From a4bbad4f9bb131d137d0cbadbfa13d863af2e67b Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Mon, 25 Nov 2024 21:13:36 -0500 Subject: [PATCH] Opt class for positional argument handling Added support for positional arguments `MODEL` and `PROMPT`. Added model path resolution, just `file://` for now. Signed-off-by: Eric Curtin --- common/common.cpp | 6 - common/common.h | 11 +- examples/run/CMakeLists.txt | 3 +- examples/run/run.cpp | 323 +++++++++++++++++++----------------- 4 files changed, 179 insertions(+), 164 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6143516d2250f..c44befe9da163 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1108,12 +1108,6 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p #define CURL_MAX_RETRY 3 #define CURL_RETRY_DELAY_SECONDS 2 - -static bool starts_with(const std::string & str, const std::string & prefix) { - // While we wait for C++20's std::string::starts_with... - return str.rfind(prefix, 0) == 0; -} - static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) { int remaining_attempts = max_attempts; diff --git a/common/common.h b/common/common.h index 95d20401d2a9a..5a49da856dc43 100644 --- a/common/common.h +++ b/common/common.h @@ -37,9 +37,9 @@ using llama_tokens = std::vector; // build info extern int LLAMA_BUILD_NUMBER; -extern char const * LLAMA_COMMIT; -extern char const * LLAMA_COMPILER; -extern char const * LLAMA_BUILD_TARGET; +extern const char * LLAMA_COMMIT; +extern const char * LLAMA_COMPILER; +extern const char * LLAMA_BUILD_TARGET; struct common_control_vector_load_info; @@ -437,6 +437,11 @@ std::vector string_split(const std::string & input, ch return parts; } +static bool starts_with(const std::string & str, + const std::string & prefix) { // While we wait for C++20's std::string::starts_with... + return str.rfind(prefix, 0) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); diff --git a/examples/run/CMakeLists.txt b/examples/run/CMakeLists.txt index 52add51ef77c3..2bf78f850b322 100644 --- a/examples/run/CMakeLists.txt +++ b/examples/run/CMakeLists.txt @@ -1,5 +1,6 @@ set(TARGET llama-run) add_executable(${TARGET} run.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_include_directories(${TARGET} PUBLIC ../../common) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index cac2faefcc256..db78f42044e4d 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -1,128 +1,126 @@ #if defined(_WIN32) -#include +# include #else -#include +# include #endif -#include #include #include #include #include #include -#include #include +#include "common.h" #include "llama-cpp.h" -typedef std::unique_ptr char_array_ptr; - -struct Argument { - std::string flag; - std::string help_text; -}; - -struct Options { - std::string model_path, prompt_non_interactive; - int ngl = 99; - int n_ctx = 2048; -}; +class Opt { + public: + int init_opt(int argc, const char ** argv) { + construct_help_str_(); + // Parse arguments + if (parse(argc, argv)) { + fprintf(stderr, "Error: Failed to parse arguments.\n"); + help(); + return 1; + } -class ArgumentParser { - public: - ArgumentParser(const char * program_name) : program_name(program_name) {} + // If help is requested, show help and exit + if (help_) { + help(); + return 2; + } - void add_argument(const std::string & flag, std::string & var, const std::string & help_text = "") { - string_args[flag] = &var; - arguments.push_back({flag, help_text}); + return 0; // Success } - void add_argument(const std::string & flag, int & var, const std::string & help_text = "") { - int_args[flag] = &var; - arguments.push_back({flag, help_text}); + std::string model_; + std::string user_; + int context_size_ = 2048, ngl_ = 0; + + private: + std::string help_str_; + bool help_ = false; + + void construct_help_str_() { + help_str_ = + "Description:\n" + " Runs a llm\n" + "\n" + "Usage:\n" + " llama-run [options] MODEL [PROMPT]\n" + "\n" + "Options:\n" + " -c, --context-size \n" + " Context size (default: " + + std::to_string(context_size_); + help_str_ += + ")\n" + " -n, --ngl \n" + " Number of GPU layers (default: " + + std::to_string(ngl_); + help_str_ += + ")\n" + " -h, --help\n" + " Show help message\n" + "\n" + "Examples:\n" + " llama-run your_model.gguf\n" + " llama-run --ngl 99 your_model.gguf\n" + " llama-run --ngl 99 your_model.gguf Hello World\n"; } int parse(int argc, const char ** argv) { + int positional_args_i = 0; for (int i = 1; i < argc; ++i) { - std::string arg = argv[i]; - if (string_args.count(arg)) { - if (i + 1 < argc) { - *string_args[arg] = argv[++i]; - } else { - fprintf(stderr, "error: missing value for %s\n", arg.c_str()); - print_usage(); + if (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0) { + if (i + 1 >= argc) { return 1; } - } else if (int_args.count(arg)) { - if (i + 1 < argc) { - if (parse_int_arg(argv[++i], *int_args[arg]) != 0) { - fprintf(stderr, "error: invalid value for %s: %s\n", arg.c_str(), argv[i]); - print_usage(); - return 1; - } - } else { - fprintf(stderr, "error: missing value for %s\n", arg.c_str()); - print_usage(); + + context_size_ = std::atoi(argv[++i]); + } else if (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "--ngl") == 0) { + if (i + 1 >= argc) { return 1; } + + ngl_ = std::atoi(argv[++i]); + } else if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) { + help_ = true; + return 0; + } else if (!positional_args_i) { + ++positional_args_i; + model_ = argv[i]; + } else if (positional_args_i == 1) { + ++positional_args_i; + user_ = argv[i]; } else { - fprintf(stderr, "error: unrecognized argument %s\n", arg.c_str()); - print_usage(); - return 1; + user_ += " " + std::string(argv[i]); } } - if (string_args["-m"]->empty()) { - fprintf(stderr, "error: -m is required\n"); - print_usage(); - return 1; - } - - return 0; + return model_.empty(); // model_ is the only required value } - private: - const char * program_name; - std::unordered_map string_args; - std::unordered_map int_args; - std::vector arguments; - - int parse_int_arg(const char * arg, int & value) { - char * end; - const long val = std::strtol(arg, &end, 10); - if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) { - value = static_cast(val); - return 0; - } - return 1; - } - - void print_usage() const { - printf("\nUsage:\n"); - printf(" %s [OPTIONS]\n\n", program_name); - printf("Options:\n"); - for (const auto & arg : arguments) { - printf(" %-10s %s\n", arg.flag.c_str(), arg.help_text.c_str()); - } - - printf("\n"); - } + void help() const { printf("%s", help_str_.c_str()); } }; class LlamaData { - public: - llama_model_ptr model; - llama_sampler_ptr sampler; - llama_context_ptr context; + public: + llama_model_ptr model; + llama_sampler_ptr sampler; + llama_context_ptr context; std::vector messages; + std::vector msg_strs; + std::vector fmtted; - int init(const Options & opt) { - model = initialize_model(opt.model_path, opt.ngl); + int init(Opt & opt) { + model = initialize_model(opt); if (!model) { return 1; } - context = initialize_context(model, opt.n_ctx); + context = initialize_context(model, opt.context_size_); if (!context) { return 1; } @@ -131,13 +129,35 @@ class LlamaData { return 0; } - private: + private: + int remove_proto(std::string & model_) { + const std::string::size_type pos = model_.find("://"); + if (pos == std::string::npos) { + return 1; + } + + model_ = model_.substr(pos + 3); // Skip past "://" + return 0; + } + + int resolve_model(std::string & model_, const struct llama_model_params & params) { + if (starts_with(model_, "file://")) { + remove_proto(model_); + } + + // Implement hf://, https://, ollama://, later, if file doesn't exist, assume + // ollama str + + return 0; + } + // Initializes the model and returns a unique pointer to it - llama_model_ptr initialize_model(const std::string & model_path, const int ngl) { + llama_model_ptr initialize_model(Opt & opt) { + ggml_backend_load_all(); llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; - - llama_model_ptr model(llama_load_model_from_file(model_path.c_str(), model_params)); + model_params.n_gpu_layers = opt.ngl_; + resolve_model(opt.model_, model_params); + llama_model_ptr model(llama_load_model_from_file(opt.model_.c_str(), model_params)); if (!model) { fprintf(stderr, "%s: error: unable to load model\n", __func__); } @@ -148,9 +168,8 @@ class LlamaData { // Initializes the context with the specified parameters llama_context_ptr initialize_context(const llama_model_ptr & model, const int n_ctx) { llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_ctx; - ctx_params.n_batch = n_ctx; - + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; llama_context_ptr context(llama_new_context_with_model(model.get(), ctx_params)); if (!context) { fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__); @@ -170,23 +189,22 @@ class LlamaData { } }; -// Add a message to `messages` and store its content in `owned_content` -static void add_message(const char * role, const std::string & text, LlamaData & llama_data, - std::vector & owned_content) { - char_array_ptr content(new char[text.size() + 1]); - std::strcpy(content.get(), text.c_str()); - llama_data.messages.push_back({role, content.get()}); - owned_content.push_back(std::move(content)); +// Add a message to `messages` and store its content in `msg_strs` +static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { + llama_data.msg_strs.push_back(std::move(text)); + llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const LlamaData & llama_data, std::vector & formatted, const bool append) { - int result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), - llama_data.messages.size(), append, formatted.data(), formatted.size()); - if (result > static_cast(formatted.size())) { - formatted.resize(result); +static int apply_chat_template(LlamaData & llama_data, const bool append) { + int result = llama_chat_apply_template( + llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, + append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); + if (append && result > static_cast(llama_data.fmtted.size())) { + llama_data.fmtted.resize(result); result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), - llama_data.messages.size(), append, formatted.data(), formatted.size()); + llama_data.messages.size(), append, llama_data.fmtted.data(), + llama_data.fmtted.size()); } return result; @@ -199,7 +217,8 @@ static int tokenize_prompt(const llama_model_ptr & model, const std::string & pr prompt_tokens.resize(n_prompt_tokens); if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { - GGML_ABORT("failed to tokenize the prompt\n"); + fprintf(stderr, "failed to tokenize the prompt\n"); + return -1; } return n_prompt_tokens; @@ -207,7 +226,7 @@ static int tokenize_prompt(const llama_model_ptr & model, const std::string & pr // Check if we have enough space in the context to evaluate this batch static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { - const int n_ctx = llama_n_ctx(ctx.get()); + const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_get_kv_cache_used_cells(ctx.get()); if (n_ctx_used + batch.n_tokens > n_ctx) { printf("\033[0m\n"); @@ -221,9 +240,10 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch & // convert the token to a string static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) { char buf[256]; - int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); + int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); if (n < 0) { - GGML_ABORT("failed to convert token to piece\n"); + fprintf(stderr, "failed to convert token to piece\n"); + return 1; } piece = std::string(buf, n); @@ -238,19 +258,19 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st // helper function to evaluate a prompt and generate a response static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { - std::vector prompt_tokens; - const int n_prompt_tokens = tokenize_prompt(llama_data.model, prompt, prompt_tokens); - if (n_prompt_tokens < 0) { + std::vector tokens; + if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) { return 1; } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); if (llama_decode(llama_data.context.get(), batch)) { - GGML_ABORT("failed to decode\n"); + fprintf(stderr, "failed to decode\n"); + return 1; } // sample the next token, check is it an end of generation? @@ -273,22 +293,9 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str return 0; } -static int parse_arguments(const int argc, const char ** argv, Options & opt) { - ArgumentParser parser(argv[0]); - parser.add_argument("-m", opt.model_path, "model"); - parser.add_argument("-p", opt.prompt_non_interactive, "prompt"); - parser.add_argument("-c", opt.n_ctx, "context_size"); - parser.add_argument("-ngl", opt.ngl, "n_gpu_layers"); - if (parser.parse(argc, argv)) { - return 1; - } - - return 0; -} - static int read_user_input(std::string & user) { std::getline(std::cin, user); - return user.empty(); // Indicate an error or empty input + return user.empty(); // Should have data in happy path } // Function to generate a response based on the prompt @@ -306,9 +313,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const LlamaData & llama_data, std::vector & formatted, - const bool is_user_input, int & output_length) { - const int new_len = apply_chat_template(llama_data, formatted, is_user_input); +static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) { + const int new_len = apply_chat_template(llama_data, append); if (new_len < 0) { fprintf(stderr, "failed to apply the chat template\n"); return -1; @@ -319,43 +325,49 @@ static int apply_chat_template_with_error_handling(const LlamaData & llama_data, } // Helper function to handle user input -static bool handle_user_input(std::string & user_input, const std::string & prompt_non_interactive) { - if (!prompt_non_interactive.empty()) { - user_input = prompt_non_interactive; - return true; // No need for interactive input +static int handle_user_input(std::string & user_input, const std::string & user_) { + if (!user_.empty()) { + user_input = user_; + return 0; // No need for interactive input } printf("\033[32m> \033[0m"); - return !read_user_input(user_input); // Returns false if input ends the loop + return read_user_input(user_input); // Returns true if input ends the loop } // Function to tokenize the prompt -static int chat_loop(LlamaData & llama_data, std::string & prompt_non_interactive) { - std::vector owned_content; - std::vector fmtted(llama_n_ctx(llama_data.context.get())); +static int chat_loop(LlamaData & llama_data, const std::string & user_) { int prev_len = 0; - + llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); while (true) { // Get user input std::string user_input; - if (!handle_user_input(user_input, prompt_non_interactive)) { + if (handle_user_input(user_input, user_)) { break; } - add_message("user", prompt_non_interactive.empty() ? user_input : prompt_non_interactive, llama_data, - owned_content); - + add_message("user", user_.empty() ? user_input : user_, llama_data); int new_len; - if (apply_chat_template_with_error_handling(llama_data, fmtted, true, new_len) < 0) { + if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) { return 1; } - std::string prompt(fmtted.begin() + prev_len, fmtted.begin() + new_len); + std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); std::string response; if (generate_response(llama_data, prompt, response)) { return 1; } + + if (!user_.empty()) { + break; + } + + add_message("assistant", response, llama_data); + if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) { + return 1; + } } + return 0; } @@ -368,7 +380,7 @@ static void log_callback(const enum ggml_log_level level, const char * text, voi static bool is_stdin_a_terminal() { #if defined(_WIN32) HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); - DWORD mode; + DWORD mode; return GetConsoleMode(hStdin, &mode); #else return isatty(STDIN_FILENO); @@ -382,17 +394,20 @@ static std::string read_pipe_data() { } int main(int argc, const char ** argv) { - Options opt; - if (parse_arguments(argc, argv, opt)) { + Opt opt; + const int opt_ret = opt.init_opt(argc, argv); + if (opt_ret == 2) { + return 0; + } else if (opt_ret) { return 1; } if (!is_stdin_a_terminal()) { - if (!opt.prompt_non_interactive.empty()) { - opt.prompt_non_interactive += "\n\n"; + if (!opt.user_.empty()) { + opt.user_ += "\n\n"; } - opt.prompt_non_interactive += read_pipe_data(); + opt.user_ += read_pipe_data(); } llama_log_set(log_callback, nullptr); @@ -401,7 +416,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.prompt_non_interactive)) { + if (chat_loop(llama_data, opt.user_)) { return 1; }