diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 811ed7ac8..78767115d 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -182,7 +182,7 @@ def run(self) -> None: user_input = self.prompt_session.get_user_input() message = Message.user(text=user_input.text) if user_input.to_continue() else None - self.save_session() + self.save_session() def reply(self) -> None: """Reply to the last user message, calling tools as needed diff --git a/src/goose/toolkit/language_server.py b/src/goose/toolkit/language_server.py index a6c21e6e0..65011c835 100644 --- a/src/goose/toolkit/language_server.py +++ b/src/goose/toolkit/language_server.py @@ -1,26 +1,16 @@ import functools -import types -from typing import Any, Callable, List, Optional, Tuple, Type, TypeVar +from typing import Callable, List, Optional, Tuple, Type + from goose.language_server.client import LanguageServerClient from goose.language_server.config import LanguageServerConfig from goose.language_server.logger import LanguageServerLogger +from rich.prompt import Confirm +from rich import print from goose.notifier import Notifier from goose.toolkit.base import Requirements, Toolkit, tool from goose.utils import load_plugins -def method_changes_file(func: Callable) -> Callable: - @functools.wraps(func) - def wrap_method(*args: list, **kwargs: dict) -> Callable: - print(f"Before calling {func.__name__}") - result = func(*args, **kwargs) - print(f"After calling {func.__name__}") - return result - - wrap_method._is_method = True - return wrap_method - - class LanguageServerCoordinator(Toolkit): _instance: Optional["LanguageServerCoordinator"] = None @@ -43,11 +33,33 @@ def __init__(self, notifier: Notifier, requires: Optional[Requirements] = None) language_server_config = LanguageServerConfig(trace_lsp_communication=False) self.language_server_client = LanguageServerClient() - for language_server_cls in load_plugins("goose.language_server").values(): - ls = language_server_cls.from_env(config=language_server_config, logger=language_server_logger) - self.language_server_client.register_language_server(ls) + for name, language_server_cls in load_plugins("goose.language_server").items(): + try: + ls = language_server_cls.from_env(config=language_server_config, logger=language_server_logger) + is_enabled = Confirm.ask( + f"Would you like to enable the [blue bold]{name}[/] language server?", default=True + ) + if is_enabled: + self.language_server_client.register_language_server(ls) + except Exception: + print(f"[red]Failed to initialize the {name} language server[/]") + + if not self.language_server_client.language_servers: + self.language_server_client = None + return developer_toolkit_instance = requires.get("developer") + + def method_changes_file(func: Callable) -> Callable: + @functools.wraps(func) + def wrap_method(*args: list, **kwargs: dict) -> Callable: + print() + result = func(*args, **kwargs) + return result + + wrap_method._is_method = True + return wrap_method + for method in ["write_file", "patch_file"]: decorated_method = method_changes_file(getattr(developer_toolkit_instance, method)) setattr( @@ -66,6 +78,8 @@ def request_definition(self, file_path: str, line: int, column: int) -> List[Tup line (int): The line number of the symbol. column (int): The column number of the symbol. """ + if not self.language_server_client: + NotImplementedError("No language server is available.") results = self.language_server_client.request_definition(file_path, line, column) return [dict(path=result.absolute_path, line_num=result.range.start) for result in results] @@ -80,6 +94,8 @@ def request_references(self, file_path: str, line: int, column: int) -> List[Tup line (int): The line number of the symbol. column (int): The column number of the symbol. """ + if not self.language_server_client: + NotImplementedError("No language server is available.") results = self.language_server_client.request_references(file_path, line, column) return [dict(path=result.absolute_path, line_num=result.range.start) for result in results] @@ -93,5 +109,7 @@ def request_hover(self, file_path: str, line: int, column: int) -> str | None: line (int): The line number of the symbol. column (int): The column number of the symbol. """ + if not self.language_server_client: + NotImplementedError("No language server is available.") result = self.language_server_client.request_hover(file_path, line, column).value return result.value if result is not None else None