Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor ImplProvider to handle context window policy and keepsystem in Rust for Opla and OpenAI #410

Merged
merged 9 commits into from
Mar 8, 2024
4 changes: 2 additions & 2 deletions webapp/components/views/Threads/Explorer/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,11 @@ export default function ThreadsExplorer({
(c1, c2) => c2.updatedAt - c1.updatedAt || c2.createdAt - c1.createdAt,
)}
editable
getItemTitle={(c) => `${getConversationTitle(c)}${c.temp ? '...' : ''}`}
getItemTitle={(c) => `${getConversationTitle(c, t)}${c.temp ? '...' : ''}`}
isEditable={(c) => !c.temp && c.id === selectedThreadId}
renderItem={(c) => (
<>
<span>{getConversationTitle(c).replaceAll(' ', '\u00a0')}</span>
<span>{getConversationTitle(c, t).replaceAll(' ', '\u00a0')}</span>
{c.temp ? <span className="ml-2 animate-pulse">...</span> : ''}
</>
)}
Expand Down
60 changes: 27 additions & 33 deletions webapp/components/views/Threads/Thread.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,7 @@ import { useCallback, useContext, useEffect, useMemo, useState } from 'react';
import { useRouter } from 'next/router';
import { useSearchParams } from 'next/navigation';
import { AppContext } from '@/context';
import {
Asset,
Conversation,
AIService,
AIServiceType,
Message,
MessageStatus,
ProviderType,
} from '@/types';
import { Asset, Conversation, AIService, AIServiceType, Message, MessageStatus } from '@/types';
import useTranslation from '@/hooks/useTranslation';
import logger from '@/utils/logger';
import {
Expand All @@ -38,6 +30,7 @@ import {
getConversationModelId,
addService,
addConversationService,
getDefaultConversationName,
} from '@/utils/data/conversations';
import useBackend from '@/hooks/useBackendContext';
import { completion } from '@/utils/providers';
Expand All @@ -55,6 +48,7 @@ import {
createMessage,
changeMessageContent,
getMessageRawContentAsString,
getMessageContentAsString,
} from '@/utils/data/messages';
import { getCommandManager, preProcessingCommands } from '@/utils/commands';
import ContentView from '@/components/common/ContentView';
Expand Down Expand Up @@ -148,7 +142,9 @@ function Thread({
selectedConversation,
]);

const tempConversationName = messages?.[0]?.content as string;
const tempConversationName = messages?.[0]
? getMessageContentAsString(messages?.[0])
: getDefaultConversationName(t);

const { modelItems, commandManager } = useMemo(() => {
const selectedModelNameOrId = getConversationModelId(selectedConversation) || activeModel;
Expand Down Expand Up @@ -191,18 +187,18 @@ function Thread({
parsePrompt({ text, caretStartIndex }, tokenValidator);

const changeService = async (
model?: string,
provider = ProviderType.opla,
model: string,
providerIdOrName: string,
partial: Partial<Conversation> = {},
) => {
logger.info(
`ChangeService ${model} ${provider} activeModel=${typeof activeModel}`,
`ChangeService ${model} ${providerIdOrName} activeModel=${typeof activeModel}`,
selectedConversation,
);
const newService: AIService = {
type: AIServiceType.Model,
modelId: model as string,
providerType: provider,
providerIdOrName,
};
if (model && selectedConversation) {
const services = addService(selectedConversation.services, newService);
Expand All @@ -225,15 +221,15 @@ function Thread({
conversation: Conversation,
updatedConversations: Conversation[],
prompt: ParsedPrompt,
modelName: string,
) => {
const returnedMessage = { ...message };
const activeService = getActiveService(
conversation,
assistant,
providers,
activeModel,
backendContext,
message.author.name,
modelName,
);
logger.info('sendMessage', activeService, conversation, presets);

Expand Down Expand Up @@ -275,7 +271,7 @@ function Thread({
await updateMessagesAndConversation(
[returnedMessage],
conversationMessages,
conversation.name,
{ name: conversation.name },
conversation.id,
updatedConversations,
);
Expand Down Expand Up @@ -324,7 +320,7 @@ function Thread({
clearPrompt(result.updatedConversation, result.updatedConversations);
return;
}
const { modelName } = result;
const { modelName = selectedModelNameOrId } = result;

setErrorMessage({ ...errorMessage, [conversationId]: '' });
setIsProcessing({ ...isProcessing, [conversationId]: true });
Expand All @@ -334,10 +330,7 @@ function Thread({
currentPrompt.text,
currentPrompt.raw,
);
let message = createMessage(
{ role: 'assistant', name: modelName || selectedModelNameOrId },
'...',
);
let message = createMessage({ role: 'assistant', name: modelName }, '...');
message.status = MessageStatus.Pending;
userMessage.sibling = message.id;
message.sibling = userMessage.id;
Expand All @@ -349,12 +342,12 @@ function Thread({
} = await updateMessagesAndConversation(
[userMessage, message],
getConversationMessages(conversationId),
tempConversationName,
{ name: tempConversationName },
conversationId,
);
let updatedConversations = uc;
if (updatedConversation.temp) {
updatedConversation.name = getConversationTitle(updatedConversation);
updatedConversation.name = getConversationTitle(updatedConversation, t);
}

updatedConversations = clearPrompt(updatedConversation, updatedConversations);
Expand All @@ -372,6 +365,7 @@ function Thread({
updatedConversation,
updatedConversations,
currentPrompt,
modelName,
);

if (tempConversationId) {
Expand Down Expand Up @@ -407,7 +401,7 @@ function Thread({
await updateMessagesAndConversation(
[message],
conversationMessages,
tempConversationName,
{ name: tempConversationName },
conversationId,
);

Expand All @@ -421,6 +415,7 @@ function Thread({
updatedConversation,
updatedConversations,
prompt,
selectedModelNameOrId,
);

setIsProcessing({ ...isProcessing, [conversationId]: false });
Expand Down Expand Up @@ -491,7 +486,7 @@ function Thread({
const { updatedMessages } = await updateMessagesAndConversation(
newMessages,
conversationMessages,
tempConversationName,
{ name: tempConversationName },
conversationId,
conversations,
);
Expand All @@ -506,7 +501,7 @@ function Thread({
};

const handleUpdatePrompt = useCallback(
(prompt: ParsedPrompt | undefined, conversationName = 'Conversation') => {
(prompt: ParsedPrompt | undefined, conversationName = getDefaultConversationName(t)) => {
if (prompt?.raw === '' && tempConversationId) {
setChangedPrompt(undefined);
updateConversations(conversations.filter((c) => !c.temp));
Expand All @@ -525,25 +520,24 @@ function Thread({
updatedConversations = updateConversation(conversation, updatedConversations, true);
} else {
updatedConversations = conversations.filter((c) => !c.temp);
const newConversation = createConversation('Conversation');
updatedConversations.push(newConversation);
let newConversation = createConversation(conversationName);
newConversation.temp = true;
newConversation.name = conversationName;
newConversation.currentPrompt = prompt;
if (assistant) {
const newService = getDefaultAssistantService(assistant);
addConversationService(newConversation, newService);
newConversation = addConversationService(newConversation, newService);
setService(undefined);
} else if (service) {
addConversationService(newConversation, service);
newConversation = addConversationService(newConversation, service);
setService(undefined);
}
updatedConversations.push(newConversation);
setTempConversationId(newConversation.id);
}
updateConversations(updatedConversations);
setChangedPrompt(undefined);
},
[tempConversationId, conversationId, conversations, updateConversations, assistant, service],
[t, tempConversationId, conversationId, conversations, updateConversations, assistant, service],
);

useDebounceFunc<ParsedPrompt | undefined>(handleUpdatePrompt, changedPrompt, 500);
Expand Down
6 changes: 3 additions & 3 deletions webapp/context/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export type Context = {
updateMessagesAndConversation: (
changedMessages: Message[],
conversationMessages: Message[],
newConversationTitle: string,
partialConversation: Partial<Conversation>,
selectedConversationId: string,
selectedConversations?: Conversation[],
) => Promise<{
Expand Down Expand Up @@ -178,14 +178,14 @@ function AppContextProvider({ children }: { children: React.ReactNode }) {
async (
changedMessages: Message[],
conversationMessages: Message[],
newConversationTitle: string,
partialConversation: Partial<Conversation>,
selectedConversationId: string, // = conversationId,
selectedConversations = conversations,
) => {
const updatedConversations = updateOrCreateConversation(
selectedConversationId,
selectedConversations,
newConversationTitle, // messages?.[0]?.content as string,
partialConversation, // messages?.[0]?.content as string,
);
const updatedMessages = mergeMessages(conversationMessages, changedMessages);
await updateConversations(updatedConversations);
Expand Down
4 changes: 2 additions & 2 deletions webapp/hooks/useBackendContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
OplaContext,
ServerStatus,
Settings,
LlmResponse,
LlmCompletionResponse,
LlmStreamResponse,
Download,
ServerParameters,
Expand Down Expand Up @@ -202,7 +202,7 @@ function BackendProvider({ children }: { children: React.ReactNode }) {
return;
}
logger.info('stream event', event, backendContext, context);
const response = (await mapKeys(event.payload, toCamelCase)) as LlmResponse;
const response = (await mapKeys(event.payload, toCamelCase)) as LlmCompletionResponse;
if (!response.conversationId) {
logger.error('stream event without conversationId', response);
return;
Expand Down
91 changes: 81 additions & 10 deletions webapp/native/src/llm/llama_cpp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ use crate::{ error::Error, llm::LlmQueryCompletion, store::ServerParameters };

use tauri::Runtime;
use serde::{ Deserialize, Serialize };
use crate::llm::{ LlmQuery, LlmResponse, LlmUsage };
use crate::llm::{ LlmQuery, LlmCompletionResponse, LlmUsage };

use super::{ LlmCompletionOptions, LlmTokenizeResponse };

#[serde_with::skip_serializing_none]
#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -47,8 +49,23 @@ pub struct LlamaCppQueryCompletion {
}

impl LlmQueryCompletion {
fn to_llama_cpp_parameters(&self) -> LlamaCppQueryCompletion {
fn to_llama_cpp_parameters(
&self,
options: Option<LlmCompletionOptions>
) -> LlamaCppQueryCompletion {
let mut prompt = String::new();
match options {
Some(options) => {
match options.system {
Some(system) => {
prompt.push_str(&format!("{}\n", system));
}
None => {}
}
}
None => {}
}
// TODO: handle context_window_policy and keep_system
for message in &self.messages {
match message.role.as_str() {
"user" => {
Expand Down Expand Up @@ -127,8 +144,8 @@ pub struct LlamaCppChatCompletion {
}

impl LlamaCppChatCompletion {
pub fn to_llm_response(&self) -> LlmResponse {
LlmResponse {
pub fn to_llm_response(&self) -> LlmCompletionResponse {
LlmCompletionResponse {
created: None,
status: None,
content: self.content.clone(),
Expand All @@ -138,6 +155,22 @@ impl LlamaCppChatCompletion {
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LlamaCppQueryTokenize {
pub content: String,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LlamaCppTokenize {
pub tokens: Vec<u64>,
}
impl LlamaCppTokenize {
pub fn to_llm_response(&self) -> LlmTokenizeResponse {
LlmTokenizeResponse {
tokens: self.tokens.clone(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct LLamaCppServer {}

Expand All @@ -146,18 +179,23 @@ impl LLamaCppServer {
LLamaCppServer {}
}

fn get_api(&self, server_parameters: &ServerParameters, endpoint: String) -> String {
// TODO https support
format!("http://{:}:{:}/{}", server_parameters.host, server_parameters.port, endpoint)
}

pub async fn call_completion<R: Runtime>(
&mut self,
query: LlmQuery<LlmQueryCompletion>,
server_parameters: &ServerParameters
) -> Result<LlmResponse, Box<dyn std::error::Error>> {
let parameters = query.options.to_llama_cpp_parameters();
server_parameters: &ServerParameters,
completion_options: Option<LlmCompletionOptions>
) -> Result<LlmCompletionResponse, Box<dyn std::error::Error>> {
let parameters = query.options.to_llama_cpp_parameters(completion_options);

// TODO https support
let api = format!("http://{:}:{:}", server_parameters.host, server_parameters.port);
let api_url = self.get_api(server_parameters, query.command);
let client = reqwest::Client::new();
let res = client
.post(format!("{}/{}", api, query.command)) // TODO remove hardcoding
.post(api_url) // TODO remove hardcoding
.json(&parameters)
.send().await;
let response = match res {
Expand All @@ -178,4 +216,37 @@ impl LLamaCppServer {
};
Ok(response.to_llm_response())
}

pub async fn call_tokenize<R: Runtime>(
&mut self,
text: String,
server_parameters: &ServerParameters
) -> Result<LlmTokenizeResponse, Box<dyn std::error::Error>> {
let parameters = LlamaCppQueryTokenize {
content: text,
};
let api_url = self.get_api(server_parameters, "tokenize".to_owned());
let client = reqwest::Client::new();
let res = client
.post(api_url) // TODO remove hardcoding
.json(&parameters)
.send().await;
let response = match res {
Ok(res) => res,
Err(error) => {
println!("Failed to get Response: {}", error);
return Err(Box::new(Error::BadResponse));
}
};
let status = response.status();
println!("Response Status: {}", status);
let response = match response.json::<LlamaCppTokenize>().await {
Ok(r) => r,
Err(error) => {
println!("Failed to parse response: {}", error);
return Err(Box::new(Error::BadResponse));
}
};
Ok(response.to_llm_response())
}
}
Loading