Skip to content

Commit

Permalink
feat(thread): cancel using prompt spinning #1243 (#1251)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikbry authored Sep 12, 2024
1 parent a12ae9f commit 8a739f4
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 44 deletions.
82 changes: 54 additions & 28 deletions webapp/features/Threads/Conversation/ConversationContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type Context = {
submit: boolean,
) => Promise<void>;
handleStartMessageEdit: (messageId: string, index: number) => void;
handleCancelSending: (messageId: string) => void;
handleCancelSending: (conversationId: string, messageId: string | undefined) => void;
};

type ConversationProviderProps = {
Expand Down Expand Up @@ -110,6 +110,7 @@ function ConversationProvider({
const {
conversations,
messages: messagesCache,
updateConversationMessages,
updateMessagesAndConversation,
} = useThreadStore();
const { parseAndValidatePrompt, clearPrompt } = useContext(PromptContext) || {};
Expand Down Expand Up @@ -580,37 +581,61 @@ function ConversationProvider({
);

const handleCancelSending = useCallback(
async (messageId: string) => {
if (selectedConversation && selectedModelId) {
try {
await cancelSending(
messageId,
selectedConversation,
selectedModelId,
assistant,
modelStorage,
activeService,
);
} catch (e) {
logger.error(e);
const conversationMessages = messagesCache[selectedConversation.id];
const previousMessage = conversationMessages.find((m) => m.id === messageId);
if (!previousMessage) {
logger.error(
"Can't find previous message",
messageId,
async (cId: string, messageId: string | undefined) => {
const cancelMessageSending = async (mId: string) => {
if (selectedConversation && selectedModelId) {
try {
await cancelSending(
mId,
selectedConversation,
conversationMessages,
selectedModelId,
assistant,
modelStorage,
activeService,
);
} catch (e) {
logger.error(e);
const conversationMessages = messagesCache[selectedConversation.id];
const previousMessage = conversationMessages.find((m) => m.id === mId);
if (!previousMessage) {
logger.error(
"Can't find previous message",
mId,
selectedConversation,
conversationMessages,
);
return;
}
const updatedMessage = changeMessageContent(
previousMessage,
t('Cancelled'),
t('Cancelled'),
MessageStatus.Delivered,
);
await updateConversationMessages(
cId,
messagesCache[cId].map((m) => (m.id === mId ? updatedMessage : m)),
);
return;
}
changeMessageContent(
previousMessage,
t('Cancelled'),
t('Cancelled'),
MessageStatus.Delivered,
);
}
};

if (messageId) {
await cancelMessageSending(messageId);
}
if (messagesCache[cId]) {
const promises = messagesCache[cId]
.map((message) => {
if (
messageId !== message.id &&
(message.status === 'pending' || message.status === 'stream')
) {
return cancelMessageSending(message.id);
}
return undefined;
})
.filter((p) => !!p);
await Promise.all(promises);
}
},
[
Expand All @@ -621,6 +646,7 @@ function ConversationProvider({
selectedModelId,
t,
messagesCache,
updateConversationMessages,
],
);

Expand Down
4 changes: 2 additions & 2 deletions webapp/features/Threads/Conversation/ConversationList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type ConversationListProps = {
onChangeMessageContent: (m: Message, newContent: string, submit: boolean) => void;
onStartMessageEdit: (messageId: string, index: number) => void;
onCopyMessage: (messageId: string, state: boolean) => void;
onCancelSending: (messageId: string) => void;
onCancelSending: (conversationId: string, messageId: string) => void;
};

function ConversationList({
Expand Down Expand Up @@ -102,7 +102,7 @@ function ConversationList({
}}
onCopyMessage={onCopyMessage}
onCancelSending={() => {
onCancelSending(m.id);
onCancelSending(conversation.id, m.id);
}}
/>
))}
Expand Down
53 changes: 39 additions & 14 deletions webapp/features/Threads/Prompt/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,12 @@ export default function Prompt({
isModelLoading,
}: PromptProps) {
const { t } = useTranslation();
const { isProcessing, errorMessages, handleSendMessage: sendMessage } = useConversationContext();
const {
isProcessing,
errorMessages,
handleSendMessage: sendMessage,
handleCancelSending,
} = useConversationContext();
const textareaRef = useRef<HTMLTextAreaElement>(null);
const [needFocus, setNeedFocus] = useState(false);
const { usage, changedPrompt, conversationPrompt, setChangedPrompt, tokenValidator } =
Expand All @@ -67,6 +72,11 @@ export default function Prompt({
const { conversations, updateConversations, messages, updateConversationMessages } =
useThreadStore();

const handleCancel = (e: MouseEvent) => {
e.preventDefault();
handleCancelSending(conversationId, undefined);
};

const handleSendMessage = (e: MouseEvent) => {
e.preventDefault();
if (prompt) {
Expand Down Expand Up @@ -172,7 +182,8 @@ export default function Prompt({
return undefined;
}

let isLoading = conversationId ? isProcessing[conversationId] || false : false;
const isConnected = conversationId ? isProcessing[conversationId] || false : false;
let isLoading = false;
let placeholder;
if (isModelLoading || isModelLoading === undefined) {
isLoading = true;
Expand Down Expand Up @@ -228,36 +239,50 @@ export default function Prompt({
onValueChange={handleValueChange}
onFocus={handleFocus}
onKeyDown={handleKeypress}
disabled={isLoading || disabled}
disabled={isLoading || isConnected || disabled}
/>
</PromptCommands>
<Tooltip>
<TooltipTrigger asChild>
<Button
disabled={isLoading || prompt?.raw?.length === 0}
disabled={isLoading || (prompt?.raw?.length === 0 && !isConnected)}
type="button"
aria-label={t('Send')}
onClick={handleSendMessage}
onClick={isConnected ? handleCancel : handleSendMessage}
className="ml-2"
size="icon"
variant="outline"
>
{isLoading ? (
{isLoading || isConnected ? (
<Loader2 strokeWidth={1.5} className="loading-icon h-4 w-4 animate-spin" />
) : (
<SendHorizontal className="strokeWidth={1.5} h-4 w-4" />
)}
</Button>
</TooltipTrigger>
<TooltipContent side="right" sideOffset={12} className="mt-1">
<div className="flex w-full flex-row items-center justify-between gap-2 pb-2">
<p>{t(shortcutSend.description)}</p>
<ShortcutBadge command={shortcutSend.command} />
</div>
<div className="flex w-full flex-row items-center justify-between gap-2">
<p>{t(shortcutNewLine.description)}</p>
<ShortcutBadge command={shortcutNewLine.command} />
</div>
{!isLoading && !isConnected && (
<>
<div className="flex w-full flex-row items-center justify-between gap-2 pb-2">
<p>{t(shortcutSend.description)}</p>
<ShortcutBadge command={shortcutSend.command} />
</div>
<div className="flex w-full flex-row items-center justify-between gap-2">
<p>{t(shortcutNewLine.description)}</p>
<ShortcutBadge command={shortcutNewLine.command} />
</div>
</>
)}
{isLoading && (
<div className="flex w-full flex-row items-center justify-between gap-2 pb-2">
<p>{t('Please wait...')}</p>
</div>
)}
{isProcessing && (
<div className="flex w-full flex-row items-center justify-between gap-2 pb-2">
<p>{t('Cancel')}</p>
</div>
)}
</TooltipContent>
</Tooltip>
</div>
Expand Down

0 comments on commit 8a739f4

Please sign in to comment.