From f4138cc78384c136477b89004a2d6785ee9a6449 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sat, 4 May 2024 18:26:07 -0700 Subject: [PATCH] feat(core): make llama cpp an optional dependency --- .github/workflows/test-rust.yml | 2 +- crates/tabby/Cargo.toml | 13 ++--- crates/tabby/src/main.rs | 9 ++-- crates/tabby/src/serve.rs | 6 +-- crates/tabby/src/services/health.rs | 4 +- crates/tabby/src/services/model/mod.rs | 66 ++++++++++++++------------ crates/tabby/src/worker.rs | 2 +- 7 files changed, 55 insertions(+), 47 deletions(-) diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml index e9e7ab7acb42..89f533326199 100644 --- a/.github/workflows/test-rust.yml +++ b/.github/workflows/test-rust.yml @@ -68,7 +68,7 @@ jobs: - run: bash ./ci/prepare_build_environment.sh - name: Run unit tests on community build - run: cargo test --bin tabby --no-default-features + run: cargo test --bin tabby --no-default-features --features llama-cpp - name: Run unit tests run: cargo test --bin tabby --lib diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 06a9ecf58a91..8b30f1bb2cfa 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -6,16 +6,17 @@ authors.workspace = true homepage.workspace = true [features] -default = ["ee"] +default = ["ee", "llama-cpp"] ee = ["dep:tabby-webserver"] -cuda = ["llama-cpp-bindings/cuda"] -rocm = ["llama-cpp-bindings/rocm"] -vulkan = ["llama-cpp-bindings/vulkan"] +llama-cpp = ["dep:llama-cpp-bindings"] +cuda = ["llama-cpp", "llama-cpp-bindings/cuda"] +rocm = ["llama-cpp", "llama-cpp-bindings/rocm"] +vulkan = ["llama-cpp", "llama-cpp-bindings/vulkan"] # If compiling on a system without OpenSSL installed, or cross-compiling for a different # architecture, enable this feature to compile OpenSSL as part of the build. # See https://docs.rs/openssl/#vendored for more. static-ssl = ['openssl/vendored'] -prod = ['tabby-webserver/prod-db'] +prod = ['tabby-webserver/prod-db', "llama-cpp"] [dependencies] tabby-common = { path = "../tabby-common" } @@ -47,7 +48,7 @@ async-stream = { workspace = true } minijinja = { version = "1.0.8", features = ["loader"] } textdistance = "1.0.2" regex.workspace = true -llama-cpp-bindings = { path = "../llama-cpp-bindings" } +llama-cpp-bindings = { path = "../llama-cpp-bindings", optional = true} futures.workspace = true async-trait.workspace = true tabby-webserver = { path = "../../ee/tabby-webserver", optional = true } diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 50da6517ae3f..4948bd41e8ef 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -68,6 +68,7 @@ pub struct SchedulerArgs { #[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)] pub enum Device { + #[cfg(feature = "llama-cpp")] #[strum(serialize = "cpu")] Cpu, @@ -79,7 +80,7 @@ pub enum Device { #[strum(serialize = "rocm")] Rocm, - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + #[cfg(all(feature = "llama-cpp", target_os = "macos", target_arch = "aarch64"))] #[strum(serialize = "metal")] Metal, @@ -87,13 +88,13 @@ pub enum Device { #[strum(serialize = "vulkan")] Vulkan, - #[strum(serialize = "experimental_http")] + #[strum(serialize = "http")] #[clap(hide = true)] - ExperimentalHttp, + Http, } impl Device { - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + #[cfg(all(feature = "llama-cpp", target_os = "macos", target_arch = "aarch64"))] pub fn ggml_use_gpu(&self) -> bool { *self == Device::Metal } diff --git a/crates/tabby/src/serve.rs b/crates/tabby/src/serve.rs index f6bade3babce..c6592857c29b 100644 --- a/crates/tabby/src/serve.rs +++ b/crates/tabby/src/serve.rs @@ -93,7 +93,7 @@ pub struct ServeArgs { port: u16, /// Device to run model inference. - #[clap(long, default_value_t=Device::Cpu)] + #[clap(long, default_value_t=Device::Http)] device: Device, /// Device to run chat model [default equals --device arg] @@ -168,14 +168,14 @@ pub async fn main(config: &Config, args: &ServeArgs) { } async fn load_model(args: &ServeArgs) { - if args.device != Device::ExperimentalHttp { + if args.device != Device::Http { if let Some(model) = &args.model { download_model_if_needed(model).await; } } let chat_device = args.chat_device.as_ref().unwrap_or(&args.device); - if chat_device != &Device::ExperimentalHttp { + if chat_device != &Device::Http { if let Some(chat_model) = &args.chat_model { download_model_if_needed(chat_model).await } diff --git a/crates/tabby/src/services/health.rs b/crates/tabby/src/services/health.rs index 140f6ee9f148..58ec046debfd 100644 --- a/crates/tabby/src/services/health.rs +++ b/crates/tabby/src/services/health.rs @@ -41,14 +41,14 @@ impl HealthState { }; let http_model_name = Some("Remote"); - let is_model_http = device == &Device::ExperimentalHttp; + let is_model_http = device == &Device::Http; let model = if is_model_http { http_model_name } else { model }; - let is_chat_model_http = chat_device == Some(&Device::ExperimentalHttp); + let is_chat_model_http = chat_device == Some(&Device::Http); let chat_model = if is_chat_model_http { http_model_name } else { diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index 655fba5ba561..c878c77f31c1 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -3,10 +3,6 @@ mod chat; use std::{fs, path::PathBuf, sync::Arc}; use serde::Deserialize; -use tabby_common::{ - registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}, - terminal::{HeaderFormat, InfoMessage}, -}; use tabby_download::download_model; use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream}; use tracing::info; @@ -18,7 +14,7 @@ pub async fn load_chat_completion( device: &Device, parallelism: u8, ) -> Arc { - if device == &Device::ExperimentalHttp { + if device == &Device::Http { return http_api_bindings::create_chat(model_id); } @@ -44,9 +40,9 @@ pub async fn load_code_generation( async fn load_completion( model_id: &str, device: &Device, - parallelism: u8, + #[allow(unused_variables)] parallelism: u8, ) -> (Arc, PromptInfo) { - if device == &Device::ExperimentalHttp { + if device == &Device::Http { let (engine, prompt_template, chat_template) = http_api_bindings::create(model_id); return ( engine, @@ -57,30 +53,37 @@ async fn load_completion( ); } - if fs::metadata(model_id).is_ok() { - let path = PathBuf::from(model_id); - let model_path = path.join(GGML_MODEL_RELATIVE_PATH); - let engine = create_ggml_engine( - device, - model_path.display().to_string().as_str(), - parallelism, - ); - let engine_info = PromptInfo::read(path.join("tabby.json")); - (Arc::new(engine), engine_info) - } else { - let (registry, name) = parse_model_id(model_id); - let registry = ModelRegistry::new(registry).await; - let model_path = registry.get_model_path(name).display().to_string(); - let model_info = registry.get_model_info(name); - let engine = create_ggml_engine(device, &model_path, parallelism); - ( - Arc::new(engine), - PromptInfo { - prompt_template: model_info.prompt_template.clone(), - chat_template: model_info.chat_template.clone(), - }, - ) + #[cfg(feature = "llama-cpp")] + { + use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}; + + if fs::metadata(model_id).is_ok() { + let path = PathBuf::from(model_id); + let model_path = path.join(GGML_MODEL_RELATIVE_PATH); + let engine = create_ggml_engine( + device, + model_path.display().to_string().as_str(), + parallelism, + ); + let engine_info = PromptInfo::read(path.join("tabby.json")); + (Arc::new(engine), engine_info) + } else { + let (registry, name) = parse_model_id(model_id); + let registry = ModelRegistry::new(registry).await; + let model_path = registry.get_model_path(name).display().to_string(); + let model_info = registry.get_model_info(name); + let engine = create_ggml_engine(device, &model_path, parallelism); + ( + Arc::new(engine), + PromptInfo { + prompt_template: model_info.prompt_template.clone(), + chat_template: model_info.chat_template.clone(), + }, + ) + } } + #[cfg(not(feature = "llama-cpp"))] + panic!("Unsupported device"); } #[derive(Deserialize)] @@ -96,7 +99,10 @@ impl PromptInfo { } } +#[cfg(feature = "llama-cpp")] fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl CompletionStream { + use tabby_common::terminal::{HeaderFormat, InfoMessage}; + if !device.ggml_use_gpu() { InfoMessage::new( "CPU Device", diff --git a/crates/tabby/src/worker.rs b/crates/tabby/src/worker.rs index f8778bdb87d4..24589846f638 100644 --- a/crates/tabby/src/worker.rs +++ b/crates/tabby/src/worker.rs @@ -38,7 +38,7 @@ pub struct WorkerArgs { model: String, /// Device to run model inference. - #[clap(long, default_value_t=Device::Cpu, help_heading=Some("Model Options"))] + #[clap(long, default_value_t=Device::Http, help_heading=Some("Model Options"))] device: Device, /// Parallelism for model serving - increasing this number will have a significant impact on the