diff --git a/webapp/components/views/Threads/Explorer/AssistantsList.tsx b/webapp/components/views/Threads/Explorer/AssistantsList.tsx index 991309a2..cb3c1042 100644 --- a/webapp/components/views/Threads/Explorer/AssistantsList.tsx +++ b/webapp/components/views/Threads/Explorer/AssistantsList.tsx @@ -34,7 +34,7 @@ export default function AssistantsList({ selectedId }: AssistantsListProps) { + @@ -43,6 +43,9 @@ export default function AssistantsList({ selectedId }: AssistantsListProps) { selectedId={selectedId || OplaAssistant.id} items={assistants} + getItemTitle={(assistant) => + assistant.id === OplaAssistant.id ? t('Use your local AI Models') : assistant.name + } renderLeftSide={(assistant) => assistant.id === OplaAssistant.id ? ( diff --git a/webapp/components/views/Threads/Thread.tsx b/webapp/components/views/Threads/Thread.tsx index 52f9e923..79619b81 100644 --- a/webapp/components/views/Threads/Thread.tsx +++ b/webapp/components/views/Threads/Thread.tsx @@ -21,6 +21,8 @@ import { AppContext } from '@/context'; import { Asset, Conversation, + ConversationConnector, + ConversationConnectorType, LlmParameters, Message, MessageStatus, @@ -38,6 +40,11 @@ import { updateOrCreateConversation, addAssetsToConversation, getConversationAssets, + getConnectorModelId, + getConversationModelId, + addConnector, + getConversationProvider, + addConversationConnector, } from '@/utils/data/conversations'; import useBackend from '@/hooks/useBackendContext'; import { buildContext, completion, getCompletionParametersDefinition } from '@/utils/providers'; @@ -105,10 +112,11 @@ function Thread({ } = useContext(AppContext); const { backendContext, setActiveModel } = useBackend(); const { activeModel: aModel } = backendContext.config.models; - const [tempModelProvider, setTempModelProvider] = useState<[string, ProviderType] | undefined>( + /* const [tempModelProvider, setTempModelProvider] = useState<[string, ProviderType] | undefined>( undefined, - ); - const activeModel = tempModelProvider?.[0] || aModel; + ); */ + const [connector, setConnector] = useState(undefined); + const activeModel = getConnectorModelId(connector) || aModel; const [tempConversationId, setTempConversationId] = useState(undefined); const conversationId = _conversationId || tempConversationId; const selectedConversation = conversations.find((c) => c.id === conversationId); @@ -162,7 +170,7 @@ function Thread({ ]); const showEmptyChat = !conversationId; - const selectedModel = selectedConversation?.model || activeModel; + const selectedModel = getConversationModelId(selectedConversation) || activeModel; const { modelItems, commandManager } = useMemo(() => { const items = getModelsAsItems(providers, backendContext, selectedModel); @@ -228,9 +236,15 @@ function Thread({ `handleSelectModel ${model} ${provider} activeModel=${typeof activeModel}`, selectedConversation, ); + const newConnector: ConversationConnector = { + type: ConversationConnectorType.Model, + modelId: model as string, + provider, + }; if (model && selectedConversation) { + const connectors = addConnector(selectedConversation.connectors, newConnector); const newConversations = updateConversation( - { ...selectedConversation, model, provider, parameters: {}, ...partial }, + { ...selectedConversation, connectors, parameters: {}, ...partial }, conversations, true, ); @@ -238,7 +252,8 @@ function Thread({ } else if (model && !activeModel) { await setActiveModel(model); } else if (model) { - setTempModelProvider([model, provider]); + // setTempModelProvider([model, provider]); + setConnector(newConnector); } }; @@ -252,14 +267,16 @@ function Thread({ let providerName: string | undefined = model?.provider; const returnedMessage = { ...message }; let provider: Provider | undefined; - if (conversation.provider && conversation.model) { - provider = findProvider(conversation.provider, providers); - model = findModel(conversation.model, provider?.models || []); + const conversationProvider = getConversationProvider(conversation); + const conversationModel = getConversationModelId(conversation); + if (conversationProvider && conversationModel) { + provider = findProvider(conversationProvider, providers); + model = findModel(conversationModel, provider?.models || []); if (provider) { providerName = provider.name; } } - const modelName = message.author.name || model?.name || conversation.model || activeModel; + const modelName = message.author.name || model?.name || conversationModel || activeModel; if (!model || model.name !== modelName) { model = findModelInAll(modelName, providers, backendContext); } @@ -628,16 +645,20 @@ function Thread({ newConversation.temp = true; newConversation.name = conversationName; newConversation.currentPrompt = prompt; - if (tempModelProvider) { + /* if (tempModelProvider) { [newConversation.model, newConversation.provider] = tempModelProvider; setTempModelProvider(undefined); + } */ + if (connector) { + addConversationConnector(newConversation, connector); + setConnector(undefined); } setTempConversationId(newConversation.id); } updateConversations(updatedConversations); setChangedPrompt(undefined); }, - [tempConversationId, conversationId, conversations, updateConversations, tempModelProvider], + [tempConversationId, conversationId, conversations, updateConversations, connector], ); useDebounceFunc(handleUpdatePrompt, changedPrompt, 500); diff --git a/webapp/components/views/Threads/index.tsx b/webapp/components/views/Threads/index.tsx index 7403ecec..c5033951 100644 --- a/webapp/components/views/Threads/index.tsx +++ b/webapp/components/views/Threads/index.tsx @@ -198,7 +198,8 @@ export default function Threads({ selectedThreadId, view = ViewName.Recent }: Th return ( diff --git a/webapp/types/index.ts b/webapp/types/index.ts index d33decb7..6e094c46 100644 --- a/webapp/types/index.ts +++ b/webapp/types/index.ts @@ -145,14 +145,39 @@ export type ConversationUsage = { totalPerSecond?: number; }; +export enum ConversationConnectorType { + Model = 'model', + Assistant = 'assistant', +} + +export type ConversationConnector = { + disabled?: boolean; +} & ( + | { + type: ConversationConnectorType.Model; + modelId: string; + provider?: ProviderType; + } + | { + type: ConversationConnectorType.Assistant; + assistantId: string; + targetId?: string; + } +); + export type Conversation = BaseNamedRecord & { messages: Message[] | undefined; pluginIds?: string[]; preset?: string; currentPrompt?: string | ParsedPrompt; note?: string; + + // Deprecated replaced by connectors model?: string; provider?: string; + + connectors?: ConversationConnector[]; + importedFrom?: string; temp?: boolean; diff --git a/webapp/utils/data/conversations.ts b/webapp/utils/data/conversations.ts index e9e0ceee..9953ca6a 100644 --- a/webapp/utils/data/conversations.ts +++ b/webapp/utils/data/conversations.ts @@ -11,7 +11,14 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -import { Asset, ContextWindowPolicy, Conversation } from '@/types'; +import { + Asset, + ContextWindowPolicy, + Conversation, + ConversationConnector, + ConversationConnectorType, + ProviderType, +} from '@/types'; import { createBaseRecord, createBaseNamedRecord, updateRecord } from '.'; export const getConversationAssets = (conversation: Conversation) => @@ -110,3 +117,81 @@ export const mergeConversations = ( }); return Array.from(conversationMap.values()); }; + +export const getConversationConnector = ( + conversation: Conversation, + connectorType: ConversationConnectorType, +) => { + let connector = conversation.connectors?.find((c) => c.type === connectorType); + if (!connector) { + if (conversation.model && connectorType === ConversationConnectorType.Model) { + connector = { + type: connectorType, + modelId: conversation.model, + provider: conversation.provider as ProviderType, + }; + } + } + return connector; +}; + +export const addConnector = ( + _connectors: ConversationConnector[] | undefined, + connector: ConversationConnector, +): ConversationConnector[] => { + const connectors = _connectors || []; + const index = connectors?.findIndex((c) => c.type === connector.type) ?? -1; + if (index !== -1) { + connectors[index] = connector; + } else { + connectors.push(connector); + } + return connectors; +}; + +export const addConversationConnector = ( + conversation: Conversation, + connector: ConversationConnector, +): Conversation => { + const index = conversation.connectors?.findIndex((c) => c.type === connector.type) ?? -1; + const connectors = conversation.connectors || []; + if (index !== -1) { + connectors[index] = connector; + } else { + connectors.push(connector); + } + return { + ...conversation, + connectors, + }; +}; + +export const getConnectorModelId = (modelConnector: ConversationConnector | undefined) => { + if (modelConnector && modelConnector.type === ConversationConnectorType.Model) { + return modelConnector.modelId; + } + return undefined; +}; + +export const getConnectorProvider = (modelConnector: ConversationConnector | undefined) => { + if (modelConnector && modelConnector.type === ConversationConnectorType.Model) { + return modelConnector.provider; + } + return undefined; +}; + +export const getConversationModelId = (conversation: Conversation | undefined) => { + if (!conversation) { + return undefined; + } + const modelConnector = getConversationConnector(conversation, ConversationConnectorType.Model); + return getConnectorModelId(modelConnector); +}; + +export const getConversationProvider = (conversation: Conversation | undefined) => { + if (!conversation) { + return undefined; + } + const modelConnector = getConversationConnector(conversation, ConversationConnectorType.Model); + return getConnectorProvider(modelConnector); +};