Skip to content

Commit

Permalink
feat: handle command actions and send/save /system message
Browse files Browse the repository at this point in the history
feat: refactor prompt commands management
  • Loading branch information
mikbry authored Feb 29, 2024
2 parents 5d88d1a + 8017363 commit 92af476
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 77 deletions.
8 changes: 4 additions & 4 deletions webapp/components/views/Threads/Prompt.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import { KeyBinding, ShortcutIds, defaultShortcuts } from '@/hooks/useShortcuts'
import logger from '@/utils/logger';
import { ParsedPrompt, TokenValidator, parsePrompt } from '@/utils/parsers';
import { getCaretPosition } from '@/utils/caretposition';
import { Ui } from '@/types';
import { CommandManager } from '@/utils/commands/types';
import { Button } from '../../ui/button';
import { Tooltip, TooltipContent, TooltipTrigger } from '../../ui/tooltip';
import { ShortcutBadge } from '../../common/ShortCut';
Expand All @@ -30,7 +30,7 @@ import PromptCommandInput from './PromptCommandInput';
export type PromptProps = {
conversationId: string;
prompt: ParsedPrompt;
commands: Ui.MenuItem[];
commandManager: CommandManager;
isLoading: boolean;
errorMessage: string;
disabled: boolean;
Expand All @@ -43,7 +43,7 @@ export type PromptProps = {
export default function Prompt({
conversationId,
prompt,
commands,
commandManager,
errorMessage,
disabled,
onUpdatePrompt,
Expand Down Expand Up @@ -119,7 +119,7 @@ export default function Prompt({
</Button>
<PromptCommandInput
value={prompt}
commands={commands}
commandManager={commandManager}
placeholder={t('Send a message...')}
className="m-0 max-h-[240px] min-h-[36px] "
onChange={handleUpdateMessage}
Expand Down
38 changes: 19 additions & 19 deletions webapp/components/views/Threads/PromptCommandInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@

import { ChangeEvent, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { isCommand, ParsedPrompt, parsePrompt, TokenValidator } from '@/utils/parsers';
import { PromptCommand } from '@/utils/parsers/promptCommand';
import { CommandManager } from '@/utils/commands/types';
import { getCaretCoordinates, getCurrentWord } from '@/utils/caretposition';
import { cn } from '@/lib/utils';
import logger from '@/utils/logger';
import useTranslation from '@/hooks/useTranslation';
import { getTokenColor } from '@/utils/ui';
import { getCommandType } from '@/utils/commands';
import { Textarea } from '../../ui/textarea';
import { Button } from '../../ui/button';

type PromptCommandProps = {
value?: ParsedPrompt;
placeholder?: string;
notFound?: string;
commands: PromptCommand[];
commandManager: CommandManager;
onChange?: (parsedPrompt: ParsedPrompt) => void;
onKeyDown: (event: KeyboardEvent) => void;
className?: string;
Expand All @@ -42,8 +42,7 @@ type PromptCommandProps = {
function PromptCommandInput({
value,
placeholder,
notFound = 'No command found.',
commands,
commandManager,
className,
onChange,
onFocus,
Expand Down Expand Up @@ -114,10 +113,8 @@ function PromptCommandInput({
if (textarea && dropdown) {
const { currentWord, caretStartIndex } = getCurrentWord(textarea);
valueChange(text, caretStartIndex);
const start = value?.text.trim().length || 0;
logger.info('value length', start);
if (isCommand(currentWord, start)) {
logger.info('isCommand', currentWord, start, commandValue);
const start = text.trim().length - currentWord.length;
if (value && !value.locked && isCommand(currentWord, start)) {
setCommandValue(currentWord);
positionDropdown();
toggleDropdown();
Expand All @@ -126,7 +123,7 @@ function PromptCommandInput({
}
}
},
[commandValue, positionDropdown, value?.text, valueChange],
[commandValue, positionDropdown, value, valueChange],
);

const handleCommandSelect = useCallback(
Expand Down Expand Up @@ -171,13 +168,13 @@ function PromptCommandInput({
if (textarea && dropdown) {
const { currentWord } = getCurrentWord(textarea);

const start = value?.text.trim().length || 0;
logger.info('isCommand selection change', currentWord, commandValue, start);
const start = textarea.value.trim().length - currentWord.length;

if (!isCommand(currentWord, start) && commandValue !== '') {
toggleDropdown(false);
}
}
}, [commandValue, value?.text]);
}, [commandValue]);

useEffect(() => {
const textarea = textareaRef.current;
Expand All @@ -195,12 +192,11 @@ function PromptCommandInput({
}, [handleBlur, handleKeyDown, handleMouseDown, handleSelectionChange]);

const filteredCommands = useMemo(
() =>
commands.filter(
(c) => !(!c.value || c.value?.toLowerCase().indexOf(commandValue.toLowerCase()) === -1),
),
[commands, commandValue],
() => commandManager.filterCommands(commandValue),
[commandManager, commandValue],
);
const commandType = getCommandType(commandValue);
const notFound = commandType ? `No ${commandType}s found.` : '';
return (
<div className="h-full w-full overflow-visible">
<Textarea
Expand Down Expand Up @@ -235,7 +231,11 @@ function PromptCommandInput({
)}
>
<div className="gap-2">
{filteredCommands.length === 0 && <div>{t(notFound)}</div>}
{filteredCommands.length === 0 && (
<div className="rounded-sm px-2 py-1.5 text-left text-sm outline-none aria-selected:bg-accent aria-selected:text-accent-foreground data-[disabled]:pointer-events-none data-[disabled]:opacity-50">
{t(notFound)}
</div>
)}
{filteredCommands.length > 0 &&
filteredCommands.map((item) => (
<Button
Expand Down
88 changes: 59 additions & 29 deletions webapp/components/views/Threads/Thread.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import {
PromptTemplate,
Provider,
ProviderType,
Ui,
} from '@/types';
import useTranslation from '@/hooks/useTranslation';
import logger from '@/utils/logger';
Expand Down Expand Up @@ -59,13 +58,13 @@ import {
PromptTokenType,
compareMentions,
comparePrompts,
getMentionName,
parsePrompt,
toPrompt,
} from '@/utils/parsers';
import { getConversationTitle } from '@/utils/conversations';
import validator from '@/utils/parsers/validator';
import { createMessage, changeMessageContent, mergeMessages } from '@/utils/data/messages';
import { getCommandManager } from '@/utils/commands';
import PromptArea from './Prompt';
import PromptsGrid from './PromptsGrid';
import ThreadMenu from './ThreadMenu';
Expand Down Expand Up @@ -165,26 +164,10 @@ function Thread({
const showEmptyChat = !conversationId;
const selectedModel = selectedConversation?.model || activeModel;

const { modelItems, commands } = useMemo(() => {
const { modelItems, commandManager } = useMemo(() => {
const items = getModelsAsItems(providers, backendContext, selectedModel);
const parameterItems: Ui.MenuItem[] = [
{ value: '#provider_key', label: 'Provider key', group: 'parameters-string' },
{ value: '#stream', label: 'Stream', group: 'parameters-boolean' },
];
const actionsItems: Ui.MenuItem[] = [{ value: '/system', label: 'System', group: 'actions' }];
const cmds = [
...items
.filter((item) => !item.selected)
.map((item) => ({
...item,
value: getMentionName(item.value as string),
group: 'models',
})),
...parameterItems,
...actionsItems,
];

return { modelItems: items, commands: cmds };
const manager = getCommandManager(items);
return { modelItems: items, commandManager: manager };
}, [backendContext, providers, selectedModel]);

useEffect(() => {
Expand All @@ -208,8 +191,8 @@ function Thread({
parsedPrompt: ParsedPrompt,
_previousToken: PromptToken | undefined,
): [PromptToken, PromptToken | undefined] =>
validator(commands, token, parsedPrompt, _previousToken),
[commands],
validator(commandManager, token, parsedPrompt, _previousToken),
[commandManager],
);

const currentPrompt = useMemo(
Expand Down Expand Up @@ -357,10 +340,58 @@ function Thread({
return returnedMessage;
};

const clearPrompt = (
conversation: Conversation | undefined,
newConversations = conversations,
) => {
setChangedPrompt(undefined);

let updatedConversations = newConversations;
if (conversation) {
updatedConversations = updateConversation(
{ ...conversation, currentPrompt: undefined, temp: false },
newConversations,
);
updateConversations(updatedConversations);
}
return updatedConversations;
};

const handleSendMessage = async () => {
if (conversationId === undefined) {
return;
}
const action = currentPrompt.tokens.find((to) => to.type === PromptTokenType.Action);

if (action) {
let updatedConversation = selectedConversation;
let updatedConversations = conversations;
const command = commandManager.getCommand(action.value, action.type);
// logger.info('command action', command, action.value, action.type, currentPrompt.text);
if (command) {
command.execute?.(action.value);
if (command.label === 'System') {
const message = createMessage(
{ role: 'system', name: 'system' },
currentPrompt.text,
currentPrompt.raw,
);
let updatedConversationId: string | undefined;
({ updatedConversationId, updatedConversations } = await updateMessagesAndConversation(
[message],
getConversationMessages(conversationId),
));

updatedConversation = getConversation(
updatedConversationId,
updatedConversations,
) as Conversation;
}
}
clearPrompt(updatedConversation, updatedConversations);
return;
}

const mentions = currentPrompt.tokens.filter((to) => to.type === PromptTokenType.Mention);
const modelItem =
mentions.length === 1
Expand Down Expand Up @@ -424,12 +455,14 @@ function Thread({
conversation.name = getConversationTitle(conversation);
}

conversation.currentPrompt = undefined;
/* conversation.currentPrompt = undefined;
setChangedPrompt(undefined);
conversation.temp = false;
updatedConversations = updateConversation(conversation, updatedConversations);
updateConversations(updatedConversations);
updateConversations(updatedConversations); */
updatedConversations = clearPrompt(conversation, updatedConversations);

logger.info('onSendMessage', updatedMessages, conversation);
message = await sendMessage(message, updatedMessages, conversation, updatedConversations);

Expand Down Expand Up @@ -523,9 +556,6 @@ function Thread({

const conversation = getConversation(conversationId, conversations);
if (conversation && message.content) {
/* const { contentHistory = [] } = message;
contentHistory.push(message.content);
const newMessage = { ...message, content: newContent, contentHistory }; */
const parsedContent = parsePrompt({ text: newContent }, tokenValidator);
const newMessage = changeMessageContent(message, parsedContent.text, parsedContent.raw);
const conversationMessages = getConversationMessages(conversationId);
Expand Down Expand Up @@ -717,7 +747,7 @@ function Thread({
<PromptArea
conversationId={conversationId as string}
disabled={disabled}
commands={commands}
commandManager={commandManager}
prompt={prompt}
isLoading={conversationId ? isProcessing[conversationId] : false}
errorMessage={conversationId ? errorMessage[conversationId] : ''}
Expand Down
97 changes: 97 additions & 0 deletions webapp/utils/commands/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright 2024 mik
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import { getMentionName } from '../parsers';
import { Command, CommandManager, CommandType } from './types';

const actionsItems: Command[] = [
{
value: '/system',
label: 'System',
group: 'actions',
type: CommandType.Action,
validate: () => true,
},
];
const parameterItems: Command[] = [
{
value: '#provider_key',
label: 'Provider key',
group: 'parameters-string',
type: CommandType.Parameter,
},
{ value: '#stream', label: 'Stream', group: 'parameters-boolean', type: CommandType.Parameter },
];

export const getActionCommands = (): Command[] => {
const actions = actionsItems.map((item) => item);
return actions;
};

export const getHashtagCommands = (): Command[] => {
const parameters = parameterItems.map((item) => item);
return parameters;
};

export const getCommandType = (value: string | undefined): CommandType | undefined => {
if (value?.startsWith('/')) {
return CommandType.Action;
}
if (value?.startsWith('#')) {
return CommandType.Parameter;
}
if (value?.startsWith('@')) {
return CommandType.Mention;
}
return undefined;
};

export const compareCommands = (
command1: string | undefined,
command2: string | string,
type?: CommandType,
): boolean => {
const type1 = getCommandType(command1);
const type2 = getCommandType(command2);
return (!type || (type === type1 && type === type2)) && command1 === command2;
};

export const getCommandManager = (mentionItems: Partial<Command>[]): CommandManager => {
const commands = [
...mentionItems
.filter((item) => !item.selected)
.map(
(item) =>
({
...item,
value: getMentionName(item.value as string),
group: 'models',
type: CommandType.Mention,
}) as Command,
),
...getHashtagCommands(),
...getActionCommands(),
];
return {
commands,
getCommand: (value: string, type: string) => {
const command = commands.find((m) => compareCommands(m.value, value, type as CommandType));
return command;
},
filterCommands: (commandValue: string): Command[] =>
commands.filter(
(c) => !(!c.value || c.value?.toLowerCase().indexOf(commandValue.toLowerCase()) === -1),
),
};
};
Loading

0 comments on commit 92af476

Please sign in to comment.