diff --git a/src/goose/cli/prompt/create.py b/src/goose/cli/prompt/create.py index 628c86ccb..899890824 100644 --- a/src/goose/cli/prompt/create.py +++ b/src/goose/cli/prompt/create.py @@ -6,10 +6,16 @@ from goose.cli.prompt.completer import GoosePromptCompleter from goose.cli.prompt.lexer import PromptLexer -from goose.command import get_commands +from goose.command.base import Command -def create_prompt() -> PromptSession: +def create_prompt(commands: dict[str, Command]) -> PromptSession: + """ + Create a prompt session with the given commands. + + Args: + commands (dict[str, Command]): A dictionary of command names, and instances of Command classes. + """ # Define custom style style = Style.from_dict( { @@ -52,12 +58,6 @@ def _(event: KeyPressEvent) -> None: # accept completion buffer.complete_state = None - # instantiate the commands available in the prompt - commands = dict() - command_plugins = get_commands() - for command, command_cls in command_plugins.items(): - commands[command] = command_cls() - return PromptSession( completer=GoosePromptCompleter(commands=commands), lexer=PromptLexer(list(commands.keys())), diff --git a/src/goose/cli/prompt/goose_prompt_session.py b/src/goose/cli/prompt/goose_prompt_session.py index cfcedd80a..14b2cef54 100644 --- a/src/goose/cli/prompt/goose_prompt_session.py +++ b/src/goose/cli/prompt/goose_prompt_session.py @@ -1,34 +1,84 @@ from typing import Optional from prompt_toolkit import PromptSession +from prompt_toolkit.document import Document from prompt_toolkit.formatted_text import FormattedText from prompt_toolkit.validation import DummyValidator from goose.cli.prompt.create import create_prompt +from goose.cli.prompt.lexer import PromptLexer from goose.cli.prompt.prompt_validator import PromptValidator from goose.cli.prompt.user_input import PromptAction, UserInput +from goose.command import get_commands class GoosePromptSession: - def __init__(self, prompt_session: PromptSession) -> None: - self.prompt_session = prompt_session + def __init__(self) -> None: + # instantiate the commands available in the prompt + self.commands = dict() + command_plugins = get_commands() + for command, command_cls in command_plugins.items(): + self.commands[command] = command_cls() + self.main_prompt_session = create_prompt(self.commands) + self.text_prompt_session = PromptSession() - @staticmethod - def create_prompt_session() -> "GoosePromptSession": - return GoosePromptSession(create_prompt()) + def get_message_after_commands(self, message: str) -> str: + lexer = PromptLexer(command_names=list(self.commands.keys())) + doc = Document(message) + lines = [] + # iterate through each line of the document + for line_num in range(len(doc.lines)): + classes_in_line = lexer.lex_document(doc)(line_num) + line_result = [] + i = 0 + while i < len(classes_in_line): + # if a command is found and it is not the last part of the line + if classes_in_line[i][0] == "class:command" and i + 1 < len(classes_in_line): + # extract the command name + command_name = classes_in_line[i][1].strip("/").strip(":") + # get the value following the command + if classes_in_line[i + 1][0] == "class:parameter": + command_value = classes_in_line[i + 1][1] + else: + command_value = "" + + # execute the command with the given argument, expecting a return value + value_after_execution = self.commands[command_name].execute(command_value, message) + + # if the command returns None, raise an error - this should never happen + # since the command should always return a string + if value_after_execution is None: + raise ValueError(f"Command {command_name} returned None") + + # append the result of the command execution to the line results + line_result.append(value_after_execution) + i += 1 + + # if the part is plain text, just append it to the line results + elif classes_in_line[i][0] == "class:text": + line_result.append(classes_in_line[i][1]) + i += 1 + + # join all processed parts of the current line and add it to the lines list + lines.append("".join(line_result)) + + # join all processed lines into a single string with newline characters and return + return "\n".join(lines) def get_user_input(self) -> "UserInput": try: text = FormattedText([("#00AEAE", "G❯ ")]) # Define the prompt style and text. - message = self.prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False) + message = self.main_prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False) if message.strip() in ("exit", ":q"): return UserInput(PromptAction.EXIT) + + message = self.get_message_after_commands(message) return UserInput(PromptAction.CONTINUE, message) except (EOFError, KeyboardInterrupt): return UserInput(PromptAction.EXIT) def get_save_session_name(self) -> Optional[str]: - return self.prompt_session.prompt( + return self.text_prompt_session.prompt( "Enter a name to save this session under. A name will be generated for you if empty: ", validator=DummyValidator(), - ) + ).strip(" ") diff --git a/src/goose/cli/prompt/lexer.py b/src/goose/cli/prompt/lexer.py index b21cae7ac..0e2bb0c91 100644 --- a/src/goose/cli/prompt/lexer.py +++ b/src/goose/cli/prompt/lexer.py @@ -5,6 +5,11 @@ from prompt_toolkit.lexers import Lexer +# These are lexers for the commands in the prompt. This is how we +# are extracting the different parts of a command (here, used for styling), +# but likely will be used to parse the command as well in the future. + + def completion_for_command(target_string: str) -> re.Pattern[str]: escaped_string = re.escape(target_string) vals = [f"(?:{escaped_string[:i]}(?=$))" for i in range(len(escaped_string), 0, -1)] @@ -13,22 +18,21 @@ def completion_for_command(target_string: str) -> re.Pattern[str]: def command_itself(target_string: str) -> re.Pattern[str]: escaped_string = re.escape(target_string) - return re.compile(rf"(? re.Pattern[str]: - escaped_string = re.escape(command_string) - return re.compile(rf"(?<=(? None: self.patterns = [] for command_name in command_names: - full_command = command_name + ":" - self.patterns.append((completion_for_command(full_command), "class:command")) - self.patterns.append((value_for_command(full_command), "class:parameter")) - self.patterns.append((command_itself(full_command), "class:command")) + self.patterns.append((completion_for_command(command_name), "class:command")) + self.patterns.append((value_for_command(command_name), "class:parameter")) + self.patterns.append((command_itself(command_name), "class:command")) def lex_document(self, document: Document) -> Callable[[int], list]: def get_line_tokens(line_number: int) -> Tuple[str, str]: diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 4888e6154..1f72a3020 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -121,7 +121,7 @@ def __init__( if len(self.exchange.messages) == 0 and plan: self.setup_plan(plan=plan) - self.prompt_session = GoosePromptSession.create_prompt_session() + self.prompt_session = GoosePromptSession() def setup_plan(self, plan: dict) -> None: if len(self.exchange.messages): diff --git a/src/goose/command/base.py b/src/goose/command/base.py index dbf2d19db..2ed4ac2d0 100644 --- a/src/goose/command/base.py +++ b/src/goose/command/base.py @@ -8,9 +8,20 @@ class Command(ABC): """A command that can be executed by the CLI.""" def get_completions(self, query: str) -> List[Completion]: - """Get completions for the command.""" + """ + Get completions for the command. + + Args: + query (str): The current query. + """ return [] - def execute(self, query: str) -> Optional[str]: - """Execute's the command and replaces it with the output.""" + def execute(self, query: str, surrounding_context: str) -> Optional[str]: + """ + Execute's the command and replaces it with the output. + + Args: + query (str): The query to execute. + surrounding_context (str): The full user message that the query is a part of. + """ return "" diff --git a/src/goose/command/file.py b/src/goose/command/file.py index 7bbf7d9e3..95b10d53f 100644 --- a/src/goose/command/file.py +++ b/src/goose/command/file.py @@ -56,6 +56,5 @@ def get_completions(self, query: str) -> List[Completion]: ) return completions - def execute(self, query: str) -> str | None: - # GOOSE-TODO: return the query - pass + def execute(self, query: str, _: str) -> str | None: + return query diff --git a/tests/cli/prompt/test_goose_prompt_session.py b/tests/cli/prompt/test_goose_prompt_session.py index eca44cc67..1c9578fa2 100644 --- a/tests/cli/prompt/test_goose_prompt_session.py +++ b/tests/cli/prompt/test_goose_prompt_session.py @@ -1,5 +1,6 @@ from unittest.mock import patch +from prompt_toolkit import PromptSession import pytest from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput @@ -7,41 +8,48 @@ @pytest.fixture def mock_prompt_session(): - with patch("prompt_toolkit.PromptSession") as mock_prompt_session: + with patch("goose.cli.prompt.goose_prompt_session.PromptSession") as mock_prompt_session: yield mock_prompt_session def test_get_save_session_name(mock_prompt_session): - mock_prompt_session.prompt.return_value = "my_session" - goose_prompt_session = GoosePromptSession(mock_prompt_session) + mock_prompt_session.return_value.prompt.return_value = "my_session" + goose_prompt_session = GoosePromptSession() assert goose_prompt_session.get_save_session_name() == "my_session" -def test_get_user_input_to_continue(mock_prompt_session): - mock_prompt_session.prompt.return_value = "input_value" - goose_prompt_session = GoosePromptSession(mock_prompt_session) +def test_get_save_session_name_with_space(mock_prompt_session): + mock_prompt_session.return_value.prompt.return_value = "my_session " + goose_prompt_session = GoosePromptSession() - user_input = goose_prompt_session.get_user_input() + assert goose_prompt_session.get_save_session_name() == "my_session" + + +def test_get_user_input_to_continue(): + with patch.object(PromptSession, "prompt", return_value="input_value"): + goose_prompt_session = GoosePromptSession() + + user_input = goose_prompt_session.get_user_input() - assert user_input == UserInput(PromptAction.CONTINUE, "input_value") + assert user_input == UserInput(PromptAction.CONTINUE, "input_value") @pytest.mark.parametrize("exit_input", ["exit", ":q"]) def test_get_user_input_to_exit(exit_input, mock_prompt_session): - mock_prompt_session.prompt.return_value = exit_input - goose_prompt_session = GoosePromptSession(mock_prompt_session) + with patch.object(PromptSession, "prompt", return_value=exit_input): + goose_prompt_session = GoosePromptSession() - user_input = goose_prompt_session.get_user_input() + user_input = goose_prompt_session.get_user_input() - assert user_input == UserInput(PromptAction.EXIT) + assert user_input == UserInput(PromptAction.EXIT) @pytest.mark.parametrize("error", [EOFError, KeyboardInterrupt]) def test_get_user_input_to_exit_when_error_occurs(error, mock_prompt_session): - mock_prompt_session.prompt.side_effect = error - goose_prompt_session = GoosePromptSession(mock_prompt_session) + with patch.object(PromptSession, "prompt", side_effect=error): + goose_prompt_session = GoosePromptSession() - user_input = goose_prompt_session.get_user_input() + user_input = goose_prompt_session.get_user_input() - assert user_input == UserInput(PromptAction.EXIT) + assert user_input == UserInput(PromptAction.EXIT) diff --git a/tests/cli/prompt/test_lexer.py b/tests/cli/prompt/test_lexer.py index 585bead9b..790bed40b 100644 --- a/tests/cli/prompt/test_lexer.py +++ b/tests/cli/prompt/test_lexer.py @@ -232,22 +232,45 @@ def test_lex_document_ending_char_of_parameter_is_symbol(): assert actual_tokens == expected_tokens -def test_command_itself(): - pattern = command_itself("file:") - matches = pattern.match("/file:example.txt") +def assert_pattern_matches(pattern, text, expected_group): + matches = pattern.search(text) assert matches is not None - assert matches.group(1) == "/file:" + assert matches.group() == expected_group + + +def test_command_itself(): + pattern = command_itself("file") + assert_pattern_matches(pattern, "/file:example.txt", "/file:") + assert_pattern_matches(pattern, "/file asdf", "/file") + assert_pattern_matches(pattern, "some /file", "/file") + assert_pattern_matches(pattern, "some /file:", "/file:") + assert_pattern_matches(pattern, "/file /file", "/file") + + assert pattern.search("file") is None + assert pattern.search("/anothercommand") is None def test_value_for_command(): - pattern = value_for_command("file:") - matches = pattern.search("/file:example.txt") - assert matches is not None - assert matches.group(1) == "example.txt" + pattern = value_for_command("file") + assert_pattern_matches(pattern, "/file:example.txt", "example.txt") + assert_pattern_matches(pattern, '/file:"example space.txt"', '"example space.txt"') + assert_pattern_matches(pattern, '/file:"example.txt" some other string', '"example.txt"') + assert_pattern_matches(pattern, "something before /file:example.txt", "example.txt") + + # assert no pattern matches when there is no value + assert pattern.search("/file:").group() == "" + assert pattern.search("/file: other").group() == "" + assert pattern.search("/file: ").group() == "" + assert pattern.search("/file other") is None def test_completion_for_command(): - pattern = completion_for_command("file:") - matches = pattern.search("/file:") - assert matches is not None - assert matches.group(1) == "file:" + pattern = completion_for_command("file") + assert_pattern_matches(pattern, "/file", "/file") + assert_pattern_matches(pattern, "/fi", "/fi") + assert_pattern_matches(pattern, "before /fi", "/fi") + assert_pattern_matches(pattern, "some /f", "/f") + + assert pattern.search("/file after") is None + assert pattern.search("/ file") is None + assert pattern.search("/file:") is None