From a43e65ee9faad2df84632265c6dc22c89c4e9687 Mon Sep 17 00:00:00 2001 From: Lam Chau Date: Fri, 27 Sep 2024 03:02:19 -0700 Subject: [PATCH] refactor: move static functions to utils --- src/goose/cli/session.py | 16 ++++------------ src/goose/utils/session_file.py | 10 +++++++++- tests/cli/test_session.py | 23 ++--------------------- tests/utils/test_session_file.py | 21 +++++++++++++++++++++ 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 665f248fa..20934db90 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -18,7 +18,7 @@ from goose.profile import Profile from goose.utils import droid, load_plugins from goose.utils._cost_calculator import get_total_cost_message -from goose.utils.session_file import read_or_create_file, save_latest_session +from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, save_latest_session RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -157,7 +157,7 @@ def prompt_overwrite_session(self) -> None: case "n" | "no": new_session_name = input("enter a new session name: ") - while Session.is_existing_session(session_path(new_session_name)): + while is_existing_session(session_path(new_session_name)): print(f"[yellow]session '{new_session_name}' already exists[/]") new_session_name = input("enter a new session name: ") self.name = new_session_name @@ -172,7 +172,7 @@ def run(self) -> None: Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. """ - if Session.is_existing_session(self.session_file_path): + if is_existing_session(self.session_file_path): self.prompt_overwrite_session() print( @@ -206,7 +206,7 @@ def run(self) -> None: # prevents cluttering the `sessions` with empty files, which # can be confusing when resuming a session - if Session.is_empty_session(self.session_file_path): + if is_empty_session(self.session_file_path): try: self.session_file_path.unlink() except FileNotFoundError: @@ -273,14 +273,6 @@ def interrupt_reply(self) -> None: def session_file_path(self) -> Path: return session_path(self.name) - @staticmethod - def is_existing_session(path: Path) -> bool: - return path.is_file() and path.stat().st_size > 0 - - @staticmethod - def is_empty_session(path: Path) -> bool: - return path.is_file() and path.stat().st_size == 0 - def load_session(self) -> List[Message]: return read_or_create_file(self.session_file_path) diff --git a/src/goose/utils/session_file.py b/src/goose/utils/session_file.py index 435186ce5..e367dcf1f 100644 --- a/src/goose/utils/session_file.py +++ b/src/goose/utils/session_file.py @@ -1,7 +1,7 @@ import json import os -from pathlib import Path import tempfile +from pathlib import Path from typing import Dict, Iterator, List from exchange import Message @@ -9,6 +9,14 @@ from goose.cli.config import SESSION_FILE_SUFFIX +def is_existing_session(path: Path) -> bool: + return path.is_file() and path.stat().st_size > 0 + + +def is_empty_session(path: Path) -> bool: + return path.is_file() and path.stat().st_size == 0 + + def write_to_file(file_path: Path, messages: List[Message]) -> None: with open(file_path, "w") as f: _write_messages_to_file(f, messages) diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 2c5817310..9ea95db28 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,5 +1,4 @@ import json -from pathlib import Path from unittest.mock import MagicMock, patch import pytest @@ -171,6 +170,7 @@ def save_latest_session(file, messages): assert saved["role"] == expected.role assert saved["content"][0]["text"] == expected.text + def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path): generated_session_name = "generated_session_name" with patch("goose.cli.session.droid", return_value=generated_session_name): @@ -178,28 +178,9 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi assert session.name == generated_session_name -def test_is_empty_session(): - with patch("pathlib.Path.is_file", return_value=True): - with patch("pathlib.Path.stat") as mock_stat: - mock_stat.return_value.st_size = 0 - assert Session.is_empty_session(Path("empty_file.json")) - - -def test_is_not_empty_session(): - with patch("pathlib.Path.is_file", return_value=True): - with patch("pathlib.Path.stat") as mock_stat: - mock_stat.return_value.st_size = 100 - assert not Session.is_empty_session(Path("non_empty_file.json")) - - -def test_is_not_empty_session_file_not_found(): - with patch("pathlib.Path.is_file", return_value=False): - assert not Session.is_empty_session(Path("non_existent_file.json")) - - def test_existing_session_prompt(create_session_with_mock_configs): with ( - patch("goose.cli.session.Session.is_existing_session", return_value=True) as mock_is_existing, + patch("goose.cli.session.is_existing_session", return_value=True) as mock_is_existing, patch("goose.cli.session.Session.prompt_overwrite_session") as mock_prompt, ): session = create_session_with_mock_configs({"name": SESSION_NAME}) diff --git a/tests/utils/test_session_file.py b/tests/utils/test_session_file.py index 6a2a64981..65fa75471 100644 --- a/tests/utils/test_session_file.py +++ b/tests/utils/test_session_file.py @@ -1,9 +1,11 @@ import os from pathlib import Path +from unittest.mock import patch import pytest from exchange import Message from goose.utils.session_file import ( + is_empty_session, list_sorted_session_files, read_from_file, read_or_create_file, @@ -115,3 +117,22 @@ def create_session_file(file_path, file_name) -> Path: file = file_path / f"{file_name}.jsonl" file.touch() return file + + +def test_is_empty_session(): + with patch("pathlib.Path.is_file", return_value=True): + with patch("pathlib.Path.stat") as mock_stat: + mock_stat.return_value.st_size = 0 + assert is_empty_session(Path("empty_file.json")) + + +def test_is_not_empty_session(): + with patch("pathlib.Path.is_file", return_value=True): + with patch("pathlib.Path.stat") as mock_stat: + mock_stat.return_value.st_size = 100 + assert not is_empty_session(Path("non_empty_file.json")) + + +def test_is_not_empty_session_file_not_found(): + with patch("pathlib.Path.is_file", return_value=False): + assert not is_empty_session(Path("non_existent_file.json"))