Skip to content

Commit

Permalink
feat: auto save sessions before next user input (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
lifeizhou-ap authored Sep 26, 2024
1 parent d56c0d6 commit 6065125
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 134 deletions.
39 changes: 21 additions & 18 deletions src/goose/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,32 @@ def list_toolkits() -> None:
print(f" - [bold]{toolkit_name}[/bold]: {first_line_of_doc}")


def autocomplete_session_files(ctx: click.Context, args: str, incomplete: str) -> None:
return [
f"{session_name}"
for session_name in sorted(get_session_files().keys(), reverse=True, key=lambda x: x.lower())
if session_name.startswith(incomplete)
]


def get_session_files() -> dict[str, Path]:
return list_sorted_session_files(SESSIONS_PATH)


@session.command(name="start")
@click.argument("name", required=False, shell_complete=autocomplete_session_files)
@click.option("--profile")
@click.option("--plan", type=click.Path(exists=True))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
def session_start(profile: str, log_level: str, plan: Optional[str] = None) -> None:
def session_start(name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None) -> None:
"""Start a new goose session"""
if plan:
yaml = YAML()
with open(plan, "r") as f:
_plan = yaml.load(f)
else:
_plan = None
session = Session(profile=profile, plan=_plan, log_level=log_level)
session = Session(name=name, profile=profile, plan=_plan, log_level=log_level)
session.run()


Expand All @@ -126,30 +139,20 @@ def parse_args(ctx: click.Context, param: click.Parameter, value: str) -> dict[s

@session.command(name="planned")
@click.option("--plan", type=click.Path(exists=True))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
@click.option("-a", "--args", callback=parse_args, help="Args in the format arg1:value1,arg2:value2")
def session_planned(plan: str, args: Optional[dict[str, str]]) -> None:
def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -> None:
plan_templated = render_template(Path(plan), context=args)
_plan = parse_plan(plan_templated)
session = Session(plan=_plan)
session = Session(plan=_plan, log_level=log_level)
session.run()


def autocomplete_session_files(ctx: click.Context, args: str, incomplete: str) -> None:
return [
f"{session_name}"
for session_name in sorted(get_session_files().keys(), reverse=True, key=lambda x: x.lower())
if session_name.startswith(incomplete)
]


def get_session_files() -> dict[str, Path]:
return list_sorted_session_files(SESSIONS_PATH)


@session.command(name="resume")
@click.argument("name", required=False, shell_complete=autocomplete_session_files)
@click.option("--profile")
def session_resume(name: Optional[str], profile: str) -> None:
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO")
def session_resume(name: Optional[str], profile: str, log_level: str) -> None:
"""Resume an existing goose session"""
session_files = get_session_files()
if name is None:
Expand All @@ -164,7 +167,7 @@ def session_resume(name: Optional[str], profile: str) -> None:
print(f"Resuming session: {name}")
else:
print(f"Creating new session: {name}")
session = Session(name=name, profile=profile)
session = Session(name=name, profile=profile, log_level=log_level)
session.run()


Expand Down
80 changes: 34 additions & 46 deletions src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict, List, Optional

from exchange import Message, ToolResult, ToolUse, Text
from prompt_toolkit.shortcuts import confirm
from rich import print
from rich.console import RenderableType
from rich.live import Live
Expand All @@ -19,7 +18,7 @@
from goose.profile import Profile
from goose.utils import droid, load_plugins
from goose.utils._cost_calculator import get_total_cost_message
from goose.utils.session_file import read_from_file, write_to_file
from goose.utils.session_file import read_or_create_file, save_latest_session

RESUME_MESSAGE = "I see we were interrupted. How can I help you?"

Expand Down Expand Up @@ -84,38 +83,44 @@ def __init__(
log_level: Optional[str] = "INFO",
**kwargs: Dict[str, Any],
) -> None:
self.name = name
if name is None:
self.name = droid()
print(Panel(f"Session name not provided, using generated name: {self.name}"))
else:
self.name = name
self.status_indicator = Status("", spinner="dots")
self.notifier = SessionNotifier(self.status_indicator)

self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier)
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)

if name is not None and self.session_file_path.exists():
messages = self.load_session()

if messages and messages[-1].role == "user":
if type(messages[-1].content[-1]) is Text:
# remove the last user message
messages.pop()
elif type(messages[-1].content[-1]) is ToolResult:
# if we remove this message, we would need to remove
# the previous assistant message as well. instead of doing
# that, we just add a new assistant message to prompt the user
messages.append(Message.assistant(RESUME_MESSAGE))
if messages and type(messages[-1].content[-1]) is ToolUse:
# remove the last request for a tool to be used
messages.pop()

# add a new assistant text message to prompt the user
messages.append(Message.assistant(RESUME_MESSAGE))
self.exchange.messages.extend(messages)
self.exchange.messages.extend(self._get_initial_messages())

if len(self.exchange.messages) == 0 and plan:
self.setup_plan(plan=plan)

self.prompt_session = GoosePromptSession()

def _get_initial_messages(self) -> List[Message]:
messages = self.load_session()

if messages and messages[-1].role == "user":
if type(messages[-1].content[-1]) is Text:
# remove the last user message
messages.pop()
elif type(messages[-1].content[-1]) is ToolResult:
# if we remove this message, we would need to remove
# the previous assistant message as well. instead of doing
# that, we just add a new assistant message to prompt the user
messages.append(Message.assistant(RESUME_MESSAGE))
if messages and type(messages[-1].content[-1]) is ToolUse:
# remove the last request for a tool to be used
messages.pop()

# add a new assistant text message to prompt the user
messages.append(Message.assistant(RESUME_MESSAGE))
return messages

def setup_plan(self, plan: dict) -> None:
if len(self.exchange.messages):
raise ValueError("The plan can only be set on an empty session.")
Expand Down Expand Up @@ -160,12 +165,11 @@ def run(self) -> None:
+ " - [yellow]depending on the error you may be able to continue[/]"
)
self.notifier.stop()

save_latest_session(self.session_file_path, self.exchange.messages)
print() # Print a newline for separation.
user_input = self.prompt_session.get_user_input()
message = Message.user(text=user_input.text) if user_input.to_continue() else None

self.save_session()
self._log_cost()

def reply(self) -> None:
Expand Down Expand Up @@ -226,29 +230,13 @@ def interrupt_reply(self) -> None:
def session_file_path(self) -> Path:
return session_path(self.name)

def save_session(self) -> None:
"""Save the current session to a file in JSON format."""
if self.name is None:
self.generate_session_name()

try:
if self.session_file_path.exists():
if not confirm(f"Session {self.name} exists in {self.session_file_path}, overwrite?"):
self.generate_session_name()
write_to_file(self.session_file_path, self.exchange.messages)
except PermissionError as e:
raise RuntimeError(f"Failed to save session due to permissions: {e}")
except (IOError, OSError) as e:
raise RuntimeError(f"Failed to save session due to I/O error: {e}")

def load_session(self) -> List[Message]:
"""Load a session from a JSON file."""
return read_from_file(self.session_file_path)

def generate_session_name(self) -> None:
user_entered_session_name = self.prompt_session.get_save_session_name()
self.name = user_entered_session_name if user_entered_session_name else droid()
print(f"Saving to [bold cyan]{self.session_file_path}[/bold cyan]")
message = (
f"session is going to be saved to [bold cyan]{self.session_file_path}[/bold cyan]."
+ " You can view it anytime."
)
print(Panel(message))
return read_or_create_file(self.session_file_path)

def _log_cost(self) -> None:
get_logger().info(get_total_cost_message(self.exchange.get_token_usage()))
Expand Down
28 changes: 25 additions & 3 deletions src/goose/utils/session_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import os
from pathlib import Path
import tempfile
from typing import Dict, Iterator, List

from exchange import Message
Expand All @@ -9,9 +11,15 @@

def write_to_file(file_path: Path, messages: List[Message]) -> None:
with open(file_path, "w") as f:
for m in messages:
json.dump(m.to_dict(), f)
f.write("\n")
_write_messages_to_file(f, messages)


def read_or_create_file(file_path: Path) -> List[Message]:
if file_path.exists():
return read_from_file(file_path)
with open(file_path, "w"):
pass
return []


def read_from_file(file_path: Path) -> List[Message]:
Expand All @@ -37,3 +45,17 @@ def session_file_exists(session_files_directory: Path) -> bool:
if not session_files_directory.exists():
return False
return any(list_session_files(session_files_directory))


def save_latest_session(file_path: Path, messages: List[Message]) -> None:
with tempfile.NamedTemporaryFile("w", delete=False) as temp_file:
_write_messages_to_file(temp_file, messages)
temp_file_path = temp_file.name

os.replace(temp_file_path, file_path)


def _write_messages_to_file(file: any, messages: List[Message]) -> None:
for m in messages:
json.dump(m.to_dict(), file)
file.write("\n")
14 changes: 11 additions & 3 deletions tests/cli/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,19 @@ def mock_session():
yield mock_session_class, mock_session_instance


def test_session_start_command_with_session_name(mock_session):
mock_session_class, mock_session_instance = mock_session
runner = CliRunner()
runner.invoke(goose_cli, ["session", "start", "session1", "--profile", "default"])
mock_session_class.assert_called_once_with(name="session1", profile="default", plan=None, log_level="INFO")
mock_session_instance.run.assert_called_once()


def test_session_resume_command_with_session_name(mock_session):
mock_session_class, mock_session_instance = mock_session
runner = CliRunner()
runner.invoke(goose_cli, ["session", "resume", "session1", "--profile", "default"])
mock_session_class.assert_called_once_with(name="session1", profile="default")
mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO")
mock_session_instance.run.assert_called_once()


Expand All @@ -59,7 +67,7 @@ def test_session_resume_command_without_session_name_use_latest_session(

second_file_path = mock_session_files_path / "second.jsonl"
mock_print.assert_called_once_with(f"Resuming most recent session: second from {second_file_path}")
mock_session_class.assert_called_once_with(name="second", profile="default")
mock_session_class.assert_called_once_with(name="second", profile="default", log_level="INFO")
mock_session_instance.run.assert_called_once()


Expand Down Expand Up @@ -121,7 +129,7 @@ def test_combined_group_commands(mock_session):
mock_session_class, mock_session_instance = mock_session
runner = CliRunner()
runner.invoke(cli, ["session", "resume", "session1", "--profile", "default"])
mock_session_class.assert_called_once_with(name="session1", profile="default")
mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO")
mock_session_instance.run.assert_called_once()


Expand Down
Loading

0 comments on commit 6065125

Please sign in to comment.