Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): make llama cpp an optional dependency #2052

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
- run: bash ./ci/prepare_build_environment.sh

- name: Run unit tests on community build
run: cargo test --bin tabby --no-default-features
run: cargo test --bin tabby --no-default-features --features llama-cpp

- name: Run unit tests
run: cargo test --bin tabby --lib
Expand Down
13 changes: 7 additions & 6 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ authors.workspace = true
homepage.workspace = true

[features]
default = ["ee"]
default = ["ee", "llama-cpp"]
ee = ["dep:tabby-webserver"]
cuda = ["llama-cpp-bindings/cuda"]
rocm = ["llama-cpp-bindings/rocm"]
vulkan = ["llama-cpp-bindings/vulkan"]
llama-cpp = ["dep:llama-cpp-bindings"]
cuda = ["llama-cpp", "llama-cpp-bindings/cuda"]
rocm = ["llama-cpp", "llama-cpp-bindings/rocm"]
vulkan = ["llama-cpp", "llama-cpp-bindings/vulkan"]
# If compiling on a system without OpenSSL installed, or cross-compiling for a different
# architecture, enable this feature to compile OpenSSL as part of the build.
# See https://docs.rs/openssl/#vendored for more.
static-ssl = ['openssl/vendored']
prod = ['tabby-webserver/prod-db']
prod = ['tabby-webserver/prod-db', "llama-cpp"]

[dependencies]
tabby-common = { path = "../tabby-common" }
Expand Down Expand Up @@ -47,7 +48,7 @@ async-stream = { workspace = true }
minijinja = { version = "1.0.8", features = ["loader"] }
textdistance = "1.0.2"
regex.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
llama-cpp-bindings = { path = "../llama-cpp-bindings", optional = true}
futures.workspace = true
async-trait.workspace = true
tabby-webserver = { path = "../../ee/tabby-webserver", optional = true }
Expand Down
9 changes: 5 additions & 4 deletions crates/tabby/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub struct SchedulerArgs {

#[derive(clap::ValueEnum, strum::Display, PartialEq, Clone)]
pub enum Device {
#[cfg(feature = "llama-cpp")]
#[strum(serialize = "cpu")]
Cpu,

Expand All @@ -79,21 +80,21 @@ pub enum Device {
#[strum(serialize = "rocm")]
Rocm,

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[cfg(all(feature = "llama-cpp", target_os = "macos", target_arch = "aarch64"))]
#[strum(serialize = "metal")]
Metal,

#[cfg(feature = "vulkan")]
#[strum(serialize = "vulkan")]
Vulkan,

#[strum(serialize = "experimental_http")]
#[strum(serialize = "http")]
#[clap(hide = true)]
ExperimentalHttp,
Http,
}

impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[cfg(all(feature = "llama-cpp", target_os = "macos", target_arch = "aarch64"))]
pub fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}
Expand Down
6 changes: 3 additions & 3 deletions crates/tabby/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
port: u16,

/// Device to run model inference.
#[clap(long, default_value_t=Device::Cpu)]
#[clap(long, default_value_t=Device::Http)]

Check warning on line 96 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L96

Added line #L96 was not covered by tests
device: Device,

/// Device to run chat model [default equals --device arg]
Expand Down Expand Up @@ -168,14 +168,14 @@
}

async fn load_model(args: &ServeArgs) {
if args.device != Device::ExperimentalHttp {
if args.device != Device::Http {

Check warning on line 171 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L171

Added line #L171 was not covered by tests
if let Some(model) = &args.model {
download_model_if_needed(model).await;
}
}

let chat_device = args.chat_device.as_ref().unwrap_or(&args.device);
if chat_device != &Device::ExperimentalHttp {
if chat_device != &Device::Http {

Check warning on line 178 in crates/tabby/src/serve.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/serve.rs#L178

Added line #L178 was not covered by tests
if let Some(chat_model) = &args.chat_model {
download_model_if_needed(chat_model).await
}
Expand Down
4 changes: 2 additions & 2 deletions crates/tabby/src/services/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
};

let http_model_name = Some("Remote");
let is_model_http = device == &Device::ExperimentalHttp;
let is_model_http = device == &Device::Http;

Check warning on line 44 in crates/tabby/src/services/health.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/health.rs#L44

Added line #L44 was not covered by tests
let model = if is_model_http {
http_model_name
} else {
model
};

let is_chat_model_http = chat_device == Some(&Device::ExperimentalHttp);
let is_chat_model_http = chat_device == Some(&Device::Http);

Check warning on line 51 in crates/tabby/src/services/health.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/health.rs#L51

Added line #L51 was not covered by tests
let chat_model = if is_chat_model_http {
http_model_name
} else {
Expand Down
66 changes: 36 additions & 30 deletions crates/tabby/src/services/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
use std::{fs, path::PathBuf, sync::Arc};

use serde::Deserialize;
use tabby_common::{
registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH},
terminal::{HeaderFormat, InfoMessage},
};
use tabby_download::download_model;
use tabby_inference::{ChatCompletionStream, CodeGeneration, CompletionStream};
use tracing::info;
Expand All @@ -18,7 +14,7 @@
device: &Device,
parallelism: u8,
) -> Arc<dyn ChatCompletionStream> {
if device == &Device::ExperimentalHttp {
if device == &Device::Http {

Check warning on line 17 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L17

Added line #L17 was not covered by tests
return http_api_bindings::create_chat(model_id);
}

Expand All @@ -44,9 +40,9 @@
async fn load_completion(
model_id: &str,
device: &Device,
parallelism: u8,
#[allow(unused_variables)] parallelism: u8,

Check warning on line 43 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L43

Added line #L43 was not covered by tests
) -> (Arc<dyn CompletionStream>, PromptInfo) {
if device == &Device::ExperimentalHttp {
if device == &Device::Http {

Check warning on line 45 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L45

Added line #L45 was not covered by tests
let (engine, prompt_template, chat_template) = http_api_bindings::create(model_id);
return (
engine,
Expand All @@ -57,30 +53,37 @@
);
}

if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine(
device,
model_path.display().to_string().as_str(),
parallelism,
);
let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(engine), engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(device, &model_path, parallelism);
(
Arc::new(engine),
PromptInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
)
#[cfg(feature = "llama-cpp")]
{
use tabby_common::registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH};

if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine = create_ggml_engine(
device,
model_path.display().to_string().as_str(),
parallelism,
);
let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(engine), engine_info)

Check warning on line 69 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L56-L69

Added lines #L56 - L69 were not covered by tests
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(device, &model_path, parallelism);
(
Arc::new(engine),
PromptInfo {
prompt_template: model_info.prompt_template.clone(),
chat_template: model_info.chat_template.clone(),
},
)

Check warning on line 82 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L71-L82

Added lines #L71 - L82 were not covered by tests
}
}
#[cfg(not(feature = "llama-cpp"))]
panic!("Unsupported device");
}

#[derive(Deserialize)]
Expand All @@ -96,7 +99,10 @@
}
}

#[cfg(feature = "llama-cpp")]
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl CompletionStream {
use tabby_common::terminal::{HeaderFormat, InfoMessage};

Check warning on line 105 in crates/tabby/src/services/model/mod.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/services/model/mod.rs#L104-L105

Added lines #L104 - L105 were not covered by tests
if !device.ggml_use_gpu() {
InfoMessage::new(
"CPU Device",
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
model: String,

/// Device to run model inference.
#[clap(long, default_value_t=Device::Cpu, help_heading=Some("Model Options"))]
#[clap(long, default_value_t=Device::Http, help_heading=Some("Model Options"))]

Check warning on line 41 in crates/tabby/src/worker.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby/src/worker.rs#L41

Added line #L41 was not covered by tests
device: Device,

/// Parallelism for model serving - increasing this number will have a significant impact on the
Expand Down
Loading