Skip to content

Commit

Permalink
fix: resuming sessions (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke Alvoeiro authored Sep 3, 2024
1 parent 3c930e1 commit fadeba9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "goose-ai"
description = "a programming agent that runs on your machine"
version = "0.8.4"
version = "0.8.5"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
Expand Down
18 changes: 17 additions & 1 deletion src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from exchange import Message, ToolResult, ToolUse
from exchange import Message, ToolResult, ToolUse, Text
from prompt_toolkit.shortcuts import confirm
from rich import print
from rich.console import RenderableType
Expand All @@ -24,6 +24,8 @@
from goose.utils import droid, load_plugins
from goose.utils.session_file import read_from_file, write_to_file

RESUME_MESSAGE = "I see we were interrupted. How can I help you?"


def load_provider() -> str:
# We try to infer a provider, by going in order of what will auth
Expand Down Expand Up @@ -91,8 +93,22 @@ def __init__(

if name is not None and self.session_file_path.exists():
messages = self.load_session()

if messages and messages[-1].role == "user":
if type(messages[-1].content[-1]) is Text:
# remove the last user message
messages.pop()
elif type(messages[-1].content[-1]) is ToolResult:
# if we remove this message, we would need to remove
# the previous assistant message as well. instead of doing
# that, we just add a new assistant message to prompt the user
messages.append(Message.assistant(RESUME_MESSAGE))
if messages and type(messages[-1].content[-1]) is ToolUse:
# remove the last request for a tool to be used
messages.pop()

# add a new assistant text message to prompt the user
messages.append(Message.assistant(RESUME_MESSAGE))
self.exchange.messages.extend(messages)

if len(self.exchange.messages) == 0 and plan:
Expand Down
39 changes: 37 additions & 2 deletions tests/cli/test_session.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch

import pytest
from exchange import Message
from exchange import Message, ToolUse, ToolResult
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.cli.session import Session
Expand Down Expand Up @@ -32,7 +32,7 @@ def create_session(session_attributes: dict = {}):
yield create_session


def test_session_does_not_extend_last_user_message_on_init(
def test_session_does_not_extend_last_user_text_message_on_init(
create_session_with_mock_configs, mock_sessions_path, create_session_file
):
messages = [Message.user("Hello"), Message.assistant("Hi"), Message.user("Last should be removed")]
Expand All @@ -44,6 +44,41 @@ def test_session_does_not_extend_last_user_message_on_init(
assert [message.text for message in session.exchange.messages] == ["Hello", "Hi"]


def test_session_adds_resume_message_if_last_message_is_tool_result(
create_session_with_mock_configs, mock_sessions_path, create_session_file
):
messages = [
Message.user("Hello"),
Message(role="assistant", content=[ToolUse(id="1", name="first_tool", parameters={})]),
Message(role="user", content=[ToolResult(tool_use_id="1", output="output")]),
]
create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl")

session = create_session_with_mock_configs({"name": SESSION_NAME})
print("Messages after session init:", session.exchange.messages) # Debugging line
assert len(session.exchange.messages) == 4
assert session.exchange.messages[-1].role == "assistant"
assert session.exchange.messages[-1].text == "I see we were interrupted. How can I help you?"


def test_session_removes_tool_use_and_adds_resume_message_if_last_message_is_tool_use(
create_session_with_mock_configs, mock_sessions_path, create_session_file
):
messages = [
Message.user("Hello"),
Message(role="assistant", content=[ToolUse(id="1", name="first_tool", parameters={})]),
]
create_session_file(messages, mock_sessions_path / f"{SESSION_NAME}.jsonl")

session = create_session_with_mock_configs({"name": SESSION_NAME})
print("Messages after session init:", session.exchange.messages) # Debugging line
assert len(session.exchange.messages) == 2
assert [message.text for message in session.exchange.messages] == [
"Hello",
"I see we were interrupted. How can I help you?",
]


def test_save_session_create_session(mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name):
session = create_session_with_mock_configs()
session.exchange.messages.append(Message.assistant("Hello"))
Expand Down

0 comments on commit fadeba9

Please sign in to comment.