From 8344b6ab2a52968bfa24bfb1047981331fd083ff Mon Sep 17 00:00:00 2001 From: Lifei Zhou Date: Mon, 7 Oct 2024 09:40:50 +1100 Subject: [PATCH] feat: saved api_key to keychain for user (#104) --- pyproject.toml | 1 + src/goose/cli/session.py | 54 ++-------- src/goose/cli/session_notifier.py | 24 +++++ src/goose/utils/_create_exchange.py | 52 ++++++++++ tests/cli/test_session.py | 35 +------ tests/utils/test_create_exchange.py | 151 ++++++++++++++++++++++++++++ 6 files changed, 238 insertions(+), 79 deletions(-) create mode 100644 src/goose/cli/session_notifier.py create mode 100644 src/goose/utils/_create_exchange.py create mode 100644 tests/utils/test_create_exchange.py diff --git a/pyproject.toml b/pyproject.toml index 4db2570f8..3d29cc59c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "ruamel-yaml>=0.18.6", "click>=8.1.7", "prompt-toolkit>=3.0.47", + "keyring>=25.4.1", ] author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] packages = [{ include = "goose", from = "src" }] diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index f9fe13b9a..c838f4178 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,26 +1,21 @@ -import sys import traceback from pathlib import Path from typing import Any, Dict, List, Optional -from exchange import Message, ToolResult, ToolUse, Text, Exchange -from exchange.providers.base import MissingProviderEnvVariableError -from exchange.invalid_choice_error import InvalidChoiceError +from exchange import Message, ToolResult, ToolUse, Text from rich import print -from rich.console import RenderableType -from rich.live import Live from rich.markdown import Markdown from rich.panel import Panel from rich.status import Status -from goose.build import build_exchange -from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH +from goose.cli.config import ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession -from goose.notifier import Notifier +from goose.cli.session_notifier import SessionNotifier from goose.profile import Profile from goose.utils import droid, load_plugins from goose.utils._cost_calculator import get_total_cost_message +from goose.utils._create_exchange import create_exchange from goose.utils.session_file import read_or_create_file, save_latest_session RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -52,24 +47,6 @@ def load_profile(name: Optional[str]) -> Profile: return profile -class SessionNotifier(Notifier): - def __init__(self, status_indicator: Status) -> None: - self.status_indicator = status_indicator - self.live = Live(self.status_indicator, refresh_per_second=8, transient=True) - - def log(self, content: RenderableType) -> None: - print(content) - - def status(self, status: str) -> None: - self.status_indicator.update(status) - - def start(self) -> None: - self.live.start() - - def stop(self) -> None: - self.live.stop() - - class Session: """A session handler for managing interactions between a user and the Goose exchange @@ -89,10 +66,12 @@ def __init__( self.name = droid() else: self.name = name - self.profile = profile + self.profile_name = profile + self.prompt_session = GoosePromptSession() self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) - self.exchange = self._create_exchange() + + self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier) setup_logging(log_file_directory=LOG_PATH, log_level=log_level) self.exchange.messages.extend(self._get_initial_messages()) @@ -102,21 +81,6 @@ def __init__( self.prompt_session = GoosePromptSession() - def _create_exchange(self) -> Exchange: - try: - return build_exchange(profile=load_profile(self.profile), notifier=self.notifier) - except MissingProviderEnvVariableError as e: - error_message = f"{e.message}. Please set the required environment variable to continue." - print(Panel(error_message, style="red")) - sys.exit(1) - except InvalidChoiceError as e: - error_message = ( - f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" - + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" - ) - print(error_message) - sys.exit(1) - def _get_initial_messages(self) -> List[Message]: messages = self.load_session() @@ -162,7 +126,7 @@ def run(self) -> None: Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. """ - print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile or 'default'}[/]") + print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]") print(f"[dim]saving to {self.session_file_path}") print() message = self.process_first_message() diff --git a/src/goose/cli/session_notifier.py b/src/goose/cli/session_notifier.py new file mode 100644 index 000000000..d29ce944a --- /dev/null +++ b/src/goose/cli/session_notifier.py @@ -0,0 +1,24 @@ +from rich.status import Status +from rich.live import Live +from rich.console import RenderableType +from rich import print + +from goose.notifier import Notifier + + +class SessionNotifier(Notifier): + def __init__(self, status_indicator: Status) -> None: + self.status_indicator = status_indicator + self.live = Live(self.status_indicator, refresh_per_second=8, transient=True) + + def log(self, content: RenderableType) -> None: + print(content) + + def status(self, status: str) -> None: + self.status_indicator.update(status) + + def start(self) -> None: + self.live.start() + + def stop(self) -> None: + self.live.stop() diff --git a/src/goose/utils/_create_exchange.py b/src/goose/utils/_create_exchange.py new file mode 100644 index 000000000..d1aa318f5 --- /dev/null +++ b/src/goose/utils/_create_exchange.py @@ -0,0 +1,52 @@ +import os +import sys +from typing import Optional +import keyring + +from prompt_toolkit import prompt +from prompt_toolkit.shortcuts import confirm +from rich import print +from rich.panel import Panel + +from goose.build import build_exchange +from goose.cli.config import PROFILES_CONFIG_PATH +from goose.cli.session_notifier import SessionNotifier +from goose.profile import Profile +from exchange import Exchange +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers.base import MissingProviderEnvVariableError + + +def create_exchange(profile: Profile, notifier: SessionNotifier) -> Exchange: + try: + return build_exchange(profile, notifier=notifier) + except InvalidChoiceError as e: + error_message = ( + f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" + + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" + ) + print(error_message) + sys.exit(1) + except MissingProviderEnvVariableError as e: + api_key = _get_api_key_from_keychain(e.env_variable, e.provider) + if api_key is None or api_key == "": + error_message = f"{e.message}. Please set the required environment variable to continue." + print(Panel(error_message, style="red")) + sys.exit(1) + else: + os.environ[e.env_variable] = api_key + return build_exchange(profile=profile, notifier=notifier) + + +def _get_api_key_from_keychain(env_variable: str, provider: str) -> Optional[str]: + api_key = keyring.get_password("goose", env_variable) + if api_key is not None: + print(f"Using {env_variable} value for {provider} from your keychain") + else: + api_key = prompt(f"Enter {env_variable} value for {provider}:".strip()) + if api_key is not None and len(api_key) > 0: + save_to_keyring = confirm(f"Would you like to save the {env_variable} value to your keychain?") + if save_to_keyring: + keyring.set_password("goose", env_variable, api_key) + print(f"Saved {env_variable} to your key_chain. service_name: goose, user_name: {env_variable}") + return api_key diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 81519512d..f2437462c 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -2,8 +2,6 @@ import pytest from exchange import Exchange, Message, ToolUse, ToolResult -from exchange.providers.base import MissingProviderEnvVariableError -from exchange.invalid_choice_error import InvalidChoiceError from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -22,7 +20,7 @@ def mock_specified_session_name(): @pytest.fixture def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory): with ( - patch("goose.cli.session.build_exchange") as mock_exchange, + patch("goose.cli.session.create_exchange") as mock_exchange, patch("goose.cli.session.load_profile", return_value=profile_factory()), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch("goose.cli.session.load_provider", return_value="provider"), @@ -158,34 +156,3 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi with patch("goose.cli.session.droid", return_value=generated_session_name): session = create_session_with_mock_configs({"name": None}) assert session.name == generated_session_name - - -def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path): - session = create_session_with_mock_configs() - expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai") - with ( - patch("goose.cli.session.build_exchange", side_effect=expected_error), - patch("goose.cli.session.print") as mock_print, - patch("sys.exit") as mock_exit, - ): - session._create_exchange() - mock_print.call_args_list[0][0][0].renderable == ( - "Missing environment variable OPENAI_API_KEY for provider openai. ", - "Please set the required environment variable to continue.", - ) - mock_exit.assert_called_once_with(1) - - -def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path): - session = create_session_with_mock_configs() - expected_error = InvalidChoiceError( - attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] - ) - with ( - patch("goose.cli.session.build_exchange", side_effect=expected_error), - patch("goose.cli.session.print") as mock_print, - patch("sys.exit") as mock_exit, - ): - session._create_exchange() - assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] - mock_exit.assert_called_once_with(1) diff --git a/tests/utils/test_create_exchange.py b/tests/utils/test_create_exchange.py new file mode 100644 index 000000000..62fdde5f2 --- /dev/null +++ b/tests/utils/test_create_exchange.py @@ -0,0 +1,151 @@ +import os +from unittest.mock import MagicMock, patch + +from exchange.exchange import Exchange +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers.base import MissingProviderEnvVariableError +import pytest + +from goose.notifier import Notifier +from goose.profile import Profile +from goose.utils._create_exchange import create_exchange + +TEST_PROFILE = MagicMock(spec=Profile) +TEST_EXCHANGE = MagicMock(spec=Exchange) +TEST_NOTIFIER = MagicMock(spec=Notifier) + + +@pytest.fixture +def mock_print(): + with patch("goose.utils._create_exchange.print") as mock_print: + yield mock_print + + +@pytest.fixture +def mock_prompt(): + with patch("goose.utils._create_exchange.prompt") as mock_prompt: + yield mock_prompt + + +@pytest.fixture +def mock_confirm(): + with patch("goose.utils._create_exchange.confirm") as mock_confirm: + yield mock_confirm + + +@pytest.fixture +def mock_sys_exit(): + with patch("sys.exit") as mock_exit: + yield mock_exit + + +@pytest.fixture +def mock_keyring_get_password(): + with patch("keyring.get_password") as mock_get_password: + yield mock_get_password + + +@pytest.fixture +def mock_keyring_set_password(): + with patch("keyring.set_password") as mock_set_password: + yield mock_set_password + + +def test_create_exchange_success(mock_print): + with patch("goose.utils._create_exchange.build_exchange", return_value=TEST_EXCHANGE): + assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + +def test_create_exchange_fail_with_invalid_choice_error(mock_print, mock_sys_exit): + expected_error = InvalidChoiceError( + attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] + ) + with patch("goose.utils._create_exchange.build_exchange", side_effect=expected_error): + create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) + + assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] + mock_sys_exit.assert_called_once_with(1) + + +class TestWhenProviderEnvVarNotFound: + API_KEY_ENV_VAR = "OPENAI_API_KEY" + API_KEY_ENV_VALUE = "api_key_value" + PROVIDER_NAME = "openai" + SERVICE_NAME = "goose" + EXPECTED_ERROR = MissingProviderEnvVariableError(env_variable=API_KEY_ENV_VAR, provider=PROVIDER_NAME) + + def test_create_exchange_get_api_key_from_keychain( + self, mock_print, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = self.API_KEY_ENV_VALUE + + assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_get_password.assert_called_once_with(self.SERVICE_NAME, self.API_KEY_ENV_VAR) + mock_print.assert_called_once_with( + f"Using {self.API_KEY_ENV_VAR} value for {self.PROVIDER_NAME} from your keychain" + ) + mock_sys_exit.assert_not_called() + mock_keyring_set_password.assert_not_called() + + def test_create_exchange_ask_api_key_and_user_set_in_keychain( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password, mock_print + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = self.API_KEY_ENV_VALUE + mock_confirm.return_value = True + + assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_set_password.assert_called_once_with( + self.SERVICE_NAME, self.API_KEY_ENV_VAR, self.API_KEY_ENV_VALUE + ) + mock_confirm.assert_called_once_with( + f"Would you like to save the {self.API_KEY_ENV_VAR} value to your keychain?" + ) + mock_print.assert_called_once_with( + f"Saved {self.API_KEY_ENV_VAR} to your key_chain. " + + f"service_name: goose, user_name: {self.API_KEY_ENV_VAR}" + ) + mock_sys_exit.assert_not_called() + + def test_create_exchange_ask_api_key_and_user_not_set_in_keychain( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = self.API_KEY_ENV_VALUE + mock_confirm.return_value = False + + assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_set_password.assert_not_called() + mock_sys_exit.assert_not_called() + + def test_create_exchange_fails_when_user_not_provide_api_key( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_print + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=self.EXPECTED_ERROR): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = None + mock_confirm.return_value = False + + create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) + + assert ( + "Please set the required environment variable to continue." + in mock_print.call_args_list[0][0][0].renderable + ) + mock_sys_exit.assert_called_once_with(1) + + def _clean_env(self): + os.environ.pop(self.API_KEY_ENV_VAR, None)