Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add execute method to commands #17

Closed
wants to merge 14 commits into from
16 changes: 8 additions & 8 deletions src/goose/cli/prompt/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

from goose.cli.prompt.completer import GoosePromptCompleter
from goose.cli.prompt.lexer import PromptLexer
from goose.command import get_commands
from goose.command.base import Command


def create_prompt() -> PromptSession:
def create_prompt(commands: dict[str, Command]) -> PromptSession:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is an old function but can we get a docstring since it's part of your PR 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for sure, good call!

"""
Create a prompt session with the given commands.

Args:
commands (dict[str, Command]): A dictionary of command names, and instances of Command classes.
"""
# Define custom style
style = Style.from_dict(
{
Expand Down Expand Up @@ -52,12 +58,6 @@ def _(event: KeyPressEvent) -> None:
# accept completion
buffer.complete_state = None

# instantiate the commands available in the prompt
commands = dict()
command_plugins = get_commands()
for command, command_cls in command_plugins.items():
commands[command] = command_cls()

return PromptSession(
completer=GoosePromptCompleter(commands=commands),
lexer=PromptLexer(list(commands.keys())),
Expand Down
66 changes: 58 additions & 8 deletions src/goose/cli/prompt/goose_prompt_session.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,84 @@
from typing import Optional

from prompt_toolkit import PromptSession
from prompt_toolkit.document import Document
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.validation import DummyValidator

from goose.cli.prompt.create import create_prompt
from goose.cli.prompt.lexer import PromptLexer
from goose.cli.prompt.prompt_validator import PromptValidator
from goose.cli.prompt.user_input import PromptAction, UserInput
from goose.command import get_commands


class GoosePromptSession:
def __init__(self, prompt_session: PromptSession) -> None:
self.prompt_session = prompt_session
def __init__(self) -> None:
# instantiate the commands available in the prompt
self.commands = dict()
command_plugins = get_commands()
for command, command_cls in command_plugins.items():
self.commands[command] = command_cls()
self.main_prompt_session = create_prompt(self.commands)
self.text_prompt_session = PromptSession()

@staticmethod
def create_prompt_session() -> "GoosePromptSession":
return GoosePromptSession(create_prompt())
def get_message_after_commands(self, message: str) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some comments to make the flow of logic clearer to the uninitiated but feel free to ignore!

lexer = PromptLexer(command_names=list(self.commands.keys()))
doc = Document(message)
lines = []
# iterate through each line of the document
for line_num in range(len(doc.lines)):
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved
classes_in_line = lexer.lex_document(doc)(line_num)
line_result = []
i = 0
while i < len(classes_in_line):
# if a command is found and it is not the last part of the line
if classes_in_line[i][0] == "class:command" and i + 1 < len(classes_in_line):
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved
# extract the command name
command_name = classes_in_line[i][1].strip("/").strip(":")
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved
# get the value following the command
if classes_in_line[i + 1][0] == "class:parameter":
command_value = classes_in_line[i + 1][1]
else:
command_value = ""

# execute the command with the given argument, expecting a return value
value_after_execution = self.commands[command_name].execute(command_value, message)
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved

# if the command returns None, raise an error - this should never happen
# since the command should always return a string
if value_after_execution is None:
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Command {command_name} returned None")

# append the result of the command execution to the line results
line_result.append(value_after_execution)
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved
i += 1

# if the part is plain text, just append it to the line results
elif classes_in_line[i][0] == "class:text":
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved
line_result.append(classes_in_line[i][1])
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved
i += 1

# join all processed parts of the current line and add it to the lines list
lines.append("".join(line_result))

# join all processed lines into a single string with newline characters and return
return "\n".join(lines)
lukealvoeiro marked this conversation as resolved.
Show resolved Hide resolved

def get_user_input(self) -> "UserInput":
try:
text = FormattedText([("#00AEAE", "G❯ ")]) # Define the prompt style and text.
message = self.prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
message = self.main_prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False)
if message.strip() in ("exit", ":q"):
return UserInput(PromptAction.EXIT)

message = self.get_message_after_commands(message)
return UserInput(PromptAction.CONTINUE, message)
except (EOFError, KeyboardInterrupt):
return UserInput(PromptAction.EXIT)

def get_save_session_name(self) -> Optional[str]:
return self.prompt_session.prompt(
return self.text_prompt_session.prompt(
"Enter a name to save this session under. A name will be generated for you if empty: ",
validator=DummyValidator(),
)
).strip(" ")
18 changes: 11 additions & 7 deletions src/goose/cli/prompt/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from prompt_toolkit.lexers import Lexer


# These are lexers for the commands in the prompt. This is how we
# are extracting the different parts of a command (here, used for styling),
# but likely will be used to parse the command as well in the future.


def completion_for_command(target_string: str) -> re.Pattern[str]:
escaped_string = re.escape(target_string)
vals = [f"(?:{escaped_string[:i]}(?=$))" for i in range(len(escaped_string), 0, -1)]
Expand All @@ -13,22 +18,21 @@ def completion_for_command(target_string: str) -> re.Pattern[str]:

def command_itself(target_string: str) -> re.Pattern[str]:
escaped_string = re.escape(target_string)
return re.compile(rf"(?<!\S)(\/{escaped_string})")
return re.compile(rf"(?<!\S)(\/{escaped_string}:?)")


def value_for_command(command_string: str) -> re.Pattern[str]:
escaped_string = re.escape(command_string)
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})([^\s]*)")
escaped_string = re.escape(command_string + ":")
return re.compile(rf"(?<=(?<!\S)\/{escaped_string})(?:(?:\"(.*?)(\"|$))|([^\s]*))")


class PromptLexer(Lexer):
def __init__(self, command_names: List[str]) -> None:
self.patterns = []
for command_name in command_names:
full_command = command_name + ":"
self.patterns.append((completion_for_command(full_command), "class:command"))
self.patterns.append((value_for_command(full_command), "class:parameter"))
self.patterns.append((command_itself(full_command), "class:command"))
self.patterns.append((completion_for_command(command_name), "class:command"))
self.patterns.append((value_for_command(command_name), "class:parameter"))
self.patterns.append((command_itself(command_name), "class:command"))

def lex_document(self, document: Document) -> Callable[[int], list]:
def get_line_tokens(line_number: int) -> Tuple[str, str]:
Expand Down
2 changes: 1 addition & 1 deletion src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
if len(self.exchange.messages) == 0 and plan:
self.setup_plan(plan=plan)

self.prompt_session = GoosePromptSession.create_prompt_session()
self.prompt_session = GoosePromptSession()

def setup_plan(self, plan: dict) -> None:
if len(self.exchange.messages):
Expand Down
17 changes: 14 additions & 3 deletions src/goose/command/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,20 @@ class Command(ABC):
"""A command that can be executed by the CLI."""

def get_completions(self, query: str) -> List[Completion]:
"""Get completions for the command."""
"""
Get completions for the command.

Args:
query (str): The current query.
"""
return []

def execute(self, query: str) -> Optional[str]:
"""Execute's the command and replaces it with the output."""
def execute(self, query: str, surrounding_context: str) -> Optional[str]:
"""
Execute's the command and replaces it with the output.

Args:
query (str): The query to execute.
surrounding_context (str): The full user message that the query is a part of.
"""
return ""
lily-de marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 2 additions & 3 deletions src/goose/command/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,5 @@ def get_completions(self, query: str) -> List[Completion]:
)
return completions

def execute(self, query: str) -> str | None:
# GOOSE-TODO: return the query
pass
def execute(self, query: str, _: str) -> str | None:
return query
40 changes: 24 additions & 16 deletions tests/cli/prompt/test_goose_prompt_session.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,55 @@
from unittest.mock import patch

from prompt_toolkit import PromptSession
import pytest
from goose.cli.prompt.goose_prompt_session import GoosePromptSession
from goose.cli.prompt.user_input import PromptAction, UserInput


@pytest.fixture
def mock_prompt_session():
with patch("prompt_toolkit.PromptSession") as mock_prompt_session:
with patch("goose.cli.prompt.goose_prompt_session.PromptSession") as mock_prompt_session:
yield mock_prompt_session


def test_get_save_session_name(mock_prompt_session):
mock_prompt_session.prompt.return_value = "my_session"
goose_prompt_session = GoosePromptSession(mock_prompt_session)
mock_prompt_session.return_value.prompt.return_value = "my_session"
goose_prompt_session = GoosePromptSession()

assert goose_prompt_session.get_save_session_name() == "my_session"


def test_get_user_input_to_continue(mock_prompt_session):
mock_prompt_session.prompt.return_value = "input_value"
goose_prompt_session = GoosePromptSession(mock_prompt_session)
def test_get_save_session_name_with_space(mock_prompt_session):
mock_prompt_session.return_value.prompt.return_value = "my_session "
goose_prompt_session = GoosePromptSession()

user_input = goose_prompt_session.get_user_input()
assert goose_prompt_session.get_save_session_name() == "my_session"


def test_get_user_input_to_continue():
with patch.object(PromptSession, "prompt", return_value="input_value"):
goose_prompt_session = GoosePromptSession()

user_input = goose_prompt_session.get_user_input()

assert user_input == UserInput(PromptAction.CONTINUE, "input_value")
assert user_input == UserInput(PromptAction.CONTINUE, "input_value")


@pytest.mark.parametrize("exit_input", ["exit", ":q"])
def test_get_user_input_to_exit(exit_input, mock_prompt_session):
mock_prompt_session.prompt.return_value = exit_input
goose_prompt_session = GoosePromptSession(mock_prompt_session)
with patch.object(PromptSession, "prompt", return_value=exit_input):
goose_prompt_session = GoosePromptSession()

user_input = goose_prompt_session.get_user_input()
user_input = goose_prompt_session.get_user_input()

assert user_input == UserInput(PromptAction.EXIT)
assert user_input == UserInput(PromptAction.EXIT)


@pytest.mark.parametrize("error", [EOFError, KeyboardInterrupt])
def test_get_user_input_to_exit_when_error_occurs(error, mock_prompt_session):
mock_prompt_session.prompt.side_effect = error
goose_prompt_session = GoosePromptSession(mock_prompt_session)
with patch.object(PromptSession, "prompt", side_effect=error):
goose_prompt_session = GoosePromptSession()

user_input = goose_prompt_session.get_user_input()
user_input = goose_prompt_session.get_user_input()

assert user_input == UserInput(PromptAction.EXIT)
assert user_input == UserInput(PromptAction.EXIT)
47 changes: 35 additions & 12 deletions tests/cli/prompt/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,45 @@ def test_lex_document_ending_char_of_parameter_is_symbol():
assert actual_tokens == expected_tokens


def test_command_itself():
pattern = command_itself("file:")
matches = pattern.match("/file:example.txt")
def assert_pattern_matches(pattern, text, expected_group):
matches = pattern.search(text)
assert matches is not None
assert matches.group(1) == "/file:"
assert matches.group() == expected_group


def test_command_itself():
pattern = command_itself("file")
assert_pattern_matches(pattern, "/file:example.txt", "/file:")
assert_pattern_matches(pattern, "/file asdf", "/file")
assert_pattern_matches(pattern, "some /file", "/file")
assert_pattern_matches(pattern, "some /file:", "/file:")
assert_pattern_matches(pattern, "/file /file", "/file")

assert pattern.search("file") is None
assert pattern.search("/anothercommand") is None


def test_value_for_command():
pattern = value_for_command("file:")
matches = pattern.search("/file:example.txt")
assert matches is not None
assert matches.group(1) == "example.txt"
pattern = value_for_command("file")
assert_pattern_matches(pattern, "/file:example.txt", "example.txt")
assert_pattern_matches(pattern, '/file:"example space.txt"', '"example space.txt"')
assert_pattern_matches(pattern, '/file:"example.txt" some other string', '"example.txt"')
assert_pattern_matches(pattern, "something before /file:example.txt", "example.txt")

# assert no pattern matches when there is no value
assert pattern.search("/file:").group() == ""
assert pattern.search("/file: other").group() == ""
assert pattern.search("/file: ").group() == ""
assert pattern.search("/file other") is None


def test_completion_for_command():
pattern = completion_for_command("file:")
matches = pattern.search("/file:")
assert matches is not None
assert matches.group(1) == "file:"
pattern = completion_for_command("file")
assert_pattern_matches(pattern, "/file", "/file")
assert_pattern_matches(pattern, "/fi", "/fi")
assert_pattern_matches(pattern, "before /fi", "/fi")
assert_pattern_matches(pattern, "some /f", "/f")

assert pattern.search("/file after") is None
assert pattern.search("/ file") is None
assert pattern.search("/file:") is None