From b4ee269cf1b59e7d2c3797e1f3d7143f8554a686 Mon Sep 17 00:00:00 2001 From: Lam Chau Date: Fri, 27 Sep 2024 02:25:44 -0700 Subject: [PATCH] test: add tests to validate content --- tests/cli/test_session.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 89b76b1ec..2c5817310 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,3 +1,4 @@ +import json from pathlib import Path from unittest.mock import MagicMock, patch @@ -133,10 +134,19 @@ def custom_exchange_generate(self, *args, **kwargs): UserInput(action=PromptAction.EXIT), ] + def save_latest_session(file, messages): + file.write_text("\n".join(json.dumps(m.to_dict()) for m in messages)) + session = create_session_with_mock_configs({"name": SESSION_NAME}) - with patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), patch.object( - Exchange, "generate" - ) as mock_generate, patch("goose.cli.session.save_latest_session") as mock_save_latest_session: + + with ( + patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), + patch.object(Exchange, "generate") as mock_generate, + patch( + "goose.cli.session.save_latest_session", + side_effect=save_latest_session, + ) as mock_save_latest_session, + ): mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs) session.run() @@ -146,6 +156,20 @@ def custom_exchange_generate(self, *args, **kwargs): assert mock_save_latest_session.call_args_list[0][0][0] == session_file assert session_file.exists() + with open(session_file, "r") as f: + saved_messages = [json.loads(line) for line in f] + + expected_messages = [ + Message.user("Question1"), + Message.assistant("Response"), + Message.user("Question2"), + Message.assistant("Response"), + ] + + assert len(saved_messages) == len(expected_messages) + for saved, expected in zip(saved_messages, expected_messages): + assert saved["role"] == expected.role + assert saved["content"][0]["text"] == expected.text def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path): generated_session_name = "generated_session_name"