From fadeba9ef933bcb621ee6fa11953d086bf95b2ef Mon Sep 17 00:00:00 2001 From: Luke Alvoeiro Date: Mon, 2 Sep 2024 20:44:49 -0700 Subject: [PATCH] fix: resuming sessions (#35) --- pyproject.toml | 2 +- src/goose/cli/session.py | 18 +++++++++++++++++- tests/cli/test_session.py | 39 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df1202451..f731bc3f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "goose-ai" description = "a programming agent that runs on your machine" -version = "0.8.4" +version = "0.8.5" readme = "README.md" requires-python = ">=3.10" dependencies = [ diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 3da620d33..713bd1f8c 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from exchange import Message, ToolResult, ToolUse +from exchange import Message, ToolResult, ToolUse, Text from prompt_toolkit.shortcuts import confirm from rich import print from rich.console import RenderableType @@ -24,6 +24,8 @@ from goose.utils import droid, load_plugins from goose.utils.session_file import read_from_file, write_to_file +RESUME_MESSAGE = "I see we were interrupted. How can I help you?" + def load_provider() -> str: # We try to infer a provider, by going in order of what will auth @@ -91,8 +93,22 @@ def __init__( if name is not None and self.session_file_path.exists(): messages = self.load_session() + if messages and messages[-1].role == "user": + if type(messages[-1].content[-1]) is Text: + # remove the last user message + messages.pop() + elif type(messages[-1].content[-1]) is ToolResult: + # if we remove this message, we would need to remove + # the previous assistant message as well. instead of doing + # that, we just add a new assistant message to prompt the user + messages.append(Message.assistant(RESUME_MESSAGE)) + if messages and type(messages[-1].content[-1]) is ToolUse: + # remove the last request for a tool to be used messages.pop() + + # add a new assistant text message to prompt the user + messages.append(Message.assistant(RESUME_MESSAGE)) self.exchange.messages.extend(messages) if len(self.exchange.messages) == 0 and plan: diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 83dd6ebaf..79a7c4a2b 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from exchange import Message +from exchange import Message, ToolUse, ToolResult from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -32,7 +32,7 @@ def create_session(session_attributes: dict = {}): yield create_session -def test_session_does_not_extend_last_user_message_on_init( +def test_session_does_not_extend_last_user_text_message_on_init( create_session_with_mock_configs, mock_sessions_path, create_session_file ): messages = [Message.user("Hello"), Message.assistant("Hi"), Message.user("Last should be removed")] @@ -44,6 +44,41 @@ def test_session_does_not_extend_last_user_message_on_init( assert [message.text for message in session.exchange.messages] == ["Hello", "Hi"] +def test_session_adds_resume_message_if_last_message_is_tool_result( + create_session_with_mock_configs, mock_sessions_path, create_session_file +): + messages = [ + Message.user("Hello"), + Message(role="assistant", content=[ToolUse(id="1", name="first_tool", parameters={})]), + Message(role="user", content=[ToolResult(tool_use_id="1", output="output")]), + ] + create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl") + + session = create_session_with_mock_configs({"name": SESSION_NAME}) + print("Messages after session init:", session.exchange.messages) # Debugging line + assert len(session.exchange.messages) == 4 + assert session.exchange.messages[-1].role == "assistant" + assert session.exchange.messages[-1].text == "I see we were interrupted. How can I help you?" + + +def test_session_removes_tool_use_and_adds_resume_message_if_last_message_is_tool_use( + create_session_with_mock_configs, mock_sessions_path, create_session_file +): + messages = [ + Message.user("Hello"), + Message(role="assistant", content=[ToolUse(id="1", name="first_tool", parameters={})]), + ] + create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl") + + session = create_session_with_mock_configs({"name": SESSION_NAME}) + print("Messages after session init:", session.exchange.messages) # Debugging line + assert len(session.exchange.messages) == 2 + assert [message.text for message in session.exchange.messages] == [ + "Hello", + "I see we were interrupted. How can I help you?", + ] + + def test_save_session_create_session(mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name): session = create_session_with_mock_configs() session.exchange.messages.append(Message.assistant("Hello"))