Skip to content

Commit

Permalink
fix initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lukealvoeiro committed Sep 26, 2024
1 parent 6f0a944 commit 7e72a84
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/goose/cli/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def run(self) -> None:
user_input = self.prompt_session.get_user_input()
message = Message.user(text=user_input.text) if user_input.to_continue() else None

self.save_session()
self.save_session()

def reply(self) -> None:
"""Reply to the last user message, calling tools as needed
Expand Down
52 changes: 35 additions & 17 deletions src/goose/toolkit/language_server.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
import functools
import types
from typing import Any, Callable, List, Optional, Tuple, Type, TypeVar
from typing import Callable, List, Optional, Tuple, Type

from goose.language_server.client import LanguageServerClient
from goose.language_server.config import LanguageServerConfig
from goose.language_server.logger import LanguageServerLogger
from rich.prompt import Confirm
from rich import print
from goose.notifier import Notifier
from goose.toolkit.base import Requirements, Toolkit, tool
from goose.utils import load_plugins


def method_changes_file(func: Callable) -> Callable:
@functools.wraps(func)
def wrap_method(*args: list, **kwargs: dict) -> Callable:
print(f"Before calling {func.__name__}")
result = func(*args, **kwargs)
print(f"After calling {func.__name__}")
return result

wrap_method._is_method = True
return wrap_method


class LanguageServerCoordinator(Toolkit):
_instance: Optional["LanguageServerCoordinator"] = None

Expand All @@ -43,11 +33,33 @@ def __init__(self, notifier: Notifier, requires: Optional[Requirements] = None)
language_server_config = LanguageServerConfig(trace_lsp_communication=False)
self.language_server_client = LanguageServerClient()

for language_server_cls in load_plugins("goose.language_server").values():
ls = language_server_cls.from_env(config=language_server_config, logger=language_server_logger)
self.language_server_client.register_language_server(ls)
for name, language_server_cls in load_plugins("goose.language_server").items():
try:
ls = language_server_cls.from_env(config=language_server_config, logger=language_server_logger)
is_enabled = Confirm.ask(
f"Would you like to enable the [blue bold]{name}[/] language server?", default=True
)
if is_enabled:
self.language_server_client.register_language_server(ls)
except Exception:
print(f"[red]Failed to initialize the {name} language server[/]")

if not self.language_server_client.language_servers:
self.language_server_client = None
return

developer_toolkit_instance = requires.get("developer")

def method_changes_file(func: Callable) -> Callable:
@functools.wraps(func)
def wrap_method(*args: list, **kwargs: dict) -> Callable:
print()
result = func(*args, **kwargs)
return result

wrap_method._is_method = True
return wrap_method

for method in ["write_file", "patch_file"]:
decorated_method = method_changes_file(getattr(developer_toolkit_instance, method))
setattr(
Expand All @@ -66,6 +78,8 @@ def request_definition(self, file_path: str, line: int, column: int) -> List[Tup
line (int): The line number of the symbol.
column (int): The column number of the symbol.
"""
if not self.language_server_client:
NotImplementedError("No language server is available.")
results = self.language_server_client.request_definition(file_path, line, column)

return [dict(path=result.absolute_path, line_num=result.range.start) for result in results]
Expand All @@ -80,6 +94,8 @@ def request_references(self, file_path: str, line: int, column: int) -> List[Tup
line (int): The line number of the symbol.
column (int): The column number of the symbol.
"""
if not self.language_server_client:
NotImplementedError("No language server is available.")
results = self.language_server_client.request_references(file_path, line, column)
return [dict(path=result.absolute_path, line_num=result.range.start) for result in results]

Expand All @@ -93,5 +109,7 @@ def request_hover(self, file_path: str, line: int, column: int) -> str | None:
line (int): The line number of the symbol.
column (int): The column number of the symbol.
"""
if not self.language_server_client:
NotImplementedError("No language server is available.")
result = self.language_server_client.request_hover(file_path, line, column).value
return result.value if result is not None else None

0 comments on commit 7e72a84

Please sign in to comment.