From 835d96a18e2055687c92d83249b706af7b0aa913 Mon Sep 17 00:00:00 2001 From: mikbry Date: Thu, 7 Mar 2024 12:50:16 +0100 Subject: [PATCH 1/9] feat: add tokenize function to ImplProvider (prev ProviderDefinition) --- webapp/types/index.ts | 31 +++++++++++++------------- webapp/utils/providers/index.ts | 8 +++++++ webapp/utils/providers/openai/index.ts | 4 ++-- webapp/utils/providers/opla/index.ts | 4 ++-- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/webapp/types/index.ts b/webapp/types/index.ts index 13f245ac..92f7aeaa 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -62,11 +62,11 @@ export enum ContentType { export type Content = | string | { - type: ContentType; - parts: string[]; - raw?: string[]; - metadata?: Metadata; - }; + type: ContentType; + parts: string[]; + raw?: string[]; + metadata?: Metadata; + }; export enum MessageStatus { Pending = 'pending', @@ -87,13 +87,13 @@ export type Asset = BaseIdRecord & { state?: AssetState; } & ( | { - type: 'link'; - url: string; - } + type: 'link'; + url: string; + } | { - type: 'file'; - file: string; - } + type: 'file'; + file: string; + } ); export type Message = BaseIdRecord & { @@ -163,17 +163,17 @@ export enum AIServiceType { export type AIService = { disabled?: boolean; } & ( - | { + | { type: AIServiceType.Model; modelId: string; providerType?: ProviderType; } - | { + | { type: AIServiceType.Assistant; assistantId: string; targetId?: string; } -); + ); export type AIImplService = AIService & { model: Model | undefined; @@ -240,7 +240,7 @@ export type CompletionParameterDefinition = ParameterDefinition; export type CompletionParametersDefinition = Record; -export type ProviderDefinition = { +export type ImplProvider = { name: string; type: ProviderType; description: string; @@ -257,6 +257,7 @@ export type ProviderDefinition = { parameters?: LlmParameters[], ) => Promise; }; + tokenize?: (text: string) => Promise; }; export type Provider = BaseNamedRecord & { diff --git a/webapp/utils/providers/index.ts b/webapp/utils/providers/index.ts index 7e43ae16..06f82392 100644 --- a/webapp/utils/providers/index.ts +++ b/webapp/utils/providers/index.ts @@ -113,3 +113,11 @@ export const models = async (provider: Provider): Promise => { } return []; }; + +export const tokenize = async (activeService: AIImplService, text: string): Promise => { + const { provider } = activeService; + if (provider?.type === ProviderType.openai) { + return OpenAI.tokenize(text); + } + return Opla.tokenize(text); +} \ No newline at end of file diff --git a/webapp/utils/providers/openai/index.ts b/webapp/utils/providers/openai/index.ts index 758a1dda..68d7ef9c 100644 --- a/webapp/utils/providers/openai/index.ts +++ b/webapp/utils/providers/openai/index.ts @@ -10,7 +10,7 @@ import { LlmResponse, Model, Provider, - ProviderDefinition, + ImplProvider, ProviderType, } from '@/types'; import { mapKeys } from '@/utils/data'; @@ -165,7 +165,7 @@ const completion = async ( throw new Error(`${NAME} completion completion error ${response}`); }; -const OpenAIProvider: ProviderDefinition = { +const OpenAIProvider: ImplProvider = { name: NAME, type: TYPE, description: DESCRIPTION, diff --git a/webapp/utils/providers/opla/index.ts b/webapp/utils/providers/opla/index.ts index 78151b97..2cfbed61 100644 --- a/webapp/utils/providers/opla/index.ts +++ b/webapp/utils/providers/opla/index.ts @@ -20,7 +20,7 @@ import { LlmResponse, Model, Provider, - ProviderDefinition, + ImplProvider, ProviderType, } from '@/types'; import { mapKeys } from '@/utils/data'; @@ -248,7 +248,7 @@ const completion = async ( throw new Error(`${NAME} completion error ${response}`); }; -const OplaProvider: ProviderDefinition = { +const OplaProvider: ImplProvider = { name: NAME, type: TYPE, description: DESCRIPTION, From bc6b904abfe545fcd8b09e4be087ea4efb97fd31 Mon Sep 17 00:00:00 2001 From: mikbry Date: Thu, 7 Mar 2024 12:51:46 +0100 Subject: [PATCH 2/9] chore: prettier --- webapp/types/index.ts | 28 ++++++++++++++-------------- webapp/utils/providers/index.ts | 8 ++++---- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/webapp/types/index.ts b/webapp/types/index.ts index 92f7aeaa..4314f1f3 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -62,11 +62,11 @@ export enum ContentType { export type Content = | string | { - type: ContentType; - parts: string[]; - raw?: string[]; - metadata?: Metadata; - }; + type: ContentType; + parts: string[]; + raw?: string[]; + metadata?: Metadata; + }; export enum MessageStatus { Pending = 'pending', @@ -87,13 +87,13 @@ export type Asset = BaseIdRecord & { state?: AssetState; } & ( | { - type: 'link'; - url: string; - } + type: 'link'; + url: string; + } | { - type: 'file'; - file: string; - } + type: 'file'; + file: string; + } ); export type Message = BaseIdRecord & { @@ -163,17 +163,17 @@ export enum AIServiceType { export type AIService = { disabled?: boolean; } & ( - | { + | { type: AIServiceType.Model; modelId: string; providerType?: ProviderType; } - | { + | { type: AIServiceType.Assistant; assistantId: string; targetId?: string; } - ); +); export type AIImplService = AIService & { model: Model | undefined; diff --git a/webapp/utils/providers/index.ts b/webapp/utils/providers/index.ts index 06f82392..fcbc0b3c 100644 --- a/webapp/utils/providers/index.ts +++ b/webapp/utils/providers/index.ts @@ -114,10 +114,10 @@ export const models = async (provider: Provider): Promise => { return []; }; -export const tokenize = async (activeService: AIImplService, text: string): Promise => { +export const tokenize = async (activeService: AIImplService, text: string): Promise => { const { provider } = activeService; if (provider?.type === ProviderType.openai) { - return OpenAI.tokenize(text); + return OpenAI.tokenize?.(text) ?? []; } - return Opla.tokenize(text); -} \ No newline at end of file + return Opla.tokenize?.(text) ?? []; +}; From 4886ab88d9ee6c840c3bbcd1fc2cc493386e693e Mon Sep 17 00:00:00 2001 From: mikbry Date: Thu, 7 Mar 2024 15:57:53 +0100 Subject: [PATCH 3/9] feat: llama.cpp tokenize endpoint feat: implements ContextWindowPolicy, keepSystem in ts --- webapp/hooks/useBackendContext.tsx | 4 +- webapp/native/src/llm/llama_cpp.rs | 70 +++++++++++++++++++++++--- webapp/native/src/llm/mod.rs | 10 +++- webapp/native/src/llm/openai.rs | 20 ++++---- webapp/native/src/main.rs | 34 ++++++++++--- webapp/native/src/server.rs | 38 +++++++++----- webapp/types/index.ts | 36 +++++++------ webapp/utils/providers/index.ts | 54 ++++++++++++++------ webapp/utils/providers/openai/index.ts | 8 +-- webapp/utils/providers/opla/index.ts | 8 +-- 10 files changed, 202 insertions(+), 80 deletions(-) diff --git a/webapp/hooks/useBackendContext.tsx b/webapp/hooks/useBackendContext.tsx index d39ff429..09d23b1d 100644 --- a/webapp/hooks/useBackendContext.tsx +++ b/webapp/hooks/useBackendContext.tsx @@ -23,7 +23,7 @@ import { OplaContext, ServerStatus, Settings, - LlmResponse, + LlmCompletionResponse, LlmStreamResponse, Download, ServerParameters, @@ -202,7 +202,7 @@ function BackendProvider({ children }: { children: React.ReactNode }) { return; } logger.info('stream event', event, backendContext, context); - const response = (await mapKeys(event.payload, toCamelCase)) as LlmResponse; + const response = (await mapKeys(event.payload, toCamelCase)) as LlmCompletionResponse; if (!response.conversationId) { logger.error('stream event without conversationId', response); return; diff --git a/webapp/native/src/llm/llama_cpp.rs b/webapp/native/src/llm/llama_cpp.rs index e373b739..b5b05e8a 100644 --- a/webapp/native/src/llm/llama_cpp.rs +++ b/webapp/native/src/llm/llama_cpp.rs @@ -16,7 +16,9 @@ use crate::{ error::Error, llm::LlmQueryCompletion, store::ServerParameters }; use tauri::Runtime; use serde::{ Deserialize, Serialize }; -use crate::llm::{ LlmQuery, LlmResponse, LlmUsage }; +use crate::llm::{ LlmQuery, LlmCompletionResponse, LlmUsage }; + +use super::LlmTokenizeResponse; #[serde_with::skip_serializing_none] #[derive(Clone, Debug, Serialize, Deserialize)] @@ -127,8 +129,8 @@ pub struct LlamaCppChatCompletion { } impl LlamaCppChatCompletion { - pub fn to_llm_response(&self) -> LlmResponse { - LlmResponse { + pub fn to_llm_response(&self) -> LlmCompletionResponse { + LlmCompletionResponse { created: None, status: None, content: self.content.clone(), @@ -138,6 +140,23 @@ impl LlamaCppChatCompletion { } } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LlamaCppQueryTokenize { + pub content: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LlamaCppTokenize { + pub tokens: Vec, +} +impl LlamaCppTokenize { + pub fn to_llm_response(&self) -> LlmTokenizeResponse { + LlmTokenizeResponse { + tokens: self.tokens.clone(), + } + } +} #[derive(Clone, Debug, Serialize, Deserialize)] pub struct LLamaCppServer {} @@ -146,18 +165,22 @@ impl LLamaCppServer { LLamaCppServer {} } + fn get_api(&self, server_parameters: &ServerParameters, endpoint: String) -> String { + // TODO https support + format!("http://{:}:{:}/{}", server_parameters.host, server_parameters.port, endpoint) + } + pub async fn call_completion( &mut self, query: LlmQuery, server_parameters: &ServerParameters - ) -> Result> { + ) -> Result> { let parameters = query.options.to_llama_cpp_parameters(); - // TODO https support - let api = format!("http://{:}:{:}", server_parameters.host, server_parameters.port); + let api_url = self.get_api(server_parameters, query.command); let client = reqwest::Client::new(); let res = client - .post(format!("{}/{}", api, query.command)) // TODO remove hardcoding + .post(api_url) // TODO remove hardcoding .json(¶meters) .send().await; let response = match res { @@ -178,4 +201,37 @@ impl LLamaCppServer { }; Ok(response.to_llm_response()) } + + pub async fn call_tokenize( + &mut self, + text: String, + server_parameters: &ServerParameters + ) -> Result> { + let parameters = LlamaCppQueryTokenize { + content: text, + }; + let api_url = self.get_api(server_parameters, "tokenize".to_owned()); + let client = reqwest::Client::new(); + let res = client + .post(api_url) // TODO remove hardcoding + .json(¶meters) + .send().await; + let response = match res { + Ok(res) => res, + Err(error) => { + println!("Failed to get Response: {}", error); + return Err(Box::new(Error::BadResponse)); + } + }; + let status = response.status(); + println!("Response Status: {}", status); + let response = match response.json::().await { + Ok(r) => r, + Err(error) => { + println!("Failed to parse response: {}", error); + return Err(Box::new(Error::BadResponse)); + } + }; + Ok(response.to_llm_response()) + } } diff --git a/webapp/native/src/llm/mod.rs b/webapp/native/src/llm/mod.rs index b397db3e..ff9c7418 100644 --- a/webapp/native/src/llm/mod.rs +++ b/webapp/native/src/llm/mod.rs @@ -62,6 +62,7 @@ pub struct LlmParameter { pub struct LlmQueryCompletion { pub conversation_id: Option, pub messages: Vec, + pub prompt: Option, pub parameters: Option>, } @@ -161,14 +162,14 @@ impl LlmUsage { #[serde_with::skip_serializing_none] #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct LlmResponse { +pub struct LlmCompletionResponse { pub created: Option, pub status: Option, pub content: String, pub conversation_id: Option, pub usage: Option, } -impl LlmResponse { +impl LlmCompletionResponse { pub fn new(created: i64, status: &str, content: &str) -> Self { Self { created: Some(created), @@ -184,3 +185,8 @@ impl LlmResponse { pub struct LlmResponseError { pub error: LlmError, } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LlmTokenizeResponse { + pub tokens: Vec, +} \ No newline at end of file diff --git a/webapp/native/src/llm/openai.rs b/webapp/native/src/llm/openai.rs index 000a74a7..9aae3d11 100644 --- a/webapp/native/src/llm/openai.rs +++ b/webapp/native/src/llm/openai.rs @@ -21,7 +21,7 @@ use futures_util::stream::StreamExt; use crate::llm::{ LlmQuery, LlmQueryCompletion, - LlmResponse, + LlmCompletionResponse, LlmResponseError, LlmUsage, LlmError, @@ -160,8 +160,8 @@ impl OpenAIChatCompletion { } } - fn to_llm_response(&self) -> LlmResponse { - let mut response = LlmResponse::new( + fn to_llm_response(&self) -> LlmCompletionResponse { + let mut response = LlmCompletionResponse::new( self.created, "success", &self.choices[0].message.content @@ -186,7 +186,7 @@ async fn request( url: String, secret_key: &str, parameters: OpenAIBodyCompletion -) -> Result> { +) -> Result> { let client = reqwest::Client::new(); let result = client.post(url).bearer_auth(&secret_key).json(¶meters).send().await; let response = match result { @@ -223,8 +223,8 @@ async fn stream_request( url: String, secret_key: &str, parameters: OpenAIBodyCompletion, - callback: Option) + Copy> -) -> Result> { + callback: Option) + Copy> +) -> Result> { let client = reqwest::Client::new(); let result = client.post(url).bearer_auth(&secret_key).json(¶meters).send().await; let response = match result { @@ -278,7 +278,7 @@ async fn stream_request( Some(mut cb) => { cb( Ok( - LlmResponse::new( + LlmCompletionResponse::new( chrono::Utc::now().timestamp_millis(), "finished", "done" @@ -303,7 +303,7 @@ async fn stream_request( Some(mut cb) => { cb( Ok( - LlmResponse::new( + LlmCompletionResponse::new( chunk.created, "success", chunk.choices[0].delta.content @@ -340,8 +340,8 @@ pub async fn call_completion( secret_key: &str, model: &str, query: LlmQuery, - callback: Option) + Copy> -) -> Result> { + callback: Option) + Copy> +) -> Result> { let start_time = chrono::Utc::now().timestamp_millis(); let url = format!("{}/chat/{}s", api, query.command); println!( diff --git a/webapp/native/src/main.rs b/webapp/native/src/main.rs index f03a48fa..32b62bd0 100644 --- a/webapp/native/src/main.rs +++ b/webapp/native/src/main.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use api::{ hf::search_hf_models, models }; use data::model::Model; use downloader::Downloader; -use llm::{ LlmQuery, LlmResponse, LlmQueryCompletion, openai::call_completion, LlmError }; +use llm::{ openai::call_completion, LlmCompletionResponse, LlmError, LlmQuery, LlmQueryCompletion, LlmTokenizeResponse }; use models::{ fetch_models_collection, ModelsCollection }; use serde::Serialize; use store::{ Store, Provider, ProviderType, ProviderMetadata, Settings, ServerConfiguration }; @@ -419,7 +419,7 @@ async fn llm_call_completion( model: String, llm_provider: Option, query: LlmQuery -) -> Result { +) -> Result { let (llm_provider, llm_provider_type) = match llm_provider { Some(p) => { (p.clone(), p.r#type) } None => { @@ -484,7 +484,7 @@ async fn llm_call_completion( &secret_key, &model, query, - Some(|result: Result| { + Some(|result: Result| { match result { Ok(response) => { let mut response = response.clone(); @@ -510,6 +510,24 @@ async fn llm_call_completion( return Err(format!("LLM provider not found: {:?}", llm_provider_type)); } +#[tauri::command] +async fn llm_call_tokenize( + _app: tauri::AppHandle, + _window: tauri::Window, + context: State<'_, OplaContext>, + model: String, + provider: String, + text: String +) -> Result { + if provider == "opla" { + let context_server = Arc::clone(&context.server); + let mut server = context_server.lock().await; + let response = server.call_tokenize::(&model, text).await.map_err(|err| err.to_string())?; + return Ok(response); + + } + return Err(format!("LLM provider not found: {:?}", provider)); +} async fn start_server( app: tauri::AppHandle, context: State<'_, OplaContext> @@ -556,7 +574,10 @@ async fn model_download( drop(store); println!("model_download {} {}", state, model_id); let server = context.server.lock().await; - if state == "ok" && (m.reference.is_some_id_or_name(&server.model) || server.model.is_none()) { + if + state == "ok" && + (m.reference.is_some_id_or_name(&server.model) || server.model.is_none()) + { drop(server); let res = start_server(handle, context).await; match res { @@ -634,7 +655,7 @@ fn handle_download_event(app: &tauri::AppHandle, payload: &str let vec: Vec<&str> = payload.split(':').collect(); let (state, id) = (vec[0].to_string(), vec[1].to_string()); - let handler = app.app_handle(); + let handler = app.app_handle(); tauri::async_runtime::spawn(async move { let handler = handler.app_handle(); match model_download(handler, id.to_string(), state.to_string()).await { @@ -800,7 +821,8 @@ fn main() { uninstall_model, update_model, set_active_model, - llm_call_completion + llm_call_completion, + llm_call_tokenize, ] ) .run(tauri::generate_context!()) diff --git a/webapp/native/src/server.rs b/webapp/native/src/server.rs index 3b27cfd2..49f2dcba 100644 --- a/webapp/native/src/server.rs +++ b/webapp/native/src/server.rs @@ -16,14 +16,13 @@ use tokio::sync::Mutex; use std::sync::Arc; use crate::{ error::Error, - llm::LlmQueryCompletion, - llm::llama_cpp::LLamaCppServer, - store::{ ServerParameters, ServerConfiguration }, + llm::{llama_cpp::LLamaCppServer, LlmQueryCompletion, LlmTokenizeResponse}, + store::{ ServerConfiguration, ServerParameters }, }; use sysinfo::System; use tauri::async_runtime::JoinHandle; use tauri::{ api::process::{ Command, CommandEvent }, Runtime, Manager }; -use crate::llm::{ LlmQuery, LlmResponse }; +use crate::llm::{ LlmQuery, LlmCompletionResponse }; use std::time::Duration; use std::thread; @@ -323,9 +322,7 @@ impl OplaServer { }; let arguments = parameters.to_args(&model_path); println!("Opla server arguments: {}", arguments.join(" ")); - if - status == ServerStatus::Starting.as_str().to_string() - { + if status == ServerStatus::Starting.as_str().to_string() { println!("Opla server is starting: stop it"); match self.stop(&app).await { Ok(_) => {} @@ -333,10 +330,7 @@ impl OplaServer { return Err(e); } } - - } else if - status == ServerStatus::Started.as_str().to_string() - { + } else if status == ServerStatus::Started.as_str().to_string() { println!("Opla server already started "); return Ok(Payload { status: status.to_string(), @@ -406,7 +400,7 @@ impl OplaServer { }; process.kill(); println!("Kill Opla server {}", pid); */ - match (self.handle).take() { + match self.handle.take() { Some(handle) => { handle.abort(); } @@ -518,7 +512,7 @@ impl OplaServer { &mut self, model: &str, query: LlmQuery - ) -> Result> { + ) -> Result> { println!("{}", format!("Opla llm call: {:?} / {:?}", query.command, &model)); let server_parameters = match &self.parameters { @@ -531,4 +525,22 @@ impl OplaServer { self.server.call_completion::(query, server_parameters).await } + + pub async fn call_tokenize( + &mut self, + model: &str, + text: String, + ) -> Result> { + println!("{}", format!("Opla llm call tokenize: {:?}", &model)); + + let server_parameters = match &self.parameters { + Some(p) => p, + None => { + println!("Opla server error try to read parameters"); + return Err(Box::new(Error::BadParameters)); + } + }; + + self.server.call_tokenize::(text, server_parameters).await + } } diff --git a/webapp/types/index.ts b/webapp/types/index.ts index 4314f1f3..ce75b6fe 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -62,11 +62,11 @@ export enum ContentType { export type Content = | string | { - type: ContentType; - parts: string[]; - raw?: string[]; - metadata?: Metadata; - }; + type: ContentType; + parts: string[]; + raw?: string[]; + metadata?: Metadata; + }; export enum MessageStatus { Pending = 'pending', @@ -87,13 +87,13 @@ export type Asset = BaseIdRecord & { state?: AssetState; } & ( | { - type: 'link'; - url: string; - } + type: 'link'; + url: string; + } | { - type: 'file'; - file: string; - } + type: 'file'; + file: string; + } ); export type Message = BaseIdRecord & { @@ -163,17 +163,17 @@ export enum AIServiceType { export type AIService = { disabled?: boolean; } & ( - | { + | { type: AIServiceType.Model; modelId: string; providerType?: ProviderType; } - | { + | { type: AIServiceType.Assistant; assistantId: string; targetId?: string; } -); + ); export type AIImplService = AIService & { model: Model | undefined; @@ -255,7 +255,7 @@ export type ImplProvider = { system?: string, conversationId?: string, parameters?: LlmParameters[], - ) => Promise; + ) => Promise; }; tokenize?: (text: string) => Promise; }; @@ -485,7 +485,7 @@ export type LlmUsage = { totalPerSecond?: number; }; -export type LlmResponse = { +export type LlmCompletionResponse = { created?: number; status?: string; content: string; @@ -505,6 +505,10 @@ export type Cpu = { usage: number; }; +export type LlmTokenizeResponse = { + tokens: number[]; +}; + export type Sys = { name: string; kernelVersion: string; diff --git a/webapp/utils/providers/index.ts b/webapp/utils/providers/index.ts index fcbc0b3c..44067318 100644 --- a/webapp/utils/providers/index.ts +++ b/webapp/utils/providers/index.ts @@ -7,12 +7,14 @@ import { Conversation, LlmMessage, LlmParameters, - LlmResponse, + LlmCompletionResponse, Message, Model, Preset, Provider, ProviderType, + LlmTokenizeResponse, + ContextWindowPolicy, } from '@/types'; import OpenAI from './openai'; import Opla from './opla'; @@ -20,18 +22,45 @@ import { findCompatiblePreset, getCompletePresetProperties } from '../data/prese import { getMessageContentAsString } from '../data/messages'; import { ParsedPrompt } from '../parsers'; import { CommandManager } from '../commands/types'; +import { invokeTauri } from '../backend/tauri'; + + +export const tokenize = async (activeService: AIImplService, text: string): Promise => { + const { provider, model } = activeService; + let response: LlmTokenizeResponse; + if (model && provider){ + response = await invokeTauri('llm_call_tokenize', { + model: model.name, + provider: provider.name, + text, + }); + } else { + throw new Error('Model or provider not found'); + } + return response; +}; -// TODO: code it in Rust -// and use ContextWindowPolicy from webapp/utils/constants.ts export const buildContext = ( conversation: Conversation, messages: Message[], index: number, + policy: ContextWindowPolicy, + keepSystemMessages: boolean, ): LlmMessage[] => { const context: Message[] = []; - // Only ContextWindowPolicy.Last is implemented - if (index > 0) { - context.push(messages[index - 1]); + + if (policy === ContextWindowPolicy.Last) { + const message = messages.findLast((m) => m.author?.role === 'system'); + if (message) { + context.push(message); + } + } else { + // For other policies, we include all messages, and handle system messages accordingly + messages.forEach((message) => { + if (message.author?.role !== 'system' || keepSystemMessages) { + context.push(message); + } + }); } const llmMessages: LlmMessage[] = context.map((m) => ({ @@ -59,14 +88,14 @@ export const completion = async ( presets: Preset[], prompt: ParsedPrompt, commandManager: CommandManager, -): Promise => { +): Promise => { if (!activeService.model) { throw new Error('Model not set'); } const { model, provider } = activeService; const llmParameters: LlmParameters[] = []; const preset = findCompatiblePreset(conversation?.preset, presets, model?.name, provider); - const { parameters: presetParameters, system } = getCompletePresetProperties( + const { parameters: presetParameters, system, contextWindowPolicy = ContextWindowPolicy.None, keepSystem = true } = getCompletePresetProperties( preset, conversation, presets, @@ -91,7 +120,7 @@ export const completion = async ( }); } const index = conversationMessages.findIndex((m) => m.id === message.id); - const messages = buildContext(conversation, conversationMessages, index); + const messages = buildContext(conversation, conversationMessages, index, contextWindowPolicy, keepSystem); if (provider?.type === ProviderType.openai) { return OpenAI.completion.invoke( @@ -114,10 +143,3 @@ export const models = async (provider: Provider): Promise => { return []; }; -export const tokenize = async (activeService: AIImplService, text: string): Promise => { - const { provider } = activeService; - if (provider?.type === ProviderType.openai) { - return OpenAI.tokenize?.(text) ?? []; - } - return Opla.tokenize?.(text) ?? []; -}; diff --git a/webapp/utils/providers/openai/index.ts b/webapp/utils/providers/openai/index.ts index 68d7ef9c..7aa7d84e 100644 --- a/webapp/utils/providers/openai/index.ts +++ b/webapp/utils/providers/openai/index.ts @@ -7,7 +7,7 @@ import { LlmMessage, LlmParameters, LlmQueryCompletion, - LlmResponse, + LlmCompletionResponse, Model, Provider, ImplProvider, @@ -130,7 +130,7 @@ const completion = async ( system = DEFAULT_SYSTEM, conversationId?: string, parameters: LlmParameters[] = [], -): Promise => { +): Promise => { if (!model) { throw new Error('Model not found'); } @@ -151,11 +151,11 @@ const completion = async ( }, toSnakeCase, ); - const response: LlmResponse = (await invokeTauri('llm_call_completion', { + const response: LlmCompletionResponse = (await invokeTauri('llm_call_completion', { model: model.id, llmProvider: provider, query: { command: 'completion', options }, - })) as LlmResponse; + })) as LlmCompletionResponse; const { content } = response; if (content) { diff --git a/webapp/utils/providers/opla/index.ts b/webapp/utils/providers/opla/index.ts index 2cfbed61..ed0f4e17 100644 --- a/webapp/utils/providers/opla/index.ts +++ b/webapp/utils/providers/opla/index.ts @@ -17,7 +17,7 @@ import { LlmMessage, LlmParameters, LlmQueryCompletion, - LlmResponse, + LlmCompletionResponse, Model, Provider, ImplProvider, @@ -217,7 +217,7 @@ const completion = async ( system = DEFAULT_SYSTEM, conversationId?: string, parameters: LlmParameters[] = DEFAULT_PROPERTIES, -): Promise => { +): Promise => { if (!model) { throw new Error('Model not found'); } @@ -234,11 +234,11 @@ const completion = async ( }; const llmProvider = mapKeys(provider, toSnakeCase); - const response: LlmResponse = (await invokeTauri('llm_call_completion', { + const response: LlmCompletionResponse = (await invokeTauri('llm_call_completion', { model: model.name, llmProvider, query: { command: 'completion', options }, - })) as LlmResponse; + })) as LlmCompletionResponse; const { content } = response; if (content) { From 1b8aea6be7bd884ea1c9c5e32c86e9a982bbb233 Mon Sep 17 00:00:00 2001 From: mikbry Date: Thu, 7 Mar 2024 15:58:32 +0100 Subject: [PATCH 4/9] chore: prettier --- webapp/types/index.ts | 28 ++++++++++++++-------------- webapp/utils/providers/index.ts | 28 ++++++++++++++++++---------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/webapp/types/index.ts b/webapp/types/index.ts index ce75b6fe..585046a2 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -62,11 +62,11 @@ export enum ContentType { export type Content = | string | { - type: ContentType; - parts: string[]; - raw?: string[]; - metadata?: Metadata; - }; + type: ContentType; + parts: string[]; + raw?: string[]; + metadata?: Metadata; + }; export enum MessageStatus { Pending = 'pending', @@ -87,13 +87,13 @@ export type Asset = BaseIdRecord & { state?: AssetState; } & ( | { - type: 'link'; - url: string; - } + type: 'link'; + url: string; + } | { - type: 'file'; - file: string; - } + type: 'file'; + file: string; + } ); export type Message = BaseIdRecord & { @@ -163,17 +163,17 @@ export enum AIServiceType { export type AIService = { disabled?: boolean; } & ( - | { + | { type: AIServiceType.Model; modelId: string; providerType?: ProviderType; } - | { + | { type: AIServiceType.Assistant; assistantId: string; targetId?: string; } - ); +); export type AIImplService = AIService & { model: Model | undefined; diff --git a/webapp/utils/providers/index.ts b/webapp/utils/providers/index.ts index 44067318..db4fe1c7 100644 --- a/webapp/utils/providers/index.ts +++ b/webapp/utils/providers/index.ts @@ -24,11 +24,13 @@ import { ParsedPrompt } from '../parsers'; import { CommandManager } from '../commands/types'; import { invokeTauri } from '../backend/tauri'; - -export const tokenize = async (activeService: AIImplService, text: string): Promise => { +export const tokenize = async ( + activeService: AIImplService, + text: string, +): Promise => { const { provider, model } = activeService; let response: LlmTokenizeResponse; - if (model && provider){ + if (model && provider) { response = await invokeTauri('llm_call_tokenize', { model: model.name, provider: provider.name, @@ -95,11 +97,12 @@ export const completion = async ( const { model, provider } = activeService; const llmParameters: LlmParameters[] = []; const preset = findCompatiblePreset(conversation?.preset, presets, model?.name, provider); - const { parameters: presetParameters, system, contextWindowPolicy = ContextWindowPolicy.None, keepSystem = true } = getCompletePresetProperties( - preset, - conversation, - presets, - ); + const { + parameters: presetParameters, + system, + contextWindowPolicy = ContextWindowPolicy.None, + keepSystem = true, + } = getCompletePresetProperties(preset, conversation, presets); const commandParameters = commandManager.findCommandParameters(prompt); const parameters = { ...presetParameters, ...commandParameters }; let { key } = provider || {}; @@ -120,7 +123,13 @@ export const completion = async ( }); } const index = conversationMessages.findIndex((m) => m.id === message.id); - const messages = buildContext(conversation, conversationMessages, index, contextWindowPolicy, keepSystem); + const messages = buildContext( + conversation, + conversationMessages, + index, + contextWindowPolicy, + keepSystem, + ); if (provider?.type === ProviderType.openai) { return OpenAI.completion.invoke( @@ -142,4 +151,3 @@ export const models = async (provider: Provider): Promise => { } return []; }; - From 4f1a51cb707a5c2da0a278b8d3ec5b34171f5d98 Mon Sep 17 00:00:00 2001 From: mikbry Date: Thu, 7 Mar 2024 20:11:42 +0100 Subject: [PATCH 5/9] feat: pass completion options and refactor implProviders --- webapp/native/src/llm/llama_cpp.rs | 29 +++++++--- webapp/native/src/llm/mod.rs | 8 +++ webapp/native/src/llm/openai.rs | 36 +++++++++++-- webapp/native/src/main.rs | 8 +-- webapp/native/src/server.rs | 7 +-- webapp/types/index.ts | 14 +++-- webapp/utils/providers/index.ts | 73 +++++++++++++++++++------- webapp/utils/providers/openai/index.ts | 53 +------------------ webapp/utils/providers/opla/index.ts | 58 ++------------------ 9 files changed, 137 insertions(+), 149 deletions(-) diff --git a/webapp/native/src/llm/llama_cpp.rs b/webapp/native/src/llm/llama_cpp.rs index b5b05e8a..a7f8ee60 100644 --- a/webapp/native/src/llm/llama_cpp.rs +++ b/webapp/native/src/llm/llama_cpp.rs @@ -18,7 +18,7 @@ use tauri::Runtime; use serde::{ Deserialize, Serialize }; use crate::llm::{ LlmQuery, LlmCompletionResponse, LlmUsage }; -use super::LlmTokenizeResponse; +use super::{ LlmCompletionOptions, LlmTokenizeResponse }; #[serde_with::skip_serializing_none] #[derive(Clone, Debug, Serialize, Deserialize)] @@ -49,8 +49,23 @@ pub struct LlamaCppQueryCompletion { } impl LlmQueryCompletion { - fn to_llama_cpp_parameters(&self) -> LlamaCppQueryCompletion { + fn to_llama_cpp_parameters( + &self, + options: Option + ) -> LlamaCppQueryCompletion { let mut prompt = String::new(); + match options { + Some(options) => { + match options.system { + Some(system) => { + prompt.push_str(&format!("{}\n", system)); + } + None => {} + } + } + None => {} + } + // TODO: handle context_window_policy and keep_system for message in &self.messages { match message.role.as_str() { "user" => { @@ -140,7 +155,6 @@ impl LlamaCppChatCompletion { } } - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct LlamaCppQueryTokenize { pub content: String, @@ -166,16 +180,17 @@ impl LLamaCppServer { } fn get_api(&self, server_parameters: &ServerParameters, endpoint: String) -> String { - // TODO https support + // TODO https support format!("http://{:}:{:}/{}", server_parameters.host, server_parameters.port, endpoint) } pub async fn call_completion( &mut self, query: LlmQuery, - server_parameters: &ServerParameters + server_parameters: &ServerParameters, + completion_options: Option ) -> Result> { - let parameters = query.options.to_llama_cpp_parameters(); + let parameters = query.options.to_llama_cpp_parameters(completion_options); let api_url = self.get_api(server_parameters, query.command); let client = reqwest::Client::new(); @@ -202,7 +217,7 @@ impl LLamaCppServer { Ok(response.to_llm_response()) } - pub async fn call_tokenize( + pub async fn call_tokenize( &mut self, text: String, server_parameters: &ServerParameters diff --git a/webapp/native/src/llm/mod.rs b/webapp/native/src/llm/mod.rs index ff9c7418..f5752ea7 100644 --- a/webapp/native/src/llm/mod.rs +++ b/webapp/native/src/llm/mod.rs @@ -66,6 +66,14 @@ pub struct LlmQueryCompletion { pub parameters: Option>, } +#[serde_with::skip_serializing_none] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct LlmCompletionOptions { + pub context_window_policy: Option, + pub keep_system: Option, + pub system: Option, +} + impl LlmQueryCompletion { pub fn get_parameter_value(&self, key: &str) -> Option { let parameters = match &self.parameters { diff --git a/webapp/native/src/llm/openai.rs b/webapp/native/src/llm/openai.rs index 9aae3d11..60abe59b 100644 --- a/webapp/native/src/llm/openai.rs +++ b/webapp/native/src/llm/openai.rs @@ -28,6 +28,8 @@ use crate::llm::{ LlmMessage, }; +use super::LlmCompletionOptions; + #[serde_with::skip_serializing_none] #[derive(Clone, Debug, Serialize, Deserialize)] pub struct OpenAIBodyCompletion { @@ -44,10 +46,33 @@ pub struct OpenAIBodyCompletion { } impl OpenAIBodyCompletion { - pub fn new(model: String, from: &LlmQueryCompletion) -> Self { + pub fn new( + model: String, + from: &LlmQueryCompletion, + options: Option + ) -> Self { + let mut messages: Vec = vec![]; + + match options { + Some(options) => { + match options.system { + Some(system) => { + messages.push(LlmMessage { + content: system.clone(), + role: "system".to_owned(), + name: None, + }); + } + None => {} + } + } + None => {} + } + messages.extend(from.messages.clone()); + // TODO: handle context_window_policy and keep_system Self { model, - messages: from.messages.clone(), + messages, stream: from.get_parameter_as_boolean("stream"), temperature: from.get_parameter_as_f32("temperature"), stop: from.get_parameter_array("stop"), @@ -340,6 +365,7 @@ pub async fn call_completion( secret_key: &str, model: &str, query: LlmQuery, + completion_options: Option, callback: Option) + Copy> ) -> Result> { let start_time = chrono::Utc::now().timestamp_millis(); @@ -349,7 +375,11 @@ pub async fn call_completion( format!("llm call: {:?} / {:?} / {:?} / {:?}", query.command, url, &model, query.options) ); - let parameters = OpenAIBodyCompletion::new(model.to_owned(), &query.options); + let parameters = OpenAIBodyCompletion::new( + model.to_owned(), + &query.options, + completion_options + ); println!("llm call parameters: {:?}", parameters); let stream = match parameters.stream { Some(t) => t, diff --git a/webapp/native/src/main.rs b/webapp/native/src/main.rs index 32b62bd0..8b5ca5e6 100644 --- a/webapp/native/src/main.rs +++ b/webapp/native/src/main.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use api::{ hf::search_hf_models, models }; use data::model::Model; use downloader::Downloader; -use llm::{ openai::call_completion, LlmCompletionResponse, LlmError, LlmQuery, LlmQueryCompletion, LlmTokenizeResponse }; +use llm::{ openai::call_completion, LlmCompletionOptions, LlmCompletionResponse, LlmError, LlmQuery, LlmQueryCompletion, LlmTokenizeResponse }; use models::{ fetch_models_collection, ModelsCollection }; use serde::Serialize; use store::{ Store, Provider, ProviderType, ProviderMetadata, Settings, ServerConfiguration }; @@ -418,7 +418,8 @@ async fn llm_call_completion( context: State<'_, OplaContext>, model: String, llm_provider: Option, - query: LlmQuery + query: LlmQuery, + completion_options: Option ) -> Result { let (llm_provider, llm_provider_type) = match llm_provider { Some(p) => { (p.clone(), p.r#type) } @@ -454,7 +455,7 @@ async fn llm_call_completion( let mut server = context_server.lock().await; server.bind::(app, &model_name, &model_path).await.map_err(|err| err.to_string())?; let response = { - server.call_completion::(&model_name, query).await.map_err(|err| err.to_string())? + server.call_completion::(&model_name, query, completion_options).await.map_err(|err| err.to_string())? }; let parameters = server.parameters.clone(); server.set_parameters(&model, &model_path, parameters); @@ -484,6 +485,7 @@ async fn llm_call_completion( &secret_key, &model, query, + completion_options, Some(|result: Result| { match result { Ok(response) => { diff --git a/webapp/native/src/server.rs b/webapp/native/src/server.rs index 49f2dcba..25d13269 100644 --- a/webapp/native/src/server.rs +++ b/webapp/native/src/server.rs @@ -16,7 +16,7 @@ use tokio::sync::Mutex; use std::sync::Arc; use crate::{ error::Error, - llm::{llama_cpp::LLamaCppServer, LlmQueryCompletion, LlmTokenizeResponse}, + llm::{llama_cpp::LLamaCppServer, LlmCompletionOptions, LlmQueryCompletion, LlmTokenizeResponse}, store::{ ServerConfiguration, ServerParameters }, }; use sysinfo::System; @@ -511,7 +511,8 @@ impl OplaServer { pub async fn call_completion( &mut self, model: &str, - query: LlmQuery + query: LlmQuery, + completion_options: Option ) -> Result> { println!("{}", format!("Opla llm call: {:?} / {:?}", query.command, &model)); @@ -523,7 +524,7 @@ impl OplaServer { } }; - self.server.call_completion::(query, server_parameters).await + self.server.call_completion::(query, server_parameters, completion_options).await } pub async fn call_tokenize( diff --git a/webapp/types/index.ts b/webapp/types/index.ts index 585046a2..b9723375 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -240,22 +240,20 @@ export type CompletionParameterDefinition = ParameterDefinition; export type CompletionParametersDefinition = Record; +export type CompletionOptions = { + contextWindowPolicy?: ContextWindowPolicy; + keepSystem?: boolean; + system?: string; +}; export type ImplProvider = { name: string; type: ProviderType; description: string; system: string; + defaultParameters: LlmParameters[]; template: Partial; completion: { parameters: CompletionParametersDefinition; - invoke: ( - model: Model | undefined, - provider: Provider | undefined, - messages: LlmMessage[], - system?: string, - conversationId?: string, - parameters?: LlmParameters[], - ) => Promise; }; tokenize?: (text: string) => Promise; }; diff --git a/webapp/utils/providers/index.ts b/webapp/utils/providers/index.ts index db4fe1c7..77e8611e 100644 --- a/webapp/utils/providers/index.ts +++ b/webapp/utils/providers/index.ts @@ -15,6 +15,8 @@ import { ProviderType, LlmTokenizeResponse, ContextWindowPolicy, + ImplProvider, + LlmQueryCompletion, } from '@/types'; import OpenAI from './openai'; import Opla from './opla'; @@ -23,6 +25,9 @@ import { getMessageContentAsString } from '../data/messages'; import { ParsedPrompt } from '../parsers'; import { CommandManager } from '../commands/types'; import { invokeTauri } from '../backend/tauri'; +import { mapKeys } from '../data'; +import { toCamelCase, toSnakeCase } from '../string'; +import logger from '../logger'; export const tokenize = async ( activeService: AIImplService, @@ -42,7 +47,7 @@ export const tokenize = async ( return response; }; -export const buildContext = ( +export const createLlmMessages = ( conversation: Conversation, messages: Message[], index: number, @@ -67,7 +72,7 @@ export const buildContext = ( const llmMessages: LlmMessage[] = context.map((m) => ({ content: getMessageContentAsString(m), - role: m.author?.role === 'user' ? 'user' : 'assistant', + role: m.author?.role, name: m.author?.name, })); return llmMessages; @@ -97,12 +102,20 @@ export const completion = async ( const { model, provider } = activeService; const llmParameters: LlmParameters[] = []; const preset = findCompatiblePreset(conversation?.preset, presets, model?.name, provider); - const { - parameters: presetParameters, - system, - contextWindowPolicy = ContextWindowPolicy.None, - keepSystem = true, - } = getCompletePresetProperties(preset, conversation, presets); + const { parameters: presetParameters, ...completionOptions } = getCompletePresetProperties( + preset, + conversation, + presets, + ); + + let implProvider: ImplProvider; + if (provider?.type === ProviderType.opla) { + implProvider = Opla; + } else { + implProvider = OpenAI; + } + + const { contextWindowPolicy = ContextWindowPolicy.None, keepSystem = true } = completionOptions; const commandParameters = commandManager.findCommandParameters(prompt); const parameters = { ...presetParameters, ...commandParameters }; let { key } = provider || {}; @@ -122,8 +135,18 @@ export const completion = async ( } }); } + if ( + implProvider.completion.parameters.stream.defaultValue && + !llmParameters.find((p) => p.key === 'stream') + ) { + llmParameters.push({ + key: 'stream', + value: String(implProvider.completion.parameters.stream.defaultValue), + }); + } + const index = conversationMessages.findIndex((m) => m.id === message.id); - const messages = buildContext( + const messages = createLlmMessages( conversation, conversationMessages, index, @@ -131,17 +154,29 @@ export const completion = async ( keepSystem, ); - if (provider?.type === ProviderType.openai) { - return OpenAI.completion.invoke( - model, - { ...provider, key }, - messages, - system, - conversation.id, - llmParameters, - ); + const options: LlmQueryCompletion = mapKeys( + { + messages, // : [systemMessage, ...messages], + conversationId: conversation.id, + parameters: llmParameters, + }, + toSnakeCase, + ); + + const llmProvider = mapKeys({ ...provider, key }, toSnakeCase); + const response: LlmCompletionResponse = (await invokeTauri('llm_call_completion', { + model: model.name, + llmProvider, + query: { command: 'completion', options }, + completionOptions: mapKeys(completionOptions, toSnakeCase), + })) as LlmCompletionResponse; + + const { content } = response; + if (content) { + logger.info(`${implProvider.name} completion response`, response); + return mapKeys(response, toCamelCase); } - return Opla.completion.invoke(model, provider, messages, system, conversation.id, llmParameters); + throw new Error(`${implProvider.name} completion error ${response}`); }; export const models = async (provider: Provider): Promise => { diff --git a/webapp/utils/providers/openai/index.ts b/webapp/utils/providers/openai/index.ts index 7aa7d84e..fab3eefe 100644 --- a/webapp/utils/providers/openai/index.ts +++ b/webapp/utils/providers/openai/index.ts @@ -4,19 +4,11 @@ import { z } from 'zod'; import { CompletionParametersDefinition, - LlmMessage, LlmParameters, - LlmQueryCompletion, - LlmCompletionResponse, - Model, Provider, ImplProvider, ProviderType, } from '@/types'; -import { mapKeys } from '@/utils/data'; -import logger from '@/utils/logger'; -import { toCamelCase, toSnakeCase } from '@/utils/string'; -import { invokeTauri } from '@/utils/backend/tauri'; const NAME = 'OpenAI'; const TYPE = ProviderType.openai; @@ -24,6 +16,7 @@ const DESCRIPTION = 'OpenAI API'; const DEFAULT_SYSTEM = ` You are an expert in retrieving information. `; +const DEFAULT_PARAMETERS: LlmParameters[] = []; const openAIProviderTemplate: Partial = { name: NAME, @@ -123,57 +116,15 @@ const CompletionParameters: CompletionParametersDefinition = { }, }; -const completion = async ( - model: Model | undefined, - provider: Provider | undefined, - messages: LlmMessage[], - system = DEFAULT_SYSTEM, - conversationId?: string, - parameters: LlmParameters[] = [], -): Promise => { - if (!model) { - throw new Error('Model not found'); - } - - const systemMessage: LlmMessage = { - role: 'system', - content: system, - }; - - if (!parameters.find((p) => p.key === 'stream')) { - parameters.push({ key: 'stream', value: String(CompletionParameters.stream.defaultValue) }); - } - const options: LlmQueryCompletion = mapKeys( - { - messages: [systemMessage, ...messages], - conversationId, - parameters, - }, - toSnakeCase, - ); - const response: LlmCompletionResponse = (await invokeTauri('llm_call_completion', { - model: model.id, - llmProvider: provider, - query: { command: 'completion', options }, - })) as LlmCompletionResponse; - - const { content } = response; - if (content) { - logger.info(`${NAME} completion response`, content); - return mapKeys(response, toCamelCase); - } - throw new Error(`${NAME} completion completion error ${response}`); -}; - const OpenAIProvider: ImplProvider = { name: NAME, type: TYPE, description: DESCRIPTION, system: DEFAULT_SYSTEM, + defaultParameters: DEFAULT_PARAMETERS, template: openAIProviderTemplate, completion: { parameters: CompletionParameters, - invoke: completion, }, }; diff --git a/webapp/utils/providers/opla/index.ts b/webapp/utils/providers/opla/index.ts index ed0f4e17..0c4bcaaf 100644 --- a/webapp/utils/providers/opla/index.ts +++ b/webapp/utils/providers/opla/index.ts @@ -12,21 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -import { - CompletionParametersDefinition, - LlmMessage, - LlmParameters, - LlmQueryCompletion, - LlmCompletionResponse, - Model, - Provider, - ImplProvider, - ProviderType, -} from '@/types'; -import { mapKeys } from '@/utils/data'; -import logger from '@/utils/logger'; -import { toCamelCase, toSnakeCase } from '@/utils/string'; -import { invokeTauri } from '@/utils/backend/tauri'; +import { CompletionParametersDefinition, LlmParameters, ImplProvider, ProviderType } from '@/types'; import { z } from 'zod'; const NAME = 'Opla'; @@ -34,7 +20,7 @@ const TYPE = ProviderType.opla; const DESCRIPTION = 'Opla Open source local LLM'; const DEFAULT_SYSTEM = 'You are an expert in retrieving information.\n'; -const DEFAULT_PROPERTIES: LlmParameters[] = [ +const DEFAULT_PARAMETERS: LlmParameters[] = [ { key: 'stop', value: "['Llama:', 'User:', 'Question:']" }, ]; @@ -210,53 +196,15 @@ export const CompletionParameters: CompletionParametersDefinition = { }, }; -const completion = async ( - model: Model | undefined, - provider: Provider | undefined, - messages: LlmMessage[], - system = DEFAULT_SYSTEM, - conversationId?: string, - parameters: LlmParameters[] = DEFAULT_PROPERTIES, -): Promise => { - if (!model) { - throw new Error('Model not found'); - } - - const systemMessage: LlmMessage = { - role: 'system', - content: system, - }; - - const options: LlmQueryCompletion = { - messages: [systemMessage, ...messages], - conversationId, - parameters, - }; - - const llmProvider = mapKeys(provider, toSnakeCase); - const response: LlmCompletionResponse = (await invokeTauri('llm_call_completion', { - model: model.name, - llmProvider, - query: { command: 'completion', options }, - })) as LlmCompletionResponse; - - const { content } = response; - if (content) { - logger.info(`${NAME} completion response`, response); - return mapKeys(response, toCamelCase); - } - throw new Error(`${NAME} completion error ${response}`); -}; - const OplaProvider: ImplProvider = { name: NAME, type: TYPE, description: DESCRIPTION, system: DEFAULT_SYSTEM, + defaultParameters: DEFAULT_PARAMETERS, template: {}, // TODO: add template completion: { parameters: CompletionParameters, - invoke: completion, }, }; From ddc669aeb541a92068a48ab235ff9b8346f786d7 Mon Sep 17 00:00:00 2001 From: mikbry Date: Fri, 8 Mar 2024 10:52:07 +0100 Subject: [PATCH 6/9] fix: change service providerType to providerIdOrType feat: updateMessagesAndConversation partial conversation fix: tempConversationName get right content fix: activeServiceFrom refactoring and cleaner --- webapp/components/views/Threads/Thread.tsx | 44 ++++++++---------- webapp/context/index.tsx | 6 +-- webapp/types/index.ts | 2 +- webapp/utils/conversations/index.ts | 5 ++- webapp/utils/data/conversations.ts | 12 ++--- webapp/utils/services/index.ts | 52 ++++++++++++---------- 6 files changed, 62 insertions(+), 59 deletions(-) diff --git a/webapp/components/views/Threads/Thread.tsx b/webapp/components/views/Threads/Thread.tsx index 08fdd106..1a57ae03 100644 --- a/webapp/components/views/Threads/Thread.tsx +++ b/webapp/components/views/Threads/Thread.tsx @@ -18,15 +18,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from 'react'; import { useRouter } from 'next/router'; import { useSearchParams } from 'next/navigation'; import { AppContext } from '@/context'; -import { - Asset, - Conversation, - AIService, - AIServiceType, - Message, - MessageStatus, - ProviderType, -} from '@/types'; +import { Asset, Conversation, AIService, AIServiceType, Message, MessageStatus } from '@/types'; import useTranslation from '@/hooks/useTranslation'; import logger from '@/utils/logger'; import { @@ -55,6 +47,7 @@ import { createMessage, changeMessageContent, getMessageRawContentAsString, + getMessageContentAsString, } from '@/utils/data/messages'; import { getCommandManager, preProcessingCommands } from '@/utils/commands'; import ContentView from '@/components/common/ContentView'; @@ -148,7 +141,9 @@ function Thread({ selectedConversation, ]); - const tempConversationName = messages?.[0]?.content as string; + const tempConversationName = messages?.[0] + ? getMessageContentAsString(messages?.[0]) + : 'Conversation'; const { modelItems, commandManager } = useMemo(() => { const selectedModelNameOrId = getConversationModelId(selectedConversation) || activeModel; @@ -191,18 +186,18 @@ function Thread({ parsePrompt({ text, caretStartIndex }, tokenValidator); const changeService = async ( - model?: string, - provider = ProviderType.opla, + model: string, + providerIdOrName: string, partial: Partial = {}, ) => { logger.info( - `ChangeService ${model} ${provider} activeModel=${typeof activeModel}`, + `ChangeService ${model} ${providerIdOrName} activeModel=${typeof activeModel}`, selectedConversation, ); const newService: AIService = { type: AIServiceType.Model, modelId: model as string, - providerType: provider, + providerIdOrName, }; if (model && selectedConversation) { const services = addService(selectedConversation.services, newService); @@ -225,15 +220,15 @@ function Thread({ conversation: Conversation, updatedConversations: Conversation[], prompt: ParsedPrompt, + modelName: string, ) => { const returnedMessage = { ...message }; const activeService = getActiveService( conversation, assistant, providers, - activeModel, backendContext, - message.author.name, + modelName, ); logger.info('sendMessage', activeService, conversation, presets); @@ -275,7 +270,7 @@ function Thread({ await updateMessagesAndConversation( [returnedMessage], conversationMessages, - conversation.name, + { name: conversation.name }, conversation.id, updatedConversations, ); @@ -324,7 +319,7 @@ function Thread({ clearPrompt(result.updatedConversation, result.updatedConversations); return; } - const { modelName } = result; + const { modelName = selectedModelNameOrId } = result; setErrorMessage({ ...errorMessage, [conversationId]: '' }); setIsProcessing({ ...isProcessing, [conversationId]: true }); @@ -334,10 +329,7 @@ function Thread({ currentPrompt.text, currentPrompt.raw, ); - let message = createMessage( - { role: 'assistant', name: modelName || selectedModelNameOrId }, - '...', - ); + let message = createMessage({ role: 'assistant', name: modelName }, '...'); message.status = MessageStatus.Pending; userMessage.sibling = message.id; message.sibling = userMessage.id; @@ -349,7 +341,7 @@ function Thread({ } = await updateMessagesAndConversation( [userMessage, message], getConversationMessages(conversationId), - tempConversationName, + { name: tempConversationName }, conversationId, ); let updatedConversations = uc; @@ -372,6 +364,7 @@ function Thread({ updatedConversation, updatedConversations, currentPrompt, + modelName, ); if (tempConversationId) { @@ -407,7 +400,7 @@ function Thread({ await updateMessagesAndConversation( [message], conversationMessages, - tempConversationName, + { name: tempConversationName }, conversationId, ); @@ -421,6 +414,7 @@ function Thread({ updatedConversation, updatedConversations, prompt, + selectedModelNameOrId, ); setIsProcessing({ ...isProcessing, [conversationId]: false }); @@ -491,7 +485,7 @@ function Thread({ const { updatedMessages } = await updateMessagesAndConversation( newMessages, conversationMessages, - tempConversationName, + { name: tempConversationName }, conversationId, conversations, ); diff --git a/webapp/context/index.tsx b/webapp/context/index.tsx index c00d1ab6..b5df7aa5 100644 --- a/webapp/context/index.tsx +++ b/webapp/context/index.tsx @@ -45,7 +45,7 @@ export type Context = { updateMessagesAndConversation: ( changedMessages: Message[], conversationMessages: Message[], - newConversationTitle: string, + partialConversation: Partial, selectedConversationId: string, selectedConversations?: Conversation[], ) => Promise<{ @@ -178,14 +178,14 @@ function AppContextProvider({ children }: { children: React.ReactNode }) { async ( changedMessages: Message[], conversationMessages: Message[], - newConversationTitle: string, + partialConversation: Partial, selectedConversationId: string, // = conversationId, selectedConversations = conversations, ) => { const updatedConversations = updateOrCreateConversation( selectedConversationId, selectedConversations, - newConversationTitle, // messages?.[0]?.content as string, + partialConversation, // messages?.[0]?.content as string, ); const updatedMessages = mergeMessages(conversationMessages, changedMessages); await updateConversations(updatedConversations); diff --git a/webapp/types/index.ts b/webapp/types/index.ts index b9723375..f86588bd 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -166,7 +166,7 @@ export type AIService = { | { type: AIServiceType.Model; modelId: string; - providerType?: ProviderType; + providerIdOrName?: string; } | { type: AIServiceType.Assistant; diff --git a/webapp/utils/conversations/index.ts b/webapp/utils/conversations/index.ts index 092cbe4c..eeaa291c 100644 --- a/webapp/utils/conversations/index.ts +++ b/webapp/utils/conversations/index.ts @@ -83,5 +83,8 @@ export const getConversationTitle = (conversation: Conversation) => { ? (conversation.currentPrompt as ParsedPrompt).text || '' : conversation.currentPrompt || ''; } - return conversation.name; + if (typeof conversation.name === 'string' && conversation.name.length > 0) { + return conversation.name; + } + return 'Conversation'; }; diff --git a/webapp/utils/data/conversations.ts b/webapp/utils/data/conversations.ts index 53da2e23..25792689 100644 --- a/webapp/utils/data/conversations.ts +++ b/webapp/utils/data/conversations.ts @@ -17,7 +17,6 @@ import { Conversation, AIService, AIServiceType, - ProviderType, Assistant, } from '@/types'; import { createBaseRecord, createBaseNamedRecord, updateRecord } from '.'; @@ -90,14 +89,15 @@ export const removeConversation = (conversationId: string, conversations: Conver export const updateOrCreateConversation = ( conversationId: string | undefined, conversations: Conversation[], - title = 'Conversation', + partial: Partial, ) => { let conversation = conversations.find((c) => c.id === conversationId); let updatedConversations; if (conversation) { - updatedConversations = updateConversation(conversation, conversations); + updatedConversations = updateConversation({ ...conversation, ...partial }, conversations); } else { - conversation = createConversation(title.trim().substring(0, 200)); + const name = partial.name || 'Conversation'; + conversation = createConversation(name.trim().substring(0, 200)); updatedConversations = [...conversations, conversation]; } return updatedConversations; @@ -130,7 +130,7 @@ export const getConversationService = ( service = { type: serviceType, modelId: conversation.model, - providerType: conversation.provider as ProviderType, + providerIdOrName: conversation.provider, }; } if (conversation.model && serviceType === AIServiceType.Assistant && assistantId) { @@ -180,7 +180,7 @@ export const getServiceModelId = (modelService: AIService | undefined) => { export const getServiceProvider = (modelService: AIService | undefined) => { if (modelService && modelService.type === AIServiceType.Model) { - return modelService.providerType; + return modelService.providerIdOrName; } return undefined; }; diff --git a/webapp/utils/services/index.ts b/webapp/utils/services/index.ts index 0b87b61d..95c83e58 100644 --- a/webapp/utils/services/index.ts +++ b/webapp/utils/services/index.ts @@ -26,6 +26,7 @@ import { import { getConversationService } from '../data/conversations'; import { findModel, findModelInAll } from '../data/models'; import { findProvider } from '../data/providers'; +import OplaProvider from '../providers/opla'; export const activeServiceFrom = (service: AIService): AIImplService => ({ ...service, @@ -37,48 +38,53 @@ export const getActiveService = ( conversation: Conversation, assistant: Assistant | undefined, providers: Provider[], - activeModel: string, backendContext: OplaContext, - _modelName: string | undefined, + modelName: string, ): AIImplService => { const type = assistant ? AIServiceType.Assistant : AIServiceType.Model; - const activeService: AIService | undefined = getConversationService( + let activeService: AIService | undefined = getConversationService( conversation, type, assistant?.id, ); - let model: Model | undefined; - let providerName: string | undefined = model?.provider; let provider: Provider | undefined; - - if (activeService && activeService.type === AIServiceType.Model) { - provider = findProvider(activeService.providerType, providers); - model = findModel(activeService.modelId, provider?.models || []); - if (provider) { - providerName = provider.name; + if (!activeService) { + if (assistant) { + activeService = activeServiceFrom({ + type: AIServiceType.Assistant, + assistantId: assistant.id, + targetId: assistant.targets?.[0].id, + }); + } else { + activeService = activeServiceFrom({ + type: AIServiceType.Model, + modelId: modelName, + providerIdOrName: conversation.provider, + }); + } + } + if (activeService.type === AIServiceType.Model) { + let { providerIdOrName } = activeService; + provider = findProvider(providerIdOrName, providers); + model = findModel(modelName || activeService.modelId, provider?.models || []); + if (!model) { + model = findModelInAll(activeService.modelId, providers, backendContext); } - } else if (activeService && activeService.type === AIServiceType.Assistant) { + if (!provider) { + providerIdOrName = model?.provider || OplaProvider.name; + provider = findProvider(providerIdOrName, providers); + } + } else if (activeService.type === AIServiceType.Assistant) { const { assistantId, targetId } = activeService; if (assistantId && targetId) { const target = assistant?.targets?.find((t) => t.id === targetId); if (target?.models && target.models.length > 0) { model = findModelInAll(target.models[0], providers, backendContext); provider = findProvider(target.provider, providers); - providerName = provider?.name; } } } - const modelName = _modelName || model?.name || conversation.model || activeModel; - if (!assistant && !provider) { - if (!model || model.name !== modelName) { - model = findModelInAll(modelName, providers, backendContext); - } - const modelProviderName = model?.provider || model?.creator; - if (modelProviderName && modelProviderName !== providerName) { - provider = findProvider(modelProviderName, providers); - } - } return { ...activeService, model, provider } as AIImplService; }; From aba8d0755d9ba5fc7610a46807da81b62f81e24d Mon Sep 17 00:00:00 2001 From: mikbry Date: Fri, 8 Mar 2024 10:55:01 +0100 Subject: [PATCH 7/9] fix: add services at conversation creation --- webapp/components/views/Threads/Thread.tsx | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/webapp/components/views/Threads/Thread.tsx b/webapp/components/views/Threads/Thread.tsx index 1a57ae03..c288701c 100644 --- a/webapp/components/views/Threads/Thread.tsx +++ b/webapp/components/views/Threads/Thread.tsx @@ -519,19 +519,20 @@ function Thread({ updatedConversations = updateConversation(conversation, updatedConversations, true); } else { updatedConversations = conversations.filter((c) => !c.temp); - const newConversation = createConversation('Conversation'); - updatedConversations.push(newConversation); + let newConversation = createConversation('Conversation'); + newConversation.temp = true; newConversation.name = conversationName; newConversation.currentPrompt = prompt; if (assistant) { const newService = getDefaultAssistantService(assistant); - addConversationService(newConversation, newService); + newConversation = addConversationService(newConversation, newService); setService(undefined); } else if (service) { - addConversationService(newConversation, service); + newConversation = addConversationService(newConversation, service); setService(undefined); } + updatedConversations.push(newConversation); setTempConversationId(newConversation.id); } updateConversations(updatedConversations); From feed4ca3bde260598e404385c8e43605053bd771 Mon Sep 17 00:00:00 2001 From: mikbry Date: Fri, 8 Mar 2024 11:00:29 +0100 Subject: [PATCH 8/9] feat: getDefaultConversationName() --- webapp/components/views/Threads/Thread.tsx | 9 ++++----- webapp/utils/conversations/index.ts | 3 ++- webapp/utils/data/conversations.ts | 4 +++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/webapp/components/views/Threads/Thread.tsx b/webapp/components/views/Threads/Thread.tsx index c288701c..90a172fb 100644 --- a/webapp/components/views/Threads/Thread.tsx +++ b/webapp/components/views/Threads/Thread.tsx @@ -30,6 +30,7 @@ import { getConversationModelId, addService, addConversationService, + getDefaultConversationName, } from '@/utils/data/conversations'; import useBackend from '@/hooks/useBackendContext'; import { completion } from '@/utils/providers'; @@ -143,7 +144,7 @@ function Thread({ const tempConversationName = messages?.[0] ? getMessageContentAsString(messages?.[0]) - : 'Conversation'; + : getDefaultConversationName(); const { modelItems, commandManager } = useMemo(() => { const selectedModelNameOrId = getConversationModelId(selectedConversation) || activeModel; @@ -500,7 +501,7 @@ function Thread({ }; const handleUpdatePrompt = useCallback( - (prompt: ParsedPrompt | undefined, conversationName = 'Conversation') => { + (prompt: ParsedPrompt | undefined, conversationName = getDefaultConversationName()) => { if (prompt?.raw === '' && tempConversationId) { setChangedPrompt(undefined); updateConversations(conversations.filter((c) => !c.temp)); @@ -519,10 +520,8 @@ function Thread({ updatedConversations = updateConversation(conversation, updatedConversations, true); } else { updatedConversations = conversations.filter((c) => !c.temp); - let newConversation = createConversation('Conversation'); - + let newConversation = createConversation(conversationName); newConversation.temp = true; - newConversation.name = conversationName; newConversation.currentPrompt = prompt; if (assistant) { const newService = getDefaultAssistantService(assistant); diff --git a/webapp/utils/conversations/index.ts b/webapp/utils/conversations/index.ts index eeaa291c..e34289a3 100644 --- a/webapp/utils/conversations/index.ts +++ b/webapp/utils/conversations/index.ts @@ -17,6 +17,7 @@ import { SafeParseReturnType, z } from 'zod'; import { Conversation, Metadata } from '@/types'; import { ParsedPrompt } from '../parsers'; +import { getDefaultConversationName } from '../data/conversations'; export const MetadataSchema: z.ZodSchema = z.lazy(() => z.record(z.union([z.string(), z.number(), z.boolean(), MetadataSchema])), @@ -86,5 +87,5 @@ export const getConversationTitle = (conversation: Conversation) => { if (typeof conversation.name === 'string' && conversation.name.length > 0) { return conversation.name; } - return 'Conversation'; + return getDefaultConversationName(); }; diff --git a/webapp/utils/data/conversations.ts b/webapp/utils/data/conversations.ts index 25792689..cfe0e710 100644 --- a/webapp/utils/data/conversations.ts +++ b/webapp/utils/data/conversations.ts @@ -21,6 +21,8 @@ import { } from '@/types'; import { createBaseRecord, createBaseNamedRecord, updateRecord } from '.'; +export const getDefaultConversationName = () => 'Conversation'; + export const getConversationAssets = (conversation: Conversation) => !conversation.assets || Array.isArray(conversation.assets) ? conversation.assets || [] @@ -96,7 +98,7 @@ export const updateOrCreateConversation = ( if (conversation) { updatedConversations = updateConversation({ ...conversation, ...partial }, conversations); } else { - const name = partial.name || 'Conversation'; + const name = partial.name || getDefaultConversationName(); conversation = createConversation(name.trim().substring(0, 200)); updatedConversations = [...conversations, conversation]; } From 266c07e982c63e923f219b4c473b9902f785cd63 Mon Sep 17 00:00:00 2001 From: mikbry Date: Fri, 8 Mar 2024 11:06:04 +0100 Subject: [PATCH 9/9] fet: getDefaultConversationName use translation --- webapp/components/views/Threads/Explorer/index.tsx | 4 ++-- webapp/components/views/Threads/Thread.tsx | 8 ++++---- webapp/utils/conversations/index.ts | 4 ++-- webapp/utils/data/conversations.ts | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/webapp/components/views/Threads/Explorer/index.tsx b/webapp/components/views/Threads/Explorer/index.tsx index 54f3ceb0..757dec59 100644 --- a/webapp/components/views/Threads/Explorer/index.tsx +++ b/webapp/components/views/Threads/Explorer/index.tsx @@ -314,11 +314,11 @@ export default function ThreadsExplorer({ (c1, c2) => c2.updatedAt - c1.updatedAt || c2.createdAt - c1.createdAt, )} editable - getItemTitle={(c) => `${getConversationTitle(c)}${c.temp ? '...' : ''}`} + getItemTitle={(c) => `${getConversationTitle(c, t)}${c.temp ? '...' : ''}`} isEditable={(c) => !c.temp && c.id === selectedThreadId} renderItem={(c) => ( <> - {getConversationTitle(c).replaceAll(' ', '\u00a0')} + {getConversationTitle(c, t).replaceAll(' ', '\u00a0')} {c.temp ? ... : ''} )} diff --git a/webapp/components/views/Threads/Thread.tsx b/webapp/components/views/Threads/Thread.tsx index 90a172fb..3f4b82f4 100644 --- a/webapp/components/views/Threads/Thread.tsx +++ b/webapp/components/views/Threads/Thread.tsx @@ -144,7 +144,7 @@ function Thread({ const tempConversationName = messages?.[0] ? getMessageContentAsString(messages?.[0]) - : getDefaultConversationName(); + : getDefaultConversationName(t); const { modelItems, commandManager } = useMemo(() => { const selectedModelNameOrId = getConversationModelId(selectedConversation) || activeModel; @@ -347,7 +347,7 @@ function Thread({ ); let updatedConversations = uc; if (updatedConversation.temp) { - updatedConversation.name = getConversationTitle(updatedConversation); + updatedConversation.name = getConversationTitle(updatedConversation, t); } updatedConversations = clearPrompt(updatedConversation, updatedConversations); @@ -501,7 +501,7 @@ function Thread({ }; const handleUpdatePrompt = useCallback( - (prompt: ParsedPrompt | undefined, conversationName = getDefaultConversationName()) => { + (prompt: ParsedPrompt | undefined, conversationName = getDefaultConversationName(t)) => { if (prompt?.raw === '' && tempConversationId) { setChangedPrompt(undefined); updateConversations(conversations.filter((c) => !c.temp)); @@ -537,7 +537,7 @@ function Thread({ updateConversations(updatedConversations); setChangedPrompt(undefined); }, - [tempConversationId, conversationId, conversations, updateConversations, assistant, service], + [t, tempConversationId, conversationId, conversations, updateConversations, assistant, service], ); useDebounceFunc(handleUpdatePrompt, changedPrompt, 500); diff --git a/webapp/utils/conversations/index.ts b/webapp/utils/conversations/index.ts index e34289a3..04638fc2 100644 --- a/webapp/utils/conversations/index.ts +++ b/webapp/utils/conversations/index.ts @@ -78,7 +78,7 @@ export type Conversations = z.infer; export const validateConversations = (data: unknown): SafeParseReturnType => ConversationsSchema.safeParse(data); -export const getConversationTitle = (conversation: Conversation) => { +export const getConversationTitle = (conversation: Conversation, t: (value: string) => string) => { if (conversation.temp) { return conversation.currentPrompt && typeof conversation.currentPrompt !== 'string' ? (conversation.currentPrompt as ParsedPrompt).text || '' @@ -87,5 +87,5 @@ export const getConversationTitle = (conversation: Conversation) => { if (typeof conversation.name === 'string' && conversation.name.length > 0) { return conversation.name; } - return getDefaultConversationName(); + return getDefaultConversationName(t); }; diff --git a/webapp/utils/data/conversations.ts b/webapp/utils/data/conversations.ts index cfe0e710..2ea4da47 100644 --- a/webapp/utils/data/conversations.ts +++ b/webapp/utils/data/conversations.ts @@ -21,7 +21,7 @@ import { } from '@/types'; import { createBaseRecord, createBaseNamedRecord, updateRecord } from '.'; -export const getDefaultConversationName = () => 'Conversation'; +export const getDefaultConversationName = (t = (value: string) => value) => t('Conversation'); export const getConversationAssets = (conversation: Conversation) => !conversation.assets || Array.isArray(conversation.assets)