Skip to content

Commit

Permalink
feat: use assistant and selected target when sending message
Browse files Browse the repository at this point in the history
  • Loading branch information
mikbry authored Mar 6, 2024
2 parents 365bf0f + 0021e19 commit 2ef100b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 32 deletions.
18 changes: 9 additions & 9 deletions webapp/components/views/Threads/Menu/AssistantMenu.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ import { cn } from '@/lib/utils';
import { useAssistantStore } from '@/stores';
import AssistantIcon from '@/components/common/AssistantIcon';
import { getAssistantTargetsAsItems, getDefaultAssistantService } from '@/utils/data/assistants';
import { updateConversation } from '@/utils/data/conversations';
import { addConversationService, updateConversation } from '@/utils/data/conversations';
import { Badge } from '../../../ui/badge';
import { ShortcutBadge } from '../../../common/ShortCut';
import Pastille from '../../../common/Pastille';
Expand All @@ -88,10 +88,10 @@ export default function AssistantMenu({
const assistant = getAssistant(selectedAssistantId) as Assistant;
const service = conversation?.services?.[0] || getDefaultAssistantService(assistant);
const selectedTargetId =
service.type === AIServiceType.Assistant ? service.targetId : _selectedTargetId;
const target = assistant?.targets?.find(
(t) => t.id === selectedTargetId || assistant?.targets?.[0].id,
);
service.type === AIServiceType.Assistant
? service.targetId
: _selectedTargetId || assistant?.targets?.[0].id;
const target = assistant?.targets?.find((t) => t.id === selectedTargetId);
const targetState = target && !target.disabled ? Ui.BasicState.active : Ui.BasicState.disabled;
const [open, setOpen] = useState(false);
const { t } = useTranslation();
Expand Down Expand Up @@ -130,10 +130,10 @@ export default function AssistantMenu({

const handleSelectAssistantTarget = async (item: Ui.MenuItem) => {
const targetId = item.value as string;
const newConversation: Conversation = {
...conversation,
services: [{ ...service, targetId } as AIService],
};
const newConversation: Conversation = addConversationService(conversation, {
...service,
targetId,
} as AIService);
const newConversations = updateConversation(newConversation, conversations);
updateConversations(newConversations);
};
Expand Down
40 changes: 29 additions & 11 deletions webapp/components/views/Threads/Thread.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import {
import useBackend from '@/hooks/useBackendContext';
import { completion } from '@/utils/providers';
import { getModelsAsItems } from '@/utils/data/models';
import { getActiveService } from '@/utils/services';
import { getActiveService, getAssistantId } from '@/utils/services';
import { toast } from '@/components/ui/Toast';
import useDebounceFunc from '@/hooks/useDebounceFunc';
import { ModalData, ModalsContext } from '@/context/modals';
Expand All @@ -59,6 +59,7 @@ import {
import { getCommandManager, preProcessingCommands } from '@/utils/commands';
import ContentView from '@/components/common/ContentView';
import { useAssistantStore } from '@/stores';
import { getDefaultAssistantService } from '@/utils/data/assistants';
import PromptArea from './Prompt';
import { ConversationPanel } from './Conversation';
import ThreadMenu from './Menu';
Expand Down Expand Up @@ -90,14 +91,14 @@ function Thread({
} = useContext(AppContext);
const { backendContext, setActiveModel } = useBackend();
const searchParams = useSearchParams();
const assistantId = searchParams?.get('assistant') || undefined;
const { getAssistant } = useAssistantStore();
const assistant = getAssistant(assistantId);
const [service, setService] = useState<AIService | undefined>(undefined);
const activeModel = getServiceModelId(service) || backendContext.config.models.activeModel;
const [tempConversationId, setTempConversationId] = useState<string | undefined>(undefined);
const conversationId = _conversationId || tempConversationId;
const selectedConversation = conversations.find((c) => c.id === conversationId);
const assistantId = searchParams?.get('assistant') || getAssistantId(selectedConversation);
const { getAssistant } = useAssistantStore();
const assistant = getAssistant(assistantId);
const [changedPrompt, setChangedPrompt] = useState<ParsedPrompt | undefined>(undefined);
const { showModal } = useContext(ModalsContext);
const [messages, setMessages] = useState<Message[] | undefined>(undefined);
Expand Down Expand Up @@ -149,13 +150,12 @@ function Thread({

const tempConversationName = messages?.[0]?.content as string;

const selectedModelNameOrId = getConversationModelId(selectedConversation) || activeModel;

const { modelItems, commandManager } = useMemo(() => {
const selectedModelNameOrId = getConversationModelId(selectedConversation) || activeModel;
const items = getModelsAsItems(providers, backendContext, selectedModelNameOrId);
const manager = getCommandManager(items);
return { modelItems: items, commandManager: manager };
}, [backendContext, providers, selectedModelNameOrId]);
}, [activeModel, backendContext, providers, selectedConversation]);

useEffect(() => {
if (_conversationId && tempConversationId) {
Expand Down Expand Up @@ -304,6 +304,8 @@ function Thread({
return;
}

const selectedModelNameOrId =
getConversationModelId(selectedConversation, assistant) || activeModel;
const result = await preProcessingCommands(
conversationId,
currentPrompt,
Expand Down Expand Up @@ -357,7 +359,13 @@ function Thread({

updatedConversations = clearPrompt(updatedConversation, updatedConversations);

logger.info('onSendMessage', updatedMessages, updatedConversation);
logger.info(
'onSendMessage',
modelName,
selectedModelNameOrId,
updatedMessages,
updatedConversation,
);
message = await sendMessage(
message,
updatedMessages,
Expand All @@ -383,13 +391,18 @@ function Thread({
setErrorMessage({ ...errorMessage, [conversationId]: '' });
setIsProcessing({ ...isProcessing, [conversationId]: true });

const selectedModelNameOrId =
getConversationModelId(selectedConversation, assistant) || activeModel;

let message: Message = changeMessageContent(
previousMessage,
'...',
'...',
MessageStatus.Pending,
);

if (selectedModelNameOrId && message.author.name !== selectedModelNameOrId) {
message.author.name = selectedModelNameOrId;
}
const { updatedConversation, updatedConversations, updatedMessages } =
await updateMessagesAndConversation(
[message],
Expand Down Expand Up @@ -517,7 +530,11 @@ function Thread({
newConversation.temp = true;
newConversation.name = conversationName;
newConversation.currentPrompt = prompt;
if (service) {
if (assistant) {
const newService = getDefaultAssistantService(assistant);
addConversationService(newConversation, newService);
setService(undefined);
} else if (service) {
addConversationService(newConversation, service);
setService(undefined);
}
Expand All @@ -526,7 +543,7 @@ function Thread({
updateConversations(updatedConversations);
setChangedPrompt(undefined);
},
[tempConversationId, conversationId, conversations, updateConversations, service],
[tempConversationId, conversationId, conversations, updateConversations, assistant, service],
);

useDebounceFunc<ParsedPrompt | undefined>(handleUpdatePrompt, changedPrompt, 500);
Expand All @@ -538,6 +555,7 @@ function Thread({
};

const prompt = changedPrompt === undefined ? currentPrompt : changedPrompt;
const selectedModelNameOrId = getConversationModelId(selectedConversation) || activeModel;
return (
<ContentView
header={
Expand Down
7 changes: 5 additions & 2 deletions webapp/components/views/Threads/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { ModalIds } from '@/modals';
import { ModalsContext } from '@/context/modals';
import { AppContext } from '@/context';
import { MenuAction, Page, ViewName } from '@/types/ui';
import { getAssistantId } from '@/utils/services';
import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from '../../ui/resizable';
import Explorer from './Explorer';
import Settings from './Settings';
Expand All @@ -45,8 +46,6 @@ type ThreadsProps = {
export default function Threads({ selectedThreadId, view = ViewName.Recent }: ThreadsProps) {
const router = useRouter();
const { id } = router.query;
const searchParams = useSearchParams();
const assistantId = searchParams?.get('assistant') || undefined;
const [errors, setError] = useState<string[]>([]);
const handleError = (error: string) => {
setError([...errors, error]);
Expand All @@ -64,6 +63,10 @@ export default function Threads({ selectedThreadId, view = ViewName.Recent }: Th
} = useContext(AppContext);
const { backendContext, setSettings } = useBackend();

const searchParams = useSearchParams();
const selectedConversation = conversations.find((c) => c.id === selectedThreadId);
const assistantId = searchParams?.get('assistant') || getAssistantId(selectedConversation);

useShortcuts(ShortcutIds.DELETE_MESSAGE, (event) => {
event.preventDefault();
logger.info('TODO delete Message');
Expand Down
19 changes: 17 additions & 2 deletions webapp/utils/data/conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
AIService,
AIServiceType,
ProviderType,
Assistant,
} from '@/types';
import { createBaseRecord, createBaseNamedRecord, updateRecord } from '.';

Expand Down Expand Up @@ -184,12 +185,26 @@ export const getServiceProvider = (modelService: AIService | undefined) => {
return undefined;
};

export const getConversationModelId = (conversation: Conversation | undefined) => {
export const getConversationModelId = (
conversation: Conversation | undefined,
assistant?: Assistant,
) => {
if (!conversation) {
return undefined;
}
const modelService = getConversationService(conversation, AIServiceType.Model);
return getServiceModelId(modelService);
let modelId = getServiceModelId(modelService);
if (!modelId) {
const assistantService = getConversationService(conversation, AIServiceType.Assistant);
if (assistantService?.type === AIServiceType.Assistant) {
const { assistantId, targetId } = assistantService;
if (assistantId && assistantId === assistant?.id) {
const target = assistant.targets?.find((t) => t.id === targetId);
modelId = target?.models?.[0];
}
}
}
return modelId;
};

export const getConversationProvider = (conversation: Conversation | undefined) => {
Expand Down
5 changes: 3 additions & 2 deletions webapp/utils/providers/opla/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {
} from '@/types';
import { mapKeys } from '@/utils/data';
import logger from '@/utils/logger';
import { toCamelCase } from '@/utils/string';
import { toCamelCase, toSnakeCase } from '@/utils/string';
import { invokeTauri } from '@/utils/backend/tauri';
import { z } from 'zod';

Expand Down Expand Up @@ -233,9 +233,10 @@ const completion = async (
parameters,
};

const llmProvider = mapKeys(provider, toSnakeCase);
const response: LlmResponse = (await invokeTauri('llm_call_completion', {
model: model.name,
llmProvider: provider,
llmProvider,
query: { command: 'completion', options },
})) as LlmResponse;

Expand Down
33 changes: 27 additions & 6 deletions webapp/utils/services/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,36 @@ export const getActiveService = (
if (provider) {
providerName = provider.name;
}
} else if (activeService && activeService.type === AIServiceType.Assistant) {
const { assistantId, targetId } = activeService;
if (assistantId && targetId) {
const target = assistant?.targets?.find((t) => t.id === targetId);
if (target?.models && target.models.length > 0) {
model = findModelInAll(target.models[0], providers, backendContext);
provider = findProvider(target.provider, providers);
providerName = provider?.name;
}
}
}
const modelName = _modelName || model?.name || conversation.model || activeModel;
if (!model || model.name !== modelName) {
model = findModelInAll(modelName, providers, backendContext);
}
const name = model?.provider || model?.creator;
if (name && name !== providerName) {
provider = findProvider(name, providers);
if (!assistant && !provider) {
if (!model || model.name !== modelName) {
model = findModelInAll(modelName, providers, backendContext);
}
const modelProviderName = model?.provider || model?.creator;
if (modelProviderName && modelProviderName !== providerName) {
provider = findProvider(modelProviderName, providers);
}
}

return { ...activeService, model, provider } as AIImplService;
};

export const getAssistantId = (conversation: Conversation | undefined): string | undefined => {
let assistantId: string | undefined;
if (conversation?.services) {
const service = conversation.services.find((c) => c.type === AIServiceType.Assistant);
if (service?.type === AIServiceType.Assistant) assistantId = service?.assistantId;
}
return assistantId;
};

0 comments on commit 2ef100b

Please sign in to comment.