Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: saved api_key to keychain for user #104

Merged
merged 7 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"ruamel-yaml>=0.18.6",
"click>=8.1.7",
"prompt-toolkit>=3.0.47",
"keyring>=25.4.1",
]
author = [{ name = "Block", email = "[email protected]" }]
packages = [{ include = "goose", from = "src" }]
Expand Down
54 changes: 9 additions & 45 deletions src/goose/cli/session.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
import sys
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional

from exchange import Message, ToolResult, ToolUse, Text, Exchange
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.invalid_choice_error import InvalidChoiceError
from exchange import Message, ToolResult, ToolUse, Text
from rich import print
from rich.console import RenderableType
from rich.live import Live
from rich.markdown import Markdown
from rich.panel import Panel
from rich.status import Status

from goose.build import build_exchange
from goose.cli.config import PROFILES_CONFIG_PATH, ensure_config, session_path, LOG_PATH
from goose.cli.config import ensure_config, session_path, LOG_PATH
from goose._logger import get_logger, setup_logging
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.notifier import Notifier
from goose.cli.session_notifier import SessionNotifier
from goose.profile import Profile
from goose.utils import droid, load_plugins
from goose.utils._cost_calculator import get_total_cost_message
from goose.utils._create_exchange import create_exchange
from goose.utils.session_file import read_or_create_file, save_latest_session

RESUME_MESSAGE = "I see we were interrupted. How can I help you?"
Expand Down Expand Up @@ -52,24 +47,6 @@ def load_profile(name: Optional[str]) -> Profile:
return profile


class SessionNotifier(Notifier):
Copy link
Collaborator Author

@lifeizhou-ap lifeizhou-ap Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to a seperate file

def __init__(self, status_indicator: Status) -> None:
self.status_indicator = status_indicator
self.live = Live(self.status_indicator, refresh_per_second=8, transient=True)

def log(self, content: RenderableType) -> None:
print(content)

def status(self, status: str) -> None:
self.status_indicator.update(status)

def start(self) -> None:
self.live.start()

def stop(self) -> None:
self.live.stop()


class Session:
"""A session handler for managing interactions between a user and the Goose exchange

Expand All @@ -89,10 +66,12 @@ def __init__(
self.name = droid()
else:
self.name = name
self.profile = profile
self.profile_name = profile
self.prompt_session = GoosePromptSession()
self.status_indicator = Status("", spinner="dots")
self.notifier = SessionNotifier(self.status_indicator)
self.exchange = self._create_exchange()

self.exchange = create_exchange(profile=load_profile(profile), notifier=self.notifier)
setup_logging(log_file_directory=LOG_PATH, log_level=log_level)

self.exchange.messages.extend(self._get_initial_messages())
Expand All @@ -102,21 +81,6 @@ def __init__(

self.prompt_session = GoosePromptSession()

def _create_exchange(self) -> Exchange:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exacted to a util function

try:
return build_exchange(profile=load_profile(self.profile), notifier=self.notifier)
except MissingProviderEnvVariableError as e:
error_message = f"{e.message}. Please set the required environment variable to continue."
print(Panel(error_message, style="red"))
sys.exit(1)
except InvalidChoiceError as e:
error_message = (
f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n"
+ "Configuration doc: https://block-open-source.github.io/goose/configuration.html"
)
print(error_message)
sys.exit(1)

def _get_initial_messages(self) -> List[Message]:
messages = self.load_session()

Expand Down Expand Up @@ -162,7 +126,7 @@ def run(self) -> None:
Runs the main loop to handle user inputs and responses.
Continues until an empty string is returned from the prompt.
"""
print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile or 'default'}[/]")
print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]")
print(f"[dim]saving to {self.session_file_path}")
print()
message = self.process_first_message()
Expand Down
24 changes: 24 additions & 0 deletions src/goose/cli/session_notifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from rich.status import Status
from rich.live import Live
from rich.console import RenderableType
from rich import print

from goose.notifier import Notifier


class SessionNotifier(Notifier):
def __init__(self, status_indicator: Status) -> None:
self.status_indicator = status_indicator
self.live = Live(self.status_indicator, refresh_per_second=8, transient=True)

def log(self, content: RenderableType) -> None:
print(content)

def status(self, status: str) -> None:
self.status_indicator.update(status)

def start(self) -> None:
self.live.start()

def stop(self) -> None:
self.live.stop()
52 changes: 52 additions & 0 deletions src/goose/utils/_create_exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import sys
from typing import Optional
import keyring

from prompt_toolkit import prompt
from prompt_toolkit.shortcuts import confirm
from rich import print
from rich.panel import Panel

from goose.build import build_exchange
from goose.cli.config import PROFILES_CONFIG_PATH
from goose.cli.session_notifier import SessionNotifier
from goose.profile import Profile
from exchange import Exchange
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.providers.base import MissingProviderEnvVariableError


def create_exchange(profile: Profile, notifier: SessionNotifier) -> Exchange:
try:
return build_exchange(profile, notifier=notifier)
except InvalidChoiceError as e:
error_message = (
f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n"
+ "Configuration doc: https://block-open-source.github.io/goose/configuration.html"
)
print(error_message)
sys.exit(1)
except MissingProviderEnvVariableError as e:
api_key = _get_api_key_from_keychain(e.env_variable, e.provider)
if api_key is None or api_key == "":
error_message = f"{e.message}. Please set the required environment variable to continue."
print(Panel(error_message, style="red"))
sys.exit(1)
else:
os.environ[e.env_variable] = api_key
return build_exchange(profile=profile, notifier=notifier)


def _get_api_key_from_keychain(env_variable: str, provider: str) -> Optional[str]:
api_key = keyring.get_password("goose", env_variable)
if api_key is not None:
print(f"Using {env_variable} value for {provider} from your keychain")
else:
api_key = prompt(f"Enter {env_variable} value for {provider}:".strip())
if api_key is not None and len(api_key) > 0:
save_to_keyring = confirm(f"Would you like to save the {env_variable} value to your keychain?")
if save_to_keyring:
keyring.set_password("goose", env_variable, api_key)
print(f"Saved {env_variable} to your key_chain. service_name: goose, user_name: {env_variable}")
return api_key
35 changes: 1 addition & 34 deletions tests/cli/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import pytest
from exchange import Exchange, Message, ToolUse, ToolResult
from exchange.providers.base import MissingProviderEnvVariableError
from exchange.invalid_choice_error import InvalidChoiceError
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.cli.session import Session
Expand All @@ -22,7 +20,7 @@ def mock_specified_session_name():
@pytest.fixture
def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory):
with (
patch("goose.cli.session.build_exchange") as mock_exchange,
patch("goose.cli.session.create_exchange") as mock_exchange,
patch("goose.cli.session.load_profile", return_value=profile_factory()),
patch("goose.cli.session.SessionNotifier") as mock_session_notifier,
patch("goose.cli.session.load_provider", return_value="provider"),
Expand Down Expand Up @@ -158,34 +156,3 @@ def test_set_generated_session_name(create_session_with_mock_configs, mock_sessi
with patch("goose.cli.session.droid", return_value=generated_session_name):
session = create_session_with_mock_configs({"name": None})
assert session.name == generated_session_name


def test_create_exchange_exit_when_env_var_does_not_exist(create_session_with_mock_configs, mock_sessions_path):
session = create_session_with_mock_configs()
expected_error = MissingProviderEnvVariableError(env_variable="OPENAI_API_KEY", provider="openai")
with (
patch("goose.cli.session.build_exchange", side_effect=expected_error),
patch("goose.cli.session.print") as mock_print,
patch("sys.exit") as mock_exit,
):
session._create_exchange()
mock_print.call_args_list[0][0][0].renderable == (
"Missing environment variable OPENAI_API_KEY for provider openai. ",
"Please set the required environment variable to continue.",
)
mock_exit.assert_called_once_with(1)


def test_create_exchange_exit_when_configuration_is_incorrect(create_session_with_mock_configs, mock_sessions_path):
session = create_session_with_mock_configs()
expected_error = InvalidChoiceError(
attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"]
)
with (
patch("goose.cli.session.build_exchange", side_effect=expected_error),
patch("goose.cli.session.print") as mock_print,
patch("sys.exit") as mock_exit,
):
session._create_exchange()
assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0]
mock_exit.assert_called_once_with(1)
151 changes: 151 additions & 0 deletions tests/utils/test_create_exchange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os
from unittest.mock import MagicMock, patch

from exchange.exchange import Exchange
from exchange.invalid_choice_error import InvalidChoiceError
from exchange.providers.base import MissingProviderEnvVariableError
import pytest

from goose.notifier import Notifier
from goose.profile import Profile
from goose.utils._create_exchange import create_exchange

TEST_PROFILE = MagicMock(spec=Profile)
TEST_EXCHANGE = MagicMock(spec=Exchange)
TEST_NOTIFIER = MagicMock(spec=Notifier)


@pytest.fixture
def mock_print():
with patch("goose.utils._create_exchange.print") as mock_print:
yield mock_print


@pytest.fixture
def mock_prompt():
with patch("goose.utils._create_exchange.prompt") as mock_prompt:
yield mock_prompt


@pytest.fixture
def mock_confirm():
with patch("goose.utils._create_exchange.confirm") as mock_confirm:
yield mock_confirm


@pytest.fixture
def mock_sys_exit():
with patch("sys.exit") as mock_exit:
yield mock_exit


@pytest.fixture
def mock_keyring_get_password():
with patch("keyring.get_password") as mock_get_password:
yield mock_get_password


@pytest.fixture
def mock_keyring_set_password():
with patch("keyring.set_password") as mock_set_password:
yield mock_set_password


def test_create_exchange_success(mock_print):
with patch("goose.utils._create_exchange.build_exchange", return_value=TEST_EXCHANGE):
assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE


def test_create_exchange_fail_with_invalid_choice_error(mock_print, mock_sys_exit):
expected_error = InvalidChoiceError(
attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"]
)
with patch("goose.utils._create_exchange.build_exchange", side_effect=expected_error):
create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER)

assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0]
mock_sys_exit.assert_called_once_with(1)


class TestWhenProviderEnvVarNotFound:
API_KEY_ENV_VAR = "OPENAI_API_KEY"
API_KEY_ENV_VALUE = "api_key_value"
PROVIDER_NAME = "openai"
SERVICE_NAME = "goose"
EXPECTED_ERROR = MissingProviderEnvVariableError(env_variable=API_KEY_ENV_VAR, provider=PROVIDER_NAME)

def test_create_exchange_get_api_key_from_keychain(
self, mock_print, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password
):
self._clean_env()
with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]):
mock_keyring_get_password.return_value = self.API_KEY_ENV_VALUE

assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE

assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE
mock_keyring_get_password.assert_called_once_with(self.SERVICE_NAME, self.API_KEY_ENV_VAR)
mock_print.assert_called_once_with(
f"Using {self.API_KEY_ENV_VAR} value for {self.PROVIDER_NAME} from your keychain"
)
mock_sys_exit.assert_not_called()
mock_keyring_set_password.assert_not_called()

def test_create_exchange_ask_api_key_and_user_set_in_keychain(
self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password, mock_print
):
self._clean_env()
with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]):
mock_keyring_get_password.return_value = None
mock_prompt.return_value = self.API_KEY_ENV_VALUE
mock_confirm.return_value = True

assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE

assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE
mock_keyring_set_password.assert_called_once_with(
self.SERVICE_NAME, self.API_KEY_ENV_VAR, self.API_KEY_ENV_VALUE
)
mock_confirm.assert_called_once_with(
f"Would you like to save the {self.API_KEY_ENV_VAR} value to your keychain?"
)
mock_print.assert_called_once_with(
f"Saved {self.API_KEY_ENV_VAR} to your key_chain. "
+ f"service_name: goose, user_name: {self.API_KEY_ENV_VAR}"
)
mock_sys_exit.assert_not_called()

def test_create_exchange_ask_api_key_and_user_not_set_in_keychain(
self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password
):
self._clean_env()
with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]):
mock_keyring_get_password.return_value = None
mock_prompt.return_value = self.API_KEY_ENV_VALUE
mock_confirm.return_value = False

assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE

assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE
mock_keyring_set_password.assert_not_called()
mock_sys_exit.assert_not_called()

def test_create_exchange_fails_when_user_not_provide_api_key(
self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_print
):
self._clean_env()
with patch("goose.utils._create_exchange.build_exchange", side_effect=self.EXPECTED_ERROR):
mock_keyring_get_password.return_value = None
mock_prompt.return_value = None
mock_confirm.return_value = False

create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER)

assert (
"Please set the required environment variable to continue."
in mock_print.call_args_list[0][0][0].renderable
)
mock_sys_exit.assert_called_once_with(1)

def _clean_env(self):
os.environ.pop(self.API_KEY_ENV_VAR, None)