Skip to content

Commit

Permalink
wrapper working now!
Browse files Browse the repository at this point in the history
  • Loading branch information
lukealvoeiro committed Sep 26, 2024
1 parent cf3190f commit 6f0a944
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 8 deletions.
9 changes: 5 additions & 4 deletions src/goose/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ def ensure_config(name: str) -> Profile:
+ pretty_diff(before.read(), after.read())
)
)
should_update = Confirm.ask(
"Do you want to update your profile to use the latest?",
default=False,
)
# should_update = Confirm.ask(
# "Do you want to update your profile to use the latest?",
# default=False,
# )
should_update = False
if should_update:
profiles[name] = profile
write_config(profiles)
Expand Down
2 changes: 1 addition & 1 deletion src/goose/language_server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from goose.language_server.core.lsp_constants import LSPConstants
from goose.language_server.core import lsp_types

import goose.language_server.types as language_server_types
from goose.language_server import language_server_types
from goose.language_server.logger import LanguageServerLogger
from goose.language_server.core.server import (
LanguageServerHandler,
Expand Down
2 changes: 1 addition & 1 deletion src/goose/language_server/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from goose.language_server.base import LanguageServer
from goose.language_server.type_helpers import ensure_all_methods_implemented
import goose.language_server.types as language_server_types
from goose.language_server import language_server_types
from goose.utils.language import Language
from typing import Any, Callable, Iterator, List, Tuple, TypeVar, Union

Expand Down
File renamed without changes.
6 changes: 5 additions & 1 deletion src/goose/toolkit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,9 @@ def tools(self) -> Tuple[Tool, ...]:
This default method looks for functions on the toolkit annotated
with @tool.
"""
candidates = inspect.getmembers(self, predicate=inspect.ismethod)

def predicate(obj: object) -> bool:
return inspect.ismethod(obj) or (hasattr(obj, "_is_method") and obj._is_method)

candidates = inspect.getmembers(self, predicate=predicate)
return (Tool.from_function(candidate) for _, candidate in candidates if getattr(candidate, "_is_tool", None))
25 changes: 24 additions & 1 deletion src/goose/toolkit/language_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Tuple, Type
import functools
import types
from typing import Any, Callable, List, Optional, Tuple, Type, TypeVar
from goose.language_server.client import LanguageServerClient
from goose.language_server.config import LanguageServerConfig
from goose.language_server.logger import LanguageServerLogger
Expand All @@ -7,6 +9,18 @@
from goose.utils import load_plugins


def method_changes_file(func: Callable) -> Callable:
@functools.wraps(func)
def wrap_method(*args: list, **kwargs: dict) -> Callable:
print(f"Before calling {func.__name__}")
result = func(*args, **kwargs)
print(f"After calling {func.__name__}")
return result

wrap_method._is_method = True
return wrap_method


class LanguageServerCoordinator(Toolkit):
_instance: Optional["LanguageServerCoordinator"] = None

Expand All @@ -33,6 +47,15 @@ def __init__(self, notifier: Notifier, requires: Optional[Requirements] = None)
ls = language_server_cls.from_env(config=language_server_config, logger=language_server_logger)
self.language_server_client.register_language_server(ls)

developer_toolkit_instance = requires.get("developer")
for method in ["write_file", "patch_file"]:
decorated_method = method_changes_file(getattr(developer_toolkit_instance, method))
setattr(
developer_toolkit_instance,
method,
decorated_method,
)

@tool
def request_definition(self, file_path: str, line: int, column: int) -> List[Tuple[str, int]]:
"""
Expand Down

0 comments on commit 6f0a944

Please sign in to comment.