From c321c8cf8a8769b3a21305be3fca8c96141b0d2e Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 1 Oct 2024 13:19:38 +1000 Subject: [PATCH 1/6] exit the goose and show the error message when provider environment is not set --- src/goose/cli/config.py | 5 ----- src/goose/cli/session.py | 17 ++++++++++++++--- tests/cli/test_session.py | 15 +++++++++++++++ 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index 2005c4689..bef458a69 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -87,11 +87,6 @@ def default_model_configuration() -> Tuple[str, str, str]: break except Exception: pass - else: - raise ValueError( - "Could not detect an available provider," - + " make sure to plugin a provider or set an env var such as OPENAI_API_KEY" - ) recommended = { "ollama": (OLLAMA_MODEL, OLLAMA_MODEL), diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index bfa869a3c..e9ab7294f 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,8 +1,10 @@ +import sys import traceback from pathlib import Path from typing import Any, Dict, List, Optional -from exchange import Message, ToolResult, ToolUse, Text +from exchange import Message, ToolResult, ToolUse, Text, Exchange +from exchange.providers.base import MissingProviderEnvVariableError from rich import print from rich.console import RenderableType from rich.live import Live @@ -89,8 +91,7 @@ def __init__( self.profile = profile self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) - - self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier) + self.exchange = self._create_exchange() setup_logging(log_file_directory=LOG_PATH, log_level=log_level) self.exchange.messages.extend(self._get_initial_messages()) @@ -100,6 +101,16 @@ def __init__( self.prompt_session = GoosePromptSession() + def _create_exchange(self) -> Exchange: + try: + return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) + except MissingProviderEnvVariableError as e: + error_message = ( + f"Missing environment variable: {e.message}. Please set the required environment variable to continue." + ) + print(Panel(error_message, style="red")) + sys.exit(1) + def _get_initial_messages(self) -> List[Message]: messages = self.load_session() diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 6d9086bcd..ae71b62e4 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -2,6 +2,7 @@ import pytest from exchange import Exchange, Message, ToolUse, ToolResult +from exchange.providers.base import MissingProviderEnvVariableError from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -151,3 +152,17 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi with patch("goose.cli.session.droid", return_value=generated_session_name): session = create_session_with_mock_configs({"name": None}) assert session.name == generated_session_name + + +def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path): + session = create_session_with_mock_configs() + expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai") + with patch("goose.cli.session.build_exchange", side_effect=expected_error), patch( + "goose.cli.session.print" + ) as mock_print, patch("sys.exit") as mock_exit: + session._create_exchange() + mock_print.call_args_list[0][0][0].renderable == ( + "Missing environment variable OPENAI_API_KEY for provider openai. ", + "Please set the required environment variable to continue.", + ) + mock_exit.assert_called_once_with(1) From 6cb7cfcba7c98ec86b38d1064f0369f758281e0e Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 1 Oct 2024 13:42:02 +1000 Subject: [PATCH 2/6] fixed the message content --- src/goose/cli/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index e9ab7294f..c58aa5f8e 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -106,7 +106,7 @@ def _create_exchange(self) -> Exchange: return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) except MissingProviderEnvVariableError as e: error_message = ( - f"Missing environment variable: {e.message}. Please set the required environment variable to continue." + f"{e.message}. Please set the required environment variable to continue." ) print(Panel(error_message, style="red")) sys.exit(1) From 03d0eaca3afe2eabe79938ceb574711479ec6524 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Tue, 1 Oct 2024 16:15:35 +1000 Subject: [PATCH 3/6] save to api key to keychain --- src/goose/cli/prompt/goose_prompt_session.py | 3 ++ src/goose/cli/session.py | 48 +++----------------- src/goose/cli/session_notifier.py | 24 ++++++++++ src/goose/utils/_create_exchange.py | 44 ++++++++++++++++++ 4 files changed, 78 insertions(+), 41 deletions(-) create mode 100644 src/goose/cli/session_notifier.py create mode 100644 src/goose/utils/_create_exchange.py diff --git a/src/goose/cli/prompt/goose_prompt_session.py b/src/goose/cli/prompt/goose_prompt_session.py index 5ba54427f..e6406ac2d 100644 --- a/src/goose/cli/prompt/goose_prompt_session.py +++ b/src/goose/cli/prompt/goose_prompt_session.py @@ -86,3 +86,6 @@ def get_save_session_name(self) -> Optional[str]: "Enter a name to save this session under. A name will be generated for you if empty: ", validator=DummyValidator(), ).strip(" ") + + def get_text_prompt(self, prompt: str) -> str: + return self.text_prompt_session.prompt(prompt) diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index c58aa5f8e..fc39f0c17 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,25 +1,21 @@ -import sys import traceback from pathlib import Path from typing import Any, Dict, List, Optional -from exchange import Message, ToolResult, ToolUse, Text, Exchange -from exchange.providers.base import MissingProviderEnvVariableError +from exchange import Message, ToolResult, ToolUse, Text from rich import print -from rich.console import RenderableType -from rich.live import Live from rich.markdown import Markdown from rich.panel import Panel from rich.status import Status -from goose.build import build_exchange from goose.cli.config import ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession -from goose.notifier import Notifier +from goose.cli.session_notifier import SessionNotifier from goose.profile import Profile from goose.utils import droid, load_plugins from goose.utils._cost_calculator import get_total_cost_message +from goose.utils._create_exchange import create_exchange from goose.utils.session_file import read_or_create_file, save_latest_session RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -50,25 +46,6 @@ def load_profile(name: Optional[str]) -> Profile: _, profile = ensure_config(name) return profile - -class SessionNotifier(Notifier): - def __init__(self, status_indicator: Status) -> None: - self.status_indicator = status_indicator - self.live = Live(self.status_indicator, refresh_per_second=8, transient=True) - - def log(self, content: RenderableType) -> None: - print(content) - - def status(self, status: str) -> None: - self.status_indicator.update(status) - - def start(self) -> None: - self.live.start() - - def stop(self) -> None: - self.live.stop() - - class Session: """A session handler for managing interactions between a user and the Goose exchange @@ -88,10 +65,11 @@ def __init__( self.name = droid() else: self.name = name - self.profile = profile + self.profile_name = profile + self.prompt_session = GoosePromptSession() self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) - self.exchange = self._create_exchange() + self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier) setup_logging(log_file_directory=LOG_PATH, log_level=log_level) self.exchange.messages.extend(self._get_initial_messages()) @@ -99,18 +77,6 @@ def __init__( if len(self.exchange.messages) == 0 and plan: self.setup_plan(plan=plan) - self.prompt_session = GoosePromptSession() - - def _create_exchange(self) -> Exchange: - try: - return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) - except MissingProviderEnvVariableError as e: - error_message = ( - f"{e.message}. Please set the required environment variable to continue." - ) - print(Panel(error_message, style="red")) - sys.exit(1) - def _get_initial_messages(self) -> List[Message]: messages = self.load_session() @@ -156,7 +122,7 @@ def run(self) -> None: Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. """ - print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile or 'default'}[/]") + print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]") print(f"[dim]saving to {self.session_file_path}") print() message = self.process_first_message() diff --git a/src/goose/cli/session_notifier.py b/src/goose/cli/session_notifier.py new file mode 100644 index 000000000..41366d14d --- /dev/null +++ b/src/goose/cli/session_notifier.py @@ -0,0 +1,24 @@ +from rich.status import Status +from rich.live import Live +from rich.console import RenderableType +from rich import print + +from goose.notifier import Notifier + + +class SessionNotifier(Notifier): + def __init__(self, status_indicator: Status) -> None: + self.status_indicator = status_indicator + self.live = Live(self.status_indicator, refresh_per_second=8, transient=True) + + def log(self, content: RenderableType) -> None: + print(content) + + def status(self, status: str) -> None: + self.status_indicator.update(status) + + def start(self) -> None: + self.live.start() + + def stop(self) -> None: + self.live.stop() \ No newline at end of file diff --git a/src/goose/utils/_create_exchange.py b/src/goose/utils/_create_exchange.py new file mode 100644 index 000000000..94cb530ea --- /dev/null +++ b/src/goose/utils/_create_exchange.py @@ -0,0 +1,44 @@ +import os +import sys +from typing import Optional +import keyring + +from prompt_toolkit import prompt +from prompt_toolkit.shortcuts import confirm +from rich.panel import Panel +from rich import print + +from goose.build import build_exchange +from goose.cli.session_notifier import SessionNotifier +from goose.profile import Profile +from exchange import Exchange +from exchange.providers.base import MissingProviderEnvVariableError + + + +def create_exchange(profile: Profile, notifier: SessionNotifier) -> Exchange: + try: + return build_exchange(profile, notifier=notifier) + except MissingProviderEnvVariableError as e: + api_key = _get_api_key_from_keychain(e.env_variable, e.provider) + if api_key is None or api_key == "": + error_message = f"{e.message}. Please set the required environment variable to continue." + print(error_message) + print(Panel(error_message, style="red")) + sys.exit(1) + else: + os.environ[e.env_variable] = api_key + return build_exchange(profile=profile, notifier=notifier) + +def _get_api_key_from_keychain(env_variable: str, provider: str) -> Optional[str]: + api_key = keyring.get_password("goose", env_variable) + if api_key is not None: + print(f"Using {env_variable} value for {provider} from your keychain") + else: + api_key = prompt(f"Enter {env_variable} value for {provider}:".strip()) + if api_key is not None and len(api_key) > 0: + save_to_keyring = confirm(f"Would you like to save the {env_variable} value to your keychain?") + if save_to_keyring: + keyring.set_password("goose", env_variable, api_key) + print(f"Saved {env_variable} to your key_chain. service_name: goose, user_name: {env_variable}") + return api_key From b32b1ead0790457827668c2556c987121b0d9f03 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 14:59:07 +1000 Subject: [PATCH 4/6] fixed merge issue and added test cases --- pyproject.toml | 1 + src/goose/cli/prompt/goose_prompt_session.py | 2 +- src/goose/cli/session.py | 11 +- src/goose/cli/session_notifier.py | 2 +- src/goose/utils/_create_exchange.py | 21 ++- tests/cli/test_session.py | 35 +---- tests/utils/test_create_exchange.py | 150 +++++++++++++++++++ 7 files changed, 172 insertions(+), 50 deletions(-) create mode 100644 tests/utils/test_create_exchange.py diff --git a/pyproject.toml b/pyproject.toml index 4db2570f8..3d29cc59c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "ruamel-yaml>=0.18.6", "click>=8.1.7", "prompt-toolkit>=3.0.47", + "keyring>=25.4.1", ] author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] packages = [{ include = "goose", from = "src" }] diff --git a/src/goose/cli/prompt/goose_prompt_session.py b/src/goose/cli/prompt/goose_prompt_session.py index e6406ac2d..5f11beae3 100644 --- a/src/goose/cli/prompt/goose_prompt_session.py +++ b/src/goose/cli/prompt/goose_prompt_session.py @@ -86,6 +86,6 @@ def get_save_session_name(self) -> Optional[str]: "Enter a name to save this session under. A name will be generated for you if empty: ", validator=DummyValidator(), ).strip(" ") - + def get_text_prompt(self, prompt: str) -> str: return self.text_prompt_session.prompt(prompt) diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 15651c66f..c838f4178 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,18 +1,14 @@ -import sys import traceback from pathlib import Path from typing import Any, Dict, List, Optional -from exchange import Message, ToolResult, ToolUse, Text, Exchange -from exchange.providers.base import MissingProviderEnvVariableError -from exchange.invalid_choice_error import InvalidChoiceError +from exchange import Message, ToolResult, ToolUse, Text from rich import print from rich.markdown import Markdown from rich.panel import Panel from rich.status import Status -from goose.build import build_exchange -from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH +from goose.cli.config import ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.session_notifier import SessionNotifier @@ -50,6 +46,7 @@ def load_profile(name: Optional[str]) -> Profile: _, profile = ensure_config(name) return profile + class Session: """A session handler for managing interactions between a user and the Goose exchange @@ -74,7 +71,7 @@ def __init__( self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) - self.exchange = build_exchange(profile=load_profile(profile), notifier=self.notifier) + self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier) setup_logging(log_file_directory=LOG_PATH, log_level=log_level) self.exchange.messages.extend(self._get_initial_messages()) diff --git a/src/goose/cli/session_notifier.py b/src/goose/cli/session_notifier.py index 41366d14d..d29ce944a 100644 --- a/src/goose/cli/session_notifier.py +++ b/src/goose/cli/session_notifier.py @@ -21,4 +21,4 @@ def start(self) -> None: self.live.start() def stop(self) -> None: - self.live.stop() \ No newline at end of file + self.live.stop() diff --git a/src/goose/utils/_create_exchange.py b/src/goose/utils/_create_exchange.py index 967dea9bb..d1aa318f5 100644 --- a/src/goose/utils/_create_exchange.py +++ b/src/goose/utils/_create_exchange.py @@ -6,31 +6,38 @@ from prompt_toolkit import prompt from prompt_toolkit.shortcuts import confirm from rich import print +from rich.panel import Panel from goose.build import build_exchange +from goose.cli.config import PROFILES_CONFIG_PATH from goose.cli.session_notifier import SessionNotifier from goose.profile import Profile from exchange import Exchange -from exchange.providers.base import InvalidChoiceError - +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers.base import MissingProviderEnvVariableError def create_exchange(profile: Profile, notifier: SessionNotifier) -> Exchange: try: return build_exchange(profile, notifier=notifier) except InvalidChoiceError as e: + error_message = ( + f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" + + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" + ) + print(error_message) + sys.exit(1) + except MissingProviderEnvVariableError as e: api_key = _get_api_key_from_keychain(e.env_variable, e.provider) if api_key is None or api_key == "": - error_message = ( - f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" - + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" - ) - print(error_message) + error_message = f"{e.message}. Please set the required environment variable to continue." + print(Panel(error_message, style="red")) sys.exit(1) else: os.environ[e.env_variable] = api_key return build_exchange(profile=profile, notifier=notifier) + def _get_api_key_from_keychain(env_variable: str, provider: str) -> Optional[str]: api_key = keyring.get_password("goose", env_variable) if api_key is not None: diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 81519512d..f2437462c 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -2,8 +2,6 @@ import pytest from exchange import Exchange, Message, ToolUse, ToolResult -from exchange.providers.base import MissingProviderEnvVariableError -from exchange.invalid_choice_error import InvalidChoiceError from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -22,7 +20,7 @@ def mock_specified_session_name(): @pytest.fixture def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory): with ( - patch("goose.cli.session.build_exchange") as mock_exchange, + patch("goose.cli.session.create_exchange") as mock_exchange, patch("goose.cli.session.load_profile", return_value=profile_factory()), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch("goose.cli.session.load_provider", return_value="provider"), @@ -158,34 +156,3 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi with patch("goose.cli.session.droid", return_value=generated_session_name): session = create_session_with_mock_configs({"name": None}) assert session.name == generated_session_name - - -def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path): - session = create_session_with_mock_configs() - expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai") - with ( - patch("goose.cli.session.build_exchange", side_effect=expected_error), - patch("goose.cli.session.print") as mock_print, - patch("sys.exit") as mock_exit, - ): - session._create_exchange() - mock_print.call_args_list[0][0][0].renderable == ( - "Missing environment variable OPENAI_API_KEY for provider openai. ", - "Please set the required environment variable to continue.", - ) - mock_exit.assert_called_once_with(1) - - -def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path): - session = create_session_with_mock_configs() - expected_error = InvalidChoiceError( - attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] - ) - with ( - patch("goose.cli.session.build_exchange", side_effect=expected_error), - patch("goose.cli.session.print") as mock_print, - patch("sys.exit") as mock_exit, - ): - session._create_exchange() - assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] - mock_exit.assert_called_once_with(1) diff --git a/tests/utils/test_create_exchange.py b/tests/utils/test_create_exchange.py new file mode 100644 index 000000000..480a53099 --- /dev/null +++ b/tests/utils/test_create_exchange.py @@ -0,0 +1,150 @@ +import os +from unittest.mock import MagicMock, patch + +from exchange.exchange import Exchange +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers.base import MissingProviderEnvVariableError +import pytest + +from goose.profile import Profile +from goose.utils._create_exchange import create_exchange + +TEST_PROFILE = MagicMock(spec=Profile) +TEST_EXCHANGE = MagicMock(spec=Exchange) +TEST_NOTIFIER = MagicMock(spec=Exchange) + + +@pytest.fixture +def mock_print(): + with patch("goose.utils._create_exchange.print") as mock_print: + yield mock_print + + +@pytest.fixture +def mock_prompt(): + with patch("goose.utils._create_exchange.prompt") as mock_prompt: + yield mock_prompt + + +@pytest.fixture +def mock_confirm(): + with patch("goose.utils._create_exchange.confirm") as mock_confirm: + yield mock_confirm + + +@pytest.fixture +def mock_sys_exit(): + with patch("sys.exit") as mock_exit: + yield mock_exit + + +@pytest.fixture +def mock_keyring_get_password(): + with patch("keyring.get_password") as mock_get_password: + yield mock_get_password + + +@pytest.fixture +def mock_keyring_set_password(): + with patch("keyring.set_password") as mock_set_password: + yield mock_set_password + + +def test_create_exchange_success(mock_print): + with patch("goose.utils._create_exchange.build_exchange", return_value=TEST_EXCHANGE): + assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + +def test_create_exchange_fail_with_invalid_choice_error(mock_print, mock_sys_exit): + expected_error = InvalidChoiceError( + attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] + ) + with patch("goose.utils._create_exchange.build_exchange", side_effect=expected_error): + create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) + + assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] + mock_sys_exit.assert_called_once_with(1) + + +class TestWhenProviderEnvVarNotFound: + API_KEY_ENV_VAR = "OPENAI_API_KEY" + API_KEY_ENV_VALUE = "api_key_value" + PROVIDER_NAME = "openai" + SERVICE_NAME = "goose" + EXPECTED_ERROR = MissingProviderEnvVariableError(env_variable=API_KEY_ENV_VAR, provider=PROVIDER_NAME) + + def test_create_exchange_get_api_key_from_keychain( + self, mock_print, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = self.API_KEY_ENV_VALUE + + assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_get_password.assert_called_once_with(self.SERVICE_NAME, self.API_KEY_ENV_VAR) + mock_print.assert_called_once_with( + f"Using {self.API_KEY_ENV_VAR} value for {self.PROVIDER_NAME} from your keychain" + ) + mock_sys_exit.assert_not_called() + mock_keyring_set_password.assert_not_called() + + def test_create_exchange_ask_api_key_and_user_set_in_keychain( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password, mock_print + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = self.API_KEY_ENV_VALUE + mock_confirm.return_value = True + + assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_set_password.assert_called_once_with( + self.SERVICE_NAME, self.API_KEY_ENV_VAR, self.API_KEY_ENV_VALUE + ) + mock_confirm.assert_called_once_with( + f"Would you like to save the {self.API_KEY_ENV_VAR} value to your keychain?" + ) + mock_print.assert_called_once_with( + f"Saved {self.API_KEY_ENV_VAR} to your key_chain. " + + f"service_name: goose, user_name: {self.API_KEY_ENV_VAR}" + ) + mock_sys_exit.assert_not_called() + + def test_create_exchange_ask_api_key_and_user_not_set_in_keychain( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = self.API_KEY_ENV_VALUE + mock_confirm.return_value = False + + assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_set_password.assert_not_called() + mock_sys_exit.assert_not_called() + + def test_create_exchange_fails_when_user_not_provide_api_key( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_print + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = None + mock_confirm.return_value = False + + create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) + + assert ( + "Please set the required environment variable to continue." + in mock_print.call_args_list[0][0][0].renderable + ) + mock_sys_exit.assert_called_once_with(1) + + def _clean_env(self): + os.environ.pop(self.API_KEY_ENV_VAR, None) From f1b3692876d80883748ef010bebe43f257343d06 Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Fri, 4 Oct 2024 15:06:11 +1000 Subject: [PATCH 5/6] removed unused function --- src/goose/cli/prompt/goose_prompt_session.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/goose/cli/prompt/goose_prompt_session.py b/src/goose/cli/prompt/goose_prompt_session.py index 5f11beae3..5ba54427f 100644 --- a/src/goose/cli/prompt/goose_prompt_session.py +++ b/src/goose/cli/prompt/goose_prompt_session.py @@ -86,6 +86,3 @@ def get_save_session_name(self) -> Optional[str]: "Enter a name to save this session under. A name will be generated for you if empty: ", validator=DummyValidator(), ).strip(" ") - - def get_text_prompt(self, prompt: str) -> str: - return self.text_prompt_session.prompt(prompt) From 5226cb0fcd683ba5d18f0ecacc58d2289ea1fffb Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Sat, 5 Oct 2024 08:47:41 +1000 Subject: [PATCH 6/6] minor fix in test --- tests/utils/test_create_exchange.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_create_exchange.py b/tests/utils/test_create_exchange.py index 480a53099..62fdde5f2 100644 --- a/tests/utils/test_create_exchange.py +++ b/tests/utils/test_create_exchange.py @@ -6,12 +6,13 @@ from exchange.providers.base import MissingProviderEnvVariableError import pytest +from goose.notifier import Notifier from goose.profile import Profile from goose.utils._create_exchange import create_exchange TEST_PROFILE = MagicMock(spec=Profile) TEST_EXCHANGE = MagicMock(spec=Exchange) -TEST_NOTIFIER = MagicMock(spec=Exchange) +TEST_NOTIFIER = MagicMock(spec=Notifier) @pytest.fixture @@ -133,7 +134,7 @@ def test_create_exchange_fails_when_user_not_provide_api_key( self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_print ): self._clean_env() - with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + with patch("goose.utils._create_exchange.build_exchange", side_effect=self.EXPECTED_ERROR): mock_keyring_get_password.return_value = None mock_prompt.return_value = None mock_confirm.return_value = False