diff --git a/.gitignore b/.gitignore index f799b7221..807d19890 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,9 @@ celerybeat.pid .env.* .venv +# exception for local langfuse init vars +!**/packages/exchange/.env.langfuse.local + # Spyder project settings .spyderproject .spyproject diff --git a/.goosehints b/.goosehints index 8b6535a63..bb6acad17 100644 --- a/.goosehints +++ b/.goosehints @@ -1,3 +1,13 @@ This is a python CLI app that uses UV. Read CONTRIBUTING.md for information on how to build and test it as needed. -Some key concepts are that it is run as a command line interface, dependes on the "ai-exchange" package, and has the concept of toolkits which are ways that its behavior can be extended. Look in src/goose and tests. -Once the user has UV installed it should be able to be used effectively along with uvx to run tasks as needed + +Some key concepts are that it is run as a command line interface, dependes on the "ai-exchange" package (which is in packages/exchange in this repo), and has the concept of toolkits which are ways that its behavior can be extended. Look in src/goose and tests. + +Assume the user has UV installed and ensure UV is used to run any python related commands. + +To run tests: + +```sh +uv sync && uv run pytest tests -m 'not integration' +``` + +ideally after each change \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 18cef1199..76feea9e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,44 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.5] - 2024-10-15 +- chore: updates ollama default model from mistral-nemo to qwen2.5 (#150) +- feat: add vision support for Google (#141) +- fix: session resume with arg handled incorrectly (#145) +- docs: add release instructions to CONTRIBUTING.md (#143) +- docs: add link to action, IDE words (#140) +- docs: goosehints doc fix only (#142) + +## [0.9.4] - 2024-10-10 + +- revert: "feat: add local langfuse tracing option (#106)" +- feat: add local langfuse tracing option (#106) +- feat: add groq provider (#134) +- feat: add a deep thinking reasoner model (o1-preview/mini) (#68) +- fix: use concrete SessionNotifier (#135) +- feat: add guards to session management (#101) +- fix: Set default model configuration for the Google provider. (#131) +- test: convert Google Gemini tests to VCR (#118) +- chore: Add goose providers list command (#116) +- docs: working ollama for desktop (#125) +- docs: format and clean up warnings/errors (#120) +- docs: update deploy workflow (#124) +- feat: Implement a goose run command (#121) +- feat: saved api_key to keychain for user (#104) +- docs: add callout plugin (#119) +- chore: add a page to docs for Goose application examples (#117) +- fix: exit the goose and show the error message when provider environment variable is not set (#103) +- fix: Update OpenAI pricing per https://openai.com/api/pricing/ (#110) +- fix: update developer tool prompts to use plan task status to match allowable statuses update_plan tool call (#107) +- fix: removed the panel in the output so that the user won't have unnecessary pane borders in the copied content (#109) +- docs: update links to exchange to the new location (#108) +- chore: setup workspace for exchange (#105) +- fix: resolve uvx when using a git client or IDE (#98) +- ci: add include-markdown for mkdocs (#100) +- chore: fix broken badge on readme (#102) +- feat: add global optional user goosehints file (#73) +- docs: update docs (#99) + ## [0.9.3] - 2024-09-25 - feat: auto save sessions before next user input (#94) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 783a010a6..deea38211 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,6 +48,21 @@ or, as a shortcut, just test ``` +### Enable traces in Goose with [locally hosted Langfuse](https://langfuse.com/docs/deployment/self-host) +> [!NOTE] +> This integration is experimental and we don't currently have integration tests for it. + +Developers can use locally hosted Langfuse tracing by applying the custom `observe_wrapper` decorator defined in `packages/exchange/src/langfuse_wrapper.py` to functions for automatic integration with Langfuse. + +- Run `just langfuse-server` to start your local Langfuse server. It requires Docker. +- Go to http://localhost:3000 and log in with the default email/password output by the shell script (values can also be found in the `.env.langfuse.local` file). +- Run Goose with the --tracing flag enabled i.e., `goose session start --tracing` +- View your traces at http://localhost:3000 + +To extend tracing to additional functions, import `from exchange.langfuse_wrapper import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator. + +Read more about Langfuse's decorator-based tracing [here](https://langfuse.com/docs/sdk/python/decorators). + ## Exchange The lower level generation behind goose is powered by the [`exchange`][ai-exchange] package, also in this repo. @@ -73,6 +88,16 @@ Additions to the [developer toolkit][developer] change the core performance, and This project follows the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification for PR titles. Conventional Commits make it easier to understand the history of a project and facilitate automation around versioning and changelog generation. +## Release + +In order to release a new version of goose, you need to do the following: +1. Update CHANGELOG.md. To get the commit messages since last release, run: `just release-notes` +2. Update version in `pyproject.toml` for `goose` and package dependencies such as `exchange` +3. Create a PR and merge it into main branch +4. Tag the HEAD commit in main branch. To do this, switch to main branch and run: `just tag-push` +5. Publish a new release from the [Github Release UI](https://github.com/block-open-source/goose/releases) + + [issues]: https://github.com/block-open-source/goose/issues [goose-plugins]: https://github.com/block-open-source/goose-plugins [ai-exchange]: https://github.com/block-open-source/goose/tree/main/packages/exchange diff --git a/README.md b/README.md index 6f980ff76..0cbaf35e2 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,6 @@ To install Goose, use `pipx`. First ensure [pipx][pipx] is installed: brew install pipx pipx ensurepath ``` -You can also place `.goosehints` in `~/.config/goose/.goosehints` if you like for always loaded hints personal to you. Then install Goose: @@ -131,7 +130,21 @@ You will see the Goose prompt `G❯`: G❯ type your instructions here exactly as you would tell a developer. ``` -Now you are interacting with Goose in conversational sessions - something like a natural language driven code interpreter. The default toolkit allows Goose to take actions through shell commands and file edits. You can interrupt Goose with `CTRL+D` or `ESC+Enter` at any time to help redirect its efforts. +Now you are interacting with Goose in conversational sessions - think of it as like giving direction to a junior developer. The default toolkit allows Goose to take actions through shell commands and file edits. You can interrupt Goose with `CTRL+D` or `ESC+Enter` at any time to help redirect its efforts. + +> [!TIP] +> You can place a `.goosehints` text file in any directory you launch goose from to give it some background info for new sessions in plain language (eg how to test, what instructions to read to get started or just tell it to read the README!) You can also put a global one `~/.config/goose/.goosehints` if you like for always loaded hints personal to you. + +### Running a goose tasks (one off) + +You can run goose to do things just as a one off, such as tidying up, and then exiting: + +```sh +goose run instructions.md +``` + +This will run until completion as best it can. You can also pass `--resume-session` and it will re-use the first session it finds for context + #### Exit the session @@ -147,16 +160,55 @@ goose session resume To see more documentation on the CLI commands currently available to Goose check out the documentation [here][cli]. If you’d like to develop your own CLI commands for Goose, check out the [Contributing document][contributing]. +### Tracing with Langfuse +> [!NOTE] +> This Langfuse integration is experimental and we don't currently have integration tests for it. + +The exchange package provides a [Langfuse](https://langfuse.com/) wrapper module. The wrapper serves to initialize Langfuse appropriately if the Langfuse server is running locally and otherwise to skip applying the Langfuse observe descorators. + +#### Start your local Langfuse server + +Run `just langfuse-server` to start your local Langfuse server. It requires Docker. + +Read more about local Langfuse deployments [here](https://langfuse.com/docs/deployment/local). + +#### Exchange and Goose integration + +Import `from exchange.langfuse_wrapper import observe_wrapper` and use the `observe_wrapper()` decorator on functions you wish to enable tracing for. `observe_wrapper` functions the same way as Langfuse's observe decorator. + +Read more about Langfuse's decorator-based tracing [here](https://langfuse.com/docs/sdk/python/decorators). + +In Goose, initialization requires certain environment variables to be present: + +- `LANGFUSE_PUBLIC_KEY`: Your Langfuse public key +- `LANGFUSE_SECRET_KEY`: Your Langfuse secret key +- `LANGFUSE_BASE_URL`: The base URL of your Langfuse instance + +By default your local deployment and Goose will use the values in `.env.langfuse.local`. + + + ### Next steps Learn how to modify your Goose profiles.yaml file to add and remove functionality (toolkits) and providing context to get the most out of Goose in our [Getting Started Guide][getting-started]. +## Other ways to run goose + **Want to move out of the terminal and into an IDE?** We have some experimental IDE integrations for VSCode and JetBrains IDEs: * https://github.com/square/goose-vscode * https://github.com/Kvadratni/goose-intellij +**Goose as a Github Action** + +There is also an experimental Github action to run goose as part of your workflow (for example if you ask it to fix an issue): +https://github.com/marketplace/actions/goose-ai-developer-agent + +**With Docker** + +There is also a `Dockerfile` in the root of this project you can use if you want to run goose in a sandboxed fashion. + ## Getting involved! There is a lot to do! If you're interested in contributing, a great place to start is picking a `good-first-issue`-labelled ticket from our [issues list][gh-issues]. More details on how to develop Goose can be found in our [Contributing Guide][contributing]. We are a friendly, collaborative group and look forward to working together![^1] diff --git a/docs/plugins/cli.md b/docs/plugins/cli.md index 5d27563c9..7a3566d34 100644 --- a/docs/plugins/cli.md +++ b/docs/plugins/cli.md @@ -19,11 +19,13 @@ Lists the version of Goose and any associated plugins. **Usage:** ```sh - goose session start [--profile PROFILE] [--plan PLAN] + goose session start [--profile PROFILE] [--plan PLAN] [--log-level [DEBUG|INFO|WARNING|ERROR|CRITICAL]] [--tracing] ``` Starts a new Goose session. +If you want to enable locally hosted Langfuse tracing, pass the --tracing flag after starting your local Langfuse server as outlined in the [Contributing Guide's][contributing] Development guidelines. + #### `resume` **Usage:** diff --git a/docs/plugins/providers.md b/docs/plugins/providers.md index 7527f798a..63e5e80f1 100644 --- a/docs/plugins/providers.md +++ b/docs/plugins/providers.md @@ -8,6 +8,7 @@ Providers in Goose mean "LLM providers" that Goose can interact with. Providers * Azure * Bedrock * Databricks +* Google * Ollama * OpenAI diff --git a/justfile b/justfile index 8b2ca2338..96f51d6b8 100644 --- a/justfile +++ b/justfile @@ -70,3 +70,10 @@ tag: tag-push: just tag git push origin tag v$(just tag_version) + +# get commit messages for a release +release-notes: + git log --pretty=format:"- %s" v$(just tag_version)..HEAD + +langfuse-server: + ./scripts/setup_langfuse.sh diff --git a/packages/exchange/.env.langfuse.local b/packages/exchange/.env.langfuse.local new file mode 100644 index 000000000..cdebcd7a1 --- /dev/null +++ b/packages/exchange/.env.langfuse.local @@ -0,0 +1,16 @@ +# These variables are default initialization variables for locally hosted Langfuse server +LANGFUSE_INIT_PROJECT_NAME=goose-local +LANGFUSE_INIT_PROJECT_PUBLIC_KEY=publickey-local +LANGFUSE_INIT_PROJECT_SECRET_KEY=secretkey-local +LANGFUSE_INIT_USER_EMAIL=local@block.xyz +LANGFUSE_INIT_USER_NAME=localdev +LANGFUSE_INIT_USER_PASSWORD=localpwd + +LANGFUSE_INIT_ORG_ID=local-id +LANGFUSE_INIT_ORG_NAME=local-org +LANGFUSE_INIT_PROJECT_ID=goose + +# These variables are used by Goose +LANGFUSE_PUBLIC_KEY=publickey-local +LANGFUSE_SECRET_KEY=secretkey-local +LANGFUSE_HOST=http://localhost:3000 diff --git a/packages/exchange/pyproject.toml b/packages/exchange/pyproject.toml index 83a9e3c25..a25782f9f 100644 --- a/packages/exchange/pyproject.toml +++ b/packages/exchange/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ai-exchange" -version = "0.9.3" +version = "0.9.5" description = "a uniform python SDK for message generation with LLMs" readme = "README.md" requires-python = ">=3.10" @@ -13,6 +13,8 @@ dependencies = [ "tiktoken>=0.7.0", "httpx>=0.27.0", "tenacity>=9.0.0", + "python-dotenv>=1.0.1", + "langfuse>=2.38.2" ] [tool.hatch.build.targets.wheel] @@ -33,6 +35,7 @@ anthropic = "exchange.providers.anthropic:AnthropicProvider" bedrock = "exchange.providers.bedrock:BedrockProvider" ollama = "exchange.providers.ollama:OllamaProvider" google = "exchange.providers.google:GoogleProvider" +groq = "exchange.providers.groq:GroqProvider" [project.entry-points."exchange.moderator"] passive = "exchange.moderators.passive:PassiveModerator" diff --git a/packages/exchange/src/exchange/checkpoint.py b/packages/exchange/src/exchange/checkpoint.py index f355dd0a2..063ef35d1 100644 --- a/packages/exchange/src/exchange/checkpoint.py +++ b/packages/exchange/src/exchange/checkpoint.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import List from attrs import define, field @@ -31,7 +30,7 @@ class CheckpointData: total_token_count: int = field(default=0) # in order list of individual checkpoints in the exchange - checkpoints: List[Checkpoint] = field(factory=list) + checkpoints: list[Checkpoint] = field(factory=list) # the offset to apply to the message index when calculating the last message index # this is useful because messages on the exchange behave like a queue, where you can only diff --git a/packages/exchange/src/exchange/content.py b/packages/exchange/src/exchange/content.py index b9cc986fc..66957b7c6 100644 --- a/packages/exchange/src/exchange/content.py +++ b/packages/exchange/src/exchange/content.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Optional from attrs import define, asdict @@ -7,11 +7,11 @@ class Content: - def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None: + def __init_subclass__(cls, **kwargs: dict[str, any]) -> None: super().__init_subclass__(**kwargs) CONTENT_TYPES[cls.__name__] = cls - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, any]: data = asdict(self, recurse=True) data["type"] = self.__class__.__name__ return data @@ -26,7 +26,7 @@ class Text(Content): class ToolUse(Content): id: str name: str - parameters: Any + parameters: any is_error: bool = False error_message: Optional[str] = None diff --git a/packages/exchange/src/exchange/exchange.py b/packages/exchange/src/exchange/exchange.py index b2fdbc5ec..942bf78c6 100644 --- a/packages/exchange/src/exchange/exchange.py +++ b/packages/exchange/src/exchange/exchange.py @@ -1,9 +1,9 @@ import json import traceback from copy import deepcopy -from typing import Any, Dict, List, Mapping, Tuple - +from typing import Mapping from attrs import define, evolve, field, Factory +from exchange.langfuse_wrapper import observe_wrapper from tiktoken import get_encoding from exchange.checkpoint import Checkpoint, CheckpointData @@ -41,8 +41,8 @@ class Exchange: model: str system: str moderator: Moderator = field(default=ContextTruncate()) - tools: Tuple[Tool] = field(factory=tuple, converter=tuple) - messages: List[Message] = field(factory=list) + tools: tuple[Tool, ...] = field(factory=tuple, converter=tuple) + messages: list[Message] = field(factory=list) checkpoint_data: CheckpointData = field(factory=CheckpointData) generation_args: dict = field(default=Factory(dict)) @@ -50,7 +50,7 @@ class Exchange: def _toolmap(self) -> Mapping[str, Tool]: return {tool.name: tool for tool in self.tools} - def replace(self, **kwargs: Dict[str, Any]) -> "Exchange": + def replace(self, **kwargs: dict[str, any]) -> "Exchange": """Make a copy of the exchange, replacing any passed arguments""" # TODO: ensure that the checkpoint data is updated correctly. aka, # if we replace the messages, we need to update the checkpoint data @@ -127,6 +127,7 @@ def reply(self, max_tool_use: int = 128) -> Message: return response + @observe_wrapper() def call_function(self, tool_use: ToolUse) -> ToolResult: """Call the function indicated by the tool use""" tool = self._toolmap.get(tool_use.name) @@ -264,7 +265,7 @@ def pop_first_message(self) -> Message: # we've removed all the checkpoints, so we need to reset the message index offset self.checkpoint_data.message_index_offset = 0 - def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + def pop_last_checkpoint(self) -> tuple[Checkpoint, list[Message]]: """ Reverts the exchange back to the last checkpoint, removing associated messages """ @@ -275,7 +276,7 @@ def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: messages.append(self.messages.pop()) return removed_checkpoint, messages - def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + def pop_first_checkpoint(self) -> tuple[Checkpoint, list[Message]]: """ Pop the first checkpoint from the exchange, removing associated messages """ @@ -332,5 +333,6 @@ def is_allowed_to_call_llm(self) -> bool: # this to be a required method of the provider instead. return len(self.messages) > 0 and self.messages[-1].role == "user" - def get_token_usage(self) -> Dict[str, Usage]: + @staticmethod + def get_token_usage() -> dict[str, Usage]: return _token_usage_collector.get_token_usage_group_by_model() diff --git a/packages/exchange/src/exchange/invalid_choice_error.py b/packages/exchange/src/exchange/invalid_choice_error.py index ffbb9899f..def35bbc0 100644 --- a/packages/exchange/src/exchange/invalid_choice_error.py +++ b/packages/exchange/src/exchange/invalid_choice_error.py @@ -1,8 +1,5 @@ -from typing import List - - class InvalidChoiceError(Exception): - def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: + def __init__(self, attribute_name: str, attribute_value: str, available_values: list[str]) -> None: self.attribute_name = attribute_name self.attribute_value = attribute_value self.available_values = available_values diff --git a/packages/exchange/src/exchange/langfuse_wrapper.py b/packages/exchange/src/exchange/langfuse_wrapper.py new file mode 100644 index 000000000..c8cec23ee --- /dev/null +++ b/packages/exchange/src/exchange/langfuse_wrapper.py @@ -0,0 +1,84 @@ +""" +Langfuse Integration Module + +This module provides integration with Langfuse, a tool for monitoring and tracing LLM applications. + +Usage: + Import this module to enable Langfuse integration. + It automatically checks for Langfuse credentials in the .env.langfuse file and for a running Langfuse server. + If these are found, it will set up the necessary client and context for tracing. + +Note: + Run setup_langfuse.sh which automates the steps for running local Langfuse. +""" + +import os +from typing import Callable +from dotenv import load_dotenv +from langfuse.decorators import langfuse_context +import sys +from io import StringIO +from pathlib import Path +from functools import wraps # Add this import + + +def find_package_root(start_path: Path, marker_file: str = "pyproject.toml") -> Path: + while start_path != start_path.parent: + if (start_path / marker_file).exists(): + return start_path + start_path = start_path.parent + return None + + +def auth_check() -> bool: + # Temporarily redirect stdout and stderr to suppress print statements from Langfuse + temp_stderr = StringIO() + sys.stderr = temp_stderr + + # Load environment variables + load_dotenv(LANGFUSE_ENV_FILE, override=True) + + auth_val = langfuse_context.auth_check() + + # Restore stderr + sys.stderr = sys.__stderr__ + return auth_val + + +CURRENT_DIR = Path(__file__).parent +PACKAGE_ROOT = find_package_root(CURRENT_DIR) + +LANGFUSE_ENV_FILE = os.path.join(PACKAGE_ROOT, ".env.langfuse.local") +HAS_LANGFUSE_CREDENTIALS = False +load_dotenv(LANGFUSE_ENV_FILE, override=True) + +HAS_LANGFUSE_CREDENTIALS = auth_check() + + +def observe_wrapper(*args, **kwargs) -> Callable: # noqa + """ + A decorator that wraps a function with Langfuse context observation if credentials are available. + + If Langfuse credentials were found, the function will be wrapped with Langfuse's observe method. + Otherwise, the function will be returned as-is. + + Args: + *args: Positional arguments to pass to langfuse_context.observe. + **kwargs: Keyword arguments to pass to langfuse_context.observe. + + Returns: + Callable: The wrapped function if credentials are available, otherwise the original function. + """ + + def _wrapper(fn: Callable) -> Callable: + if HAS_LANGFUSE_CREDENTIALS: + + @wraps(fn) + def wrapped_fn(*fargs, **fkwargs): # noqa + return langfuse_context.observe(*args, **kwargs)(fn)(*fargs, **fkwargs) + + return wrapped_fn + else: + return fn + + return _wrapper diff --git a/packages/exchange/src/exchange/message.py b/packages/exchange/src/exchange/message.py index 035c60345..5edff692c 100644 --- a/packages/exchange/src/exchange/message.py +++ b/packages/exchange/src/exchange/message.py @@ -1,7 +1,7 @@ import inspect import time from pathlib import Path -from typing import Any, Dict, List, Literal, Type +from typing import Literal from attrs import define, field from jinja2 import Environment, FileSystemLoader @@ -12,7 +12,7 @@ Role = Literal["user", "assistant"] -def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401 +def validate_role_and_content(instance: "Message", *_: any) -> None: # noqa: ANN401 if instance.role == "user": if not (instance.text or instance.tool_result): raise ValueError("User message must include a Text or ToolResult") @@ -25,7 +25,7 @@ def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: AN raise ValueError("Assistant message does not support ToolResult") -def content_converter(contents: List[Dict[str, Any]]) -> List[Content]: +def content_converter(contents: list[dict[str, any]]) -> list[Content]: return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents] @@ -48,9 +48,9 @@ class Message: role: Role = field(default="user") id: str = field(factory=lambda: str(create_object_id(prefix="msg"))) created: int = field(factory=lambda: int(time.time())) - content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter) + content: list[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, any]: return { "role": self.role, "id": self.id, @@ -68,7 +68,7 @@ def text(self) -> str: return "\n".join(result) @property - def tool_use(self) -> List[ToolUse]: + def tool_use(self) -> list[ToolUse]: """All tool use content of this message.""" result = [] for content in self.content: @@ -77,7 +77,7 @@ def tool_use(self) -> List[ToolUse]: return result @property - def tool_result(self) -> List[ToolResult]: + def tool_result(self) -> list[ToolResult]: """All tool result content of this message.""" result = [] for content in self.content: @@ -87,10 +87,10 @@ def tool_result(self) -> List[ToolResult]: @classmethod def load( - cls: Type["Message"], + cls: type["Message"], filename: str, role: Role = "user", - **kwargs: Dict[str, Any], + **kwargs: dict[str, any], ) -> "Message": """Load the message from filename relative to where the load is called. @@ -113,9 +113,9 @@ def load( return cls(role=role, content=[Text(text=rendered_content)]) @classmethod - def user(cls: Type["Message"], text: str) -> "Message": + def user(cls: type["Message"], text: str) -> "Message": return cls(role="user", content=[Text(text)]) @classmethod - def assistant(cls: Type["Message"], text: str) -> "Message": + def assistant(cls: type["Message"], text: str) -> "Message": return cls(role="assistant", content=[Text(text)]) diff --git a/packages/exchange/src/exchange/moderators/__init__.py b/packages/exchange/src/exchange/moderators/__init__.py index 82d032e42..925473e98 100644 --- a/packages/exchange/src/exchange/moderators/__init__.py +++ b/packages/exchange/src/exchange/moderators/__init__.py @@ -1,5 +1,4 @@ from functools import cache -from typing import Type from exchange.invalid_choice_error import InvalidChoiceError from exchange.moderators.base import Moderator @@ -10,7 +9,7 @@ @cache -def get_moderator(name: str) -> Type[Moderator]: +def get_moderator(name: str) -> type[Moderator]: moderators = load_plugins(group="exchange.moderator") if name not in moderators: raise InvalidChoiceError("moderator", name, moderators.keys()) diff --git a/packages/exchange/src/exchange/moderators/base.py b/packages/exchange/src/exchange/moderators/base.py index d7c630c6a..98a6ad663 100644 --- a/packages/exchange/src/exchange/moderators/base.py +++ b/packages/exchange/src/exchange/moderators/base.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod -from typing import Type class Moderator(ABC): @abstractmethod - def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821 pass diff --git a/packages/exchange/src/exchange/moderators/passive.py b/packages/exchange/src/exchange/moderators/passive.py index e3a24efbd..30e6f2c66 100644 --- a/packages/exchange/src/exchange/moderators/passive.py +++ b/packages/exchange/src/exchange/moderators/passive.py @@ -1,7 +1,6 @@ -from typing import Type from exchange.moderators.base import Moderator class PassiveModerator(Moderator): - def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + def rewrite(self, _: type["exchange.exchange.Exchange"]) -> None: # noqa: F821 pass diff --git a/packages/exchange/src/exchange/moderators/summarizer.py b/packages/exchange/src/exchange/moderators/summarizer.py index 7e2dd5588..a7bb1b0f5 100644 --- a/packages/exchange/src/exchange/moderators/summarizer.py +++ b/packages/exchange/src/exchange/moderators/summarizer.py @@ -1,12 +1,10 @@ -from typing import Type - from exchange import Message from exchange.checkpoint import CheckpointData from exchange.moderators import ContextTruncate, PassiveModerator class ContextSummarizer(ContextTruncate): - def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + def rewrite(self, exchange: type["exchange.exchange.Exchange"]) -> None: # noqa: F821 """Summarize the context history up to the last few messages in the exchange""" self._update_system_prompt_token_count(exchange) diff --git a/packages/exchange/src/exchange/moderators/truncate.py b/packages/exchange/src/exchange/moderators/truncate.py index 41115f663..a9c08b650 100644 --- a/packages/exchange/src/exchange/moderators/truncate.py +++ b/packages/exchange/src/exchange/moderators/truncate.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Optional from exchange.checkpoint import CheckpointData from exchange.message import Message @@ -62,7 +62,7 @@ def _update_system_prompt_token_count(self, exchange: Exchange) -> None: exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count exchange.checkpoint_data.total_token_count += self.system_prompt_token_count - def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]: + def _get_messages_to_remove(self, exchange: Exchange) -> list[Message]: # this keeps all the messages/checkpoints throwaway_exchange = exchange.replace( moderator=PassiveModerator(), diff --git a/packages/exchange/src/exchange/providers/__init__.py b/packages/exchange/src/exchange/providers/__init__.py index f92d4f769..56418df47 100644 --- a/packages/exchange/src/exchange/providers/__init__.py +++ b/packages/exchange/src/exchange/providers/__init__.py @@ -1,5 +1,4 @@ from functools import cache -from typing import Type from exchange.invalid_choice_error import InvalidChoiceError from exchange.providers.anthropic import AnthropicProvider # noqa @@ -7,6 +6,7 @@ from exchange.providers.databricks import DatabricksProvider # noqa from exchange.providers.openai import OpenAiProvider # noqa from exchange.providers.ollama import OllamaProvider # noqa +from exchange.providers.groq import GroqProvider # noqa from exchange.providers.azure import AzureProvider # noqa from exchange.providers.google import GoogleProvider # noqa @@ -14,7 +14,7 @@ @cache -def get_provider(name: str) -> Type[Provider]: +def get_provider(name: str) -> type[Provider]: providers = load_plugins(group="exchange.provider") if name not in providers: raise InvalidChoiceError("provider", name, providers.keys()) diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py index 84ecd12fb..c98c6d432 100644 --- a/packages/exchange/src/exchange/providers/anthropic.py +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -1,5 +1,4 @@ import os -from typing import Any, Dict, List, Tuple, Type import httpx @@ -8,6 +7,7 @@ from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import retry_if_status, raise_for_status +from exchange.langfuse_wrapper import observe_wrapper ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" @@ -29,7 +29,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": + def from_env(cls: type["AnthropicProvider"]) -> "AnthropicProvider": cls.check_env_vars() url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) key = os.environ.get("ANTHROPIC_API_KEY") @@ -45,7 +45,7 @@ def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": return cls(client) @staticmethod - def get_usage(data: Dict) -> Usage: # noqa: ANN401 + def get_usage(data: dict) -> Usage: # noqa: ANN401 usage = data.get("usage") input_tokens = usage.get("input_tokens") output_tokens = usage.get("output_tokens") @@ -61,7 +61,7 @@ def get_usage(data: Dict) -> Usage: # noqa: ANN401 ) @staticmethod - def anthropic_response_to_message(response: Dict) -> Message: + def anthropic_response_to_message(response: dict) -> Message: content_blocks = response.get("content", []) content = [] for block in content_blocks: @@ -78,7 +78,7 @@ def anthropic_response_to_message(response: Dict) -> Message: return Message(role="assistant", content=content) @staticmethod - def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]: + def tools_to_anthropic_spec(tools: tuple[Tool, ...]) -> list[dict[str, any]]: return [ { "name": tool.name, @@ -89,7 +89,7 @@ def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]: ] @staticmethod - def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]: + def messages_to_anthropic_spec(messages: list[Message]) -> list[dict[str, any]]: messages_spec = [] # if messages is empty - just make a default for message in messages: @@ -123,14 +123,17 @@ def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]: messages_spec.append(converted) return messages_spec + @observe_wrapper(as_type="generation") def complete( self, model: str, system: str, - messages: List[Message], - tools: List[Tool] = [], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: list[Tool] = None, + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: + if tools is None: + tools = [] tools_set = set() unique_tools = [] for tool in tools: diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py index 4d470f978..fa8814f39 100644 --- a/packages/exchange/src/exchange/providers/azure.py +++ b/packages/exchange/src/exchange/providers/azure.py @@ -1,5 +1,3 @@ -from typing import Type - import httpx import os @@ -21,7 +19,7 @@ def __init__(self, client: httpx.Client) -> None: super().__init__(client) @classmethod - def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": + def from_env(cls: type["AzureProvider"]) -> "AzureProvider": cls.check_env_vars() url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME") deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py index c8d860ecc..76b1c3391 100644 --- a/packages/exchange/src/exchange/providers/base.py +++ b/packages/exchange/src/exchange/providers/base.py @@ -1,7 +1,7 @@ import os from abc import ABC, abstractmethod from attrs import define, field -from typing import List, Optional, Tuple, Type +from typing import Optional from exchange.message import Message from exchange.tool import Tool @@ -19,11 +19,11 @@ class Provider(ABC): REQUIRED_ENV_VARS: list[str] = [] @classmethod - def from_env(cls: Type["Provider"]) -> "Provider": + def from_env(cls: type["Provider"]) -> "Provider": return cls() @classmethod - def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None: + def check_env_vars(cls: type["Provider"], instructions_url: Optional[str] = None) -> None: for env_var in cls.REQUIRED_ENV_VARS: if env_var not in os.environ: raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url) @@ -33,9 +33,10 @@ def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: """Generate the next message using the specified model""" pass diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py index 6c32d7cb3..cdc0c29c9 100644 --- a/packages/exchange/src/exchange/providers/bedrock.py +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -4,7 +4,7 @@ import logging import os from datetime import datetime, timezone -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Optional from urllib.parse import quote, urlparse import httpx @@ -15,6 +15,7 @@ from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import raise_for_status, retry_if_status from exchange.tool import Tool +from exchange.langfuse_wrapper import observe_wrapper SERVICE = "bedrock-runtime" UTC = timezone.utc @@ -36,7 +37,7 @@ def __init__( aws_access_key: str, aws_secret_key: str, aws_session_token: Optional[str] = None, - **kwargs: Dict[str, Any], + **kwargs: dict[str, any], ) -> None: self.region = aws_region self.host = f"https://{SERVICE}.{aws_region}.amazonaws.com/" @@ -45,7 +46,7 @@ def __init__( self.session_token = aws_session_token super().__init__(base_url=self.host, timeout=600, **kwargs) - def post(self, path: str, json: Dict, **kwargs: Dict[str, Any]) -> httpx.Response: + def post(self, path: str, json: dict, **kwargs: dict[str, any]) -> httpx.Response: signed_headers = self.sign_and_get_headers( method="POST", url=path, @@ -60,7 +61,7 @@ def sign_and_get_headers( url: str, payload: dict, service: str, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Sign the request and generate the necessary headers for AWS authentication. @@ -72,10 +73,10 @@ def sign_and_get_headers( region (str): The AWS region. access_key (str): The AWS access key. secret_key (str): The AWS secret key. - session_token (Optional[str]): The AWS session token, if any. + session_token (optional[str]): The AWS session token, if any. Returns: - Dict[str, str]: The headers required for the request. + dict[str, str]: The headers required for the request. """ def sign(key: bytes, msg: str) -> bytes: @@ -160,7 +161,7 @@ def __init__(self, client: AwsClient) -> None: self.client = client @classmethod - def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": + def from_env(cls: type["BedrockProvider"]) -> "BedrockProvider": cls.check_env_vars() aws_region = os.environ.get("AWS_REGION", "us-east-1") aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID") @@ -175,26 +176,27 @@ def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": ) return cls(client=client) + @observe_wrapper(as_type="generation") def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: """ Generate a completion response from the Bedrock gateway. Args: model (str): The model identifier. system (str): The system prompt or configuration. - messages (List[Message]): A list of messages to be processed by the model. - tools (Tuple[Tool]): A tuple of tools to be used in the completion process. + messages (list[Message]): A list of messages to be processed by the model. + tools (tuple[Tool]): A tuple of tools to be used in the completion process. **kwargs: Additional keyword arguments for inference configuration. Returns: - Tuple[Message, Usage]: A tuple containing the response message and usage data. + tuple[Message, Usage]: A tuple containing the response message and usage data. """ inference_config = dict( @@ -231,7 +233,7 @@ def complete( return self.response_to_message(response_message), usage @retry_procedure - def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401 + def _post(self, payload: any, path: str) -> dict: # noqa: ANN401 response = self.client.post(path, json=payload) return raise_for_status(response).json() @@ -311,7 +313,7 @@ def response_to_message(response_message: dict) -> Message: raise Exception("Invalid response") @staticmethod - def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: + def tools_to_bedrock_spec(tools: tuple[Tool, ...]) -> Optional[dict]: if len(tools) == 0: return None # API requires a non-empty tool config or None tools_added = set() diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py index 9bd582dc5..b8f92dca9 100644 --- a/packages/exchange/src/exchange/providers/databricks.py +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, List, Tuple, Type - import httpx import os @@ -13,7 +11,7 @@ tools_to_openai_spec, ) from exchange.tool import Tool - +from exchange.langfuse_wrapper import observe_wrapper retry_procedure = retry( wait=wait_fixed(2), @@ -43,7 +41,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": + def from_env(cls: type["DatabricksProvider"]) -> "DatabricksProvider": cls.check_env_vars(cls.instructions_url) url = os.environ.get("DATABRICKS_HOST") key = os.environ.get("DATABRICKS_TOKEN") @@ -69,14 +67,15 @@ def get_usage(data: dict) -> Usage: total_tokens=total_tokens, ) + @observe_wrapper(as_type="generation") def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: payload = dict( messages=[ {"role": "system", "content": system}, diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py index fe83cd605..76ccd7a9c 100644 --- a/packages/exchange/src/exchange/providers/google.py +++ b/packages/exchange/src/exchange/providers/google.py @@ -1,5 +1,4 @@ import os -from typing import Any, Dict, List, Tuple, Type import httpx @@ -7,7 +6,9 @@ from exchange.content import Text, ToolResult, ToolUse from exchange.providers.base import Provider, Usage from tenacity import retry, wait_fixed, stop_after_attempt -from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import raise_for_status, retry_if_status, encode_image +from exchange.langfuse_wrapper import observe_wrapper + GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" @@ -30,7 +31,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": + def from_env(cls: type["GoogleProvider"]) -> "GoogleProvider": cls.check_env_vars(cls.instructions_url) url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) key = os.environ.get("GOOGLE_API_KEY") @@ -45,7 +46,7 @@ def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": return cls(client) @staticmethod - def get_usage(data: Dict) -> Usage: # noqa: ANN401 + def get_usage(data: dict) -> Usage: # noqa: ANN401 usage = data.get("usageMetadata") input_tokens = usage.get("promptTokenCount") output_tokens = usage.get("candidatesTokenCount") @@ -61,7 +62,7 @@ def get_usage(data: Dict) -> Usage: # noqa: ANN401 ) @staticmethod - def google_response_to_message(response: Dict) -> Message: + def google_response_to_message(response: dict) -> Message: candidates = response.get("candidates", []) if candidates: # Only use first candidate for now @@ -85,12 +86,12 @@ def google_response_to_message(response: Dict) -> Message: return Message(role="assistant", content=[]) @staticmethod - def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: + def tools_to_google_spec(tools: tuple[Tool, ...]) -> dict[str, list[dict[str, any]]]: if not tools: return {} converted_tools = [] for tool in tools: - converted_tool: Dict[str, Any] = { + converted_tool: dict[str, any] = { "name": tool.name, "description": tool.description or "", } @@ -100,7 +101,7 @@ def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: return {"functionDeclarations": converted_tools} @staticmethod - def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: + def messages_to_google_spec(messages: list[Message]) -> list[dict[str, any]]: messages_spec = [] for message in messages: role = "user" if message.role == "user" else "model" @@ -111,9 +112,20 @@ def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: elif isinstance(content, ToolUse): converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}}) elif isinstance(content, ToolResult): - converted["parts"].append( - {"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} - ) + if content.output.startswith('"image:'): + image_path = content.output.replace('"image:', "").replace('"', "") + converted["parts"].append( + { + "inline_data": { + "mime_type": "image/png", + "data": f"{encode_image(image_path)}", + } + } + ) + else: + converted["parts"].append( + {"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} + ) messages_spec.append(converted) if not messages_spec: @@ -121,14 +133,15 @@ def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: return messages_spec + @observe_wrapper(as_type="generation") def complete( self, model: str, system: str, - messages: List[Message], - tools: List[Tool] = [], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: list[Tool] = None, + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: tools_set = set() unique_tools = [] for tool in tools: diff --git a/packages/exchange/src/exchange/providers/groq.py b/packages/exchange/src/exchange/providers/groq.py new file mode 100644 index 000000000..0f6472f88 --- /dev/null +++ b/packages/exchange/src/exchange/providers/groq.py @@ -0,0 +1,98 @@ +import os + +from exchange.langfuse_wrapper import observe_wrapper +import httpx + +from exchange.message import Message +from exchange.providers.base import Provider, Usage +from exchange.providers.utils import ( + messages_to_openai_spec, + openai_response_to_message, + openai_single_message_context_length_exceeded, + raise_for_status, + tools_to_openai_spec, +) +from exchange.tool import Tool +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status + +GROQ_HOST = "https://api.groq.com/openai/" + +retry_procedure = retry( + wait=wait_fixed(5), + stop=stop_after_attempt(5), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class GroqProvider(Provider): + """Provides chat completions for models hosted directly by OpenAI.""" + + PROVIDER_NAME = "groq" + REQUIRED_ENV_VARS = ["GROQ_API_KEY"] + instructions_url = "https://console.groq.com/docs/quickstart" + + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: type["GroqProvider"]) -> "GroqProvider": + cls.check_env_vars(cls.instructions_url) + url = os.environ.get("GROQ_HOST", GROQ_HOST) + key = os.environ.get("GROQ_API_KEY") + + client = httpx.Client( + base_url=url + "v1/", + headers={"Authorization": "Bearer " + key}, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: dict) -> Usage: + usage = data.pop("usage") + input_tokens = usage.get("prompt_tokens") + output_tokens = usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + @observe_wrapper(as_type="generation") + def complete( + self, + model: str, + system: str, + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: + system_message = [{"role": "system", "content": system}] + payload = dict( + messages=system_message + messages_to_openai_spec(messages), + model=model, + tools=tools_to_openai_spec(tools) if tools else [], + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(payload) + + # Check for context_length_exceeded error for single, long input message + if "error" in response and len(messages) == 1: + openai_single_message_context_length_exceeded(response["error"]) + + message = openai_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, payload: dict) -> dict: + response = self.client.post("chat/completions", json=payload) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/ollama.py b/packages/exchange/src/exchange/providers/ollama.py index 888564640..51fef5105 100644 --- a/packages/exchange/src/exchange/providers/ollama.py +++ b/packages/exchange/src/exchange/providers/ollama.py @@ -1,29 +1,28 @@ import os -from typing import Type import httpx from exchange.providers.openai import OpenAiProvider OLLAMA_HOST = "http://localhost:11434/" -OLLAMA_MODEL = "mistral-nemo" +OLLAMA_MODEL = "qwen2.5" class OllamaProvider(OpenAiProvider): """Provides chat completions for models hosted by Ollama.""" - __doc__ += """Here's an example profile configuration to try: + __doc__ += f"""Here's an example profile configuration to try: First run: ollama pull qwen2.5, then use this profile: ollama: provider: ollama - processor: qwen2.5 - accelerator: qwen2.5 + processor: {OLLAMA_MODEL} + accelerator: {OLLAMA_MODEL} moderator: truncate toolkits: - name: developer - requires: {} + requires: {{}} """ def __init__(self, client: httpx.Client) -> None: @@ -31,7 +30,7 @@ def __init__(self, client: httpx.Client) -> None: super().__init__(client) @classmethod - def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider": + def from_env(cls: type["OllamaProvider"]) -> "OllamaProvider": ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST) timeout = httpx.Timeout(60 * 10) diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py index b25c5a70a..8701e5429 100644 --- a/packages/exchange/src/exchange/providers/openai.py +++ b/packages/exchange/src/exchange/providers/openai.py @@ -1,5 +1,4 @@ import os -from typing import Any, Dict, List, Tuple, Type import httpx @@ -15,6 +14,7 @@ from exchange.tool import Tool from tenacity import retry, wait_fixed, stop_after_attempt from exchange.providers.utils import retry_if_status +from exchange.langfuse_wrapper import observe_wrapper OPENAI_HOST = "https://api.openai.com/" @@ -37,7 +37,7 @@ def __init__(self, client: httpx.Client) -> None: self.client = client @classmethod - def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": + def from_env(cls: type["OpenAiProvider"]) -> "OpenAiProvider": cls.check_env_vars(cls.instructions_url) url = os.environ.get("OPENAI_HOST", OPENAI_HOST) key = os.environ.get("OPENAI_API_KEY") @@ -65,14 +65,15 @@ def get_usage(data: dict) -> Usage: total_tokens=total_tokens, ) + @observe_wrapper(as_type="generation") def complete( self, model: str, system: str, - messages: List[Message], - tools: Tuple[Tool], - **kwargs: Dict[str, Any], - ) -> Tuple[Message, Usage]: + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}] payload = dict( messages=system_message + messages_to_openai_spec(messages), diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py index 4be7ac31e..9af7287ef 100644 --- a/packages/exchange/src/exchange/providers/utils.py +++ b/packages/exchange/src/exchange/providers/utils.py @@ -1,7 +1,7 @@ import base64 import json import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Optional import httpx from exchange.content import Text, ToolResult, ToolUse @@ -10,10 +10,10 @@ from tenacity import retry_if_exception -def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable: +def retry_if_status(codes: Optional[list[int]] = None, above: Optional[int] = None) -> callable: codes = codes or [] - def predicate(exc: Exception) -> bool: + def predicate(exc: BaseException) -> bool: if isinstance(exc, httpx.HTTPStatusError): if exc.response.status_code in codes: return True @@ -42,7 +42,7 @@ def encode_image(image_path: str) -> str: return base64.b64encode(image_file.read()).decode("utf-8") -def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]: +def messages_to_openai_spec(messages: list[Message]) -> list[dict[str, any]]: messages_spec = [] for message in messages: converted = {"role": message.role} @@ -106,7 +106,7 @@ def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]: return messages_spec -def tools_to_openai_spec(tools: Tuple[Tool]) -> Dict[str, Any]: +def tools_to_openai_spec(tools: tuple[Tool, ...]) -> dict[str, any]: tools_names = set() result = [] for tool in tools: diff --git a/packages/exchange/src/exchange/token_usage_collector.py b/packages/exchange/src/exchange/token_usage_collector.py index 8f0801062..c99110c29 100644 --- a/packages/exchange/src/exchange/token_usage_collector.py +++ b/packages/exchange/src/exchange/token_usage_collector.py @@ -1,5 +1,4 @@ from collections import defaultdict -from typing import Dict from exchange.providers.base import Usage @@ -11,7 +10,7 @@ def __init__(self) -> None: def collect(self, model: str, usage: Usage) -> None: self.usage_data.append((model, usage)) - def get_token_usage_group_by_model(self) -> Dict[str, Usage]: + def get_token_usage_group_by_model(self) -> dict[str, Usage]: usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0)) for model, usage in self.usage_data: usage_by_model = usage_group_by_model[model] diff --git a/packages/exchange/src/exchange/tool.py b/packages/exchange/src/exchange/tool.py index 4ce9e7c50..1ca1f4358 100644 --- a/packages/exchange/src/exchange/tool.py +++ b/packages/exchange/src/exchange/tool.py @@ -1,5 +1,4 @@ import inspect -from typing import Any, Callable, Type from attrs import define @@ -13,17 +12,17 @@ class Tool: Attributes: name (str): The name of the tool description (str): A description of what the tool does - parameters dict[str, Any]: A json schema of the function signature + parameters dict[str, any]: A json schema of the function signature function (Callable): The python function that powers the tool """ name: str description: str - parameters: dict[str, Any] - function: Callable + parameters: dict[str, any] + function: callable @classmethod - def from_function(cls: Type["Tool"], func: Any) -> "Tool": # noqa: ANN401 + def from_function(cls: type["Tool"], func: any) -> "Tool": # noqa: ANN401 """Create a tool instance from a function and its docstring The function must have a docstring - we require it to load the description diff --git a/packages/exchange/src/exchange/utils.py b/packages/exchange/src/exchange/utils.py index 04d5ffa18..b95f1c485 100644 --- a/packages/exchange/src/exchange/utils.py +++ b/packages/exchange/src/exchange/utils.py @@ -1,7 +1,7 @@ import inspect import uuid from importlib.metadata import entry_points -from typing import Any, Callable, Dict, List, Type, get_args, get_origin +from typing import get_args, get_origin from griffe import ( Docstring, @@ -20,7 +20,7 @@ def compact(content: str) -> str: return " ".join(content.split()) -def parse_docstring(func: Callable) -> tuple[str, List[Dict]]: +def parse_docstring(func: callable) -> tuple[str, list[dict]]: """Get description and parameters from function docstring""" function_args = list(inspect.signature(func).parameters.keys()) text = str(func.__doc__) @@ -71,7 +71,7 @@ def parse_docstring(func: Callable) -> tuple[str, List[Dict]]: def _check_section_is_present( - parsed_docstring: List[DocstringSection], section_type: Type[DocstringSectionText] + parsed_docstring: list[DocstringSection], section_type: type[DocstringSectionText] ) -> bool: for section in parsed_docstring: if isinstance(section, section_type): @@ -79,7 +79,7 @@ def _check_section_is_present( return False -def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401 +def json_schema(func: any) -> dict[str, any]: # noqa: ANN401 """Get the json schema for a function""" signature = inspect.signature(func) parameters = signature.parameters @@ -107,16 +107,16 @@ def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401 return schema -def _map_type_to_schema(py_type: Type) -> Dict[str, Any]: # noqa: ANN401 +def _map_type_to_schema(py_type: type) -> dict[str, any]: # noqa: ANN401 origin = get_origin(py_type) args = get_args(py_type) if origin is list or origin is tuple: - return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)} + return {"type": "array", "items": _map_type_to_schema(args[0] if args else any)} elif origin is dict: return { "type": "object", - "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any), + "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else any), } elif py_type is int: return {"type": "integer"} diff --git a/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml b/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml index 3ac8a4fc0..8d7a34239 100644 --- a/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml +++ b/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml @@ -1,7 +1,19 @@ interactions: - request: - body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}' + body: |- + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ], + "model": "gpt-4o-mini" + } headers: accept: - '*/*' @@ -23,35 +35,94 @@ interactions: uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview response: body: - string: '{"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":"stop","index":0,"logprobs":null,"message":{"content":"Hello! - How can I assist you today?","role":"assistant"}}],"created":1727230065,"id":"chatcmpl-ABBjN3AoYlxkP7Vg2lBvUhYeA6j5K","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":18,"total_tokens":27}} - - ' + string: |- + { + "choices": [ + { + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + }, + "finish_reason": "stop", + "index": 0, + "logprobs": null, + "message": { + "content": "Hello! How can I assist you today?\n", + "role": "assistant" + } + } + ], + "created": 1728788170, + "id": "chatcmpl-AHj469gGa9bQaCSikrZYDejfGDx2x", + "model": "gpt-4-32k", + "object": "chat.completion", + "prompt_filter_results": [ + { + "prompt_index": 0, + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + } + } + ], + "system_fingerprint": null, + "usage": { + "completion_tokens": 9, + "prompt_tokens": 18, + "total_tokens": 27 + } + } headers: Cache-Control: - no-cache, must-revalidate Content-Length: - - '825' + - '827' Content-Type: - application/json Date: - - Wed, 25 Sep 2024 02:07:45 GMT - Set-Cookie: test_set_cookie + - Sun, 13 Oct 2024 02:56:10 GMT Strict-Transport-Security: - max-age=31536000; includeSubDomains; preload access-control-allow-origin: - '*' apim-request-id: - - 82e66ef8-ac07-4a43-b60f-9aecec1d8c81 + - b203f657-e7aa-40ea-8969-ba0b83dae854 azureml-model-session: - - d145-20240919052126 - openai-organization: test_openai_org_key + - d156-20241010120317 x-accel-buffering: - 'no' x-content-type-options: - nosniff x-ms-client-request-id: - - 82e66ef8-ac07-4a43-b60f-9aecec1d8c81 + - b203f657-e7aa-40ea-8969-ba0b83dae854 x-ms-rai-invoked: - 'true' x-ms-region: @@ -61,7 +132,7 @@ interactions: x-ratelimit-remaining-tokens: - '79984' x-request-id: - - 38db9001-8b16-4efe-84c9-620e10f18c3c + - 0c6dc92f-a017-4879-b82e-be937533c76e status: code: 200 message: OK diff --git a/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml b/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml index 9da479790..240b83bf4 100644 --- a/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml +++ b/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml @@ -1,13 +1,40 @@ interactions: - request: - body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. - Expect to need to read a file using read_file."}, {"role": "user", "content": - "What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools": - [{"type": "function", "function": {"name": "read_file", "description": "Read - the contents of the file.", "parameters": {"type": "object", "properties": {"filename": - {"type": "string", "description": "The path to the file, which can be relative - or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent - working directory."}}, "required": ["filename"]}}}]}' + body: |- + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. Expect to need to read a file using read_file." + }, + { + "role": "user", + "content": "What are the contents of this file? test.txt" + } + ], + "model": "gpt-4o-mini", + "tools": [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the contents of the file.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The path to the file, which can be relative or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent working directory." + } + }, + "required": [ + "filename" + ] + } + } + } + ] + } headers: accept: - '*/*' @@ -29,10 +56,64 @@ interactions: uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview response: body: - string: '{"choices":[{"content_filter_results":{},"finish_reason":"tool_calls","index":0,"logprobs":null,"message":{"content":null,"role":"assistant","tool_calls":[{"function":{"arguments":"{\n \"filename\": - \"test.txt\"\n}","name":"read_file"},"id":"call_a47abadDxlGKIWjvYYvGVAHa","type":"function"}]}}],"created":1727256650,"id":"chatcmpl-ABIeABbq5WVCq0e0AriGFaYDSih3P","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":16,"prompt_tokens":109,"total_tokens":125}} - - ' + string: |- + { + "choices": [ + { + "content_filter_results": {}, + "finish_reason": "tool_calls", + "index": 0, + "logprobs": null, + "message": { + "content": null, + "role": "assistant", + "tool_calls": [ + { + "function": { + "arguments": "{\n \"filename\": \"test.txt\"\n}", + "name": "read_file" + }, + "id": "call_Dn0idyNSdmHSYpsql3EbFH9L", + "type": "function" + } + ] + } + } + ], + "created": 1728788173, + "id": "chatcmpl-AHj498XEBkixukw2lwNReXCIBStp0", + "model": "gpt-4-32k", + "object": "chat.completion", + "prompt_filter_results": [ + { + "prompt_index": 0, + "content_filter_results": { + "hate": { + "filtered": false, + "severity": "safe" + }, + "self_harm": { + "filtered": false, + "severity": "safe" + }, + "sexual": { + "filtered": false, + "severity": "safe" + }, + "violence": { + "filtered": false, + "severity": "safe" + } + } + } + ], + "system_fingerprint": null, + "usage": { + "completion_tokens": 16, + "prompt_tokens": 109, + "total_tokens": 125 + } + } headers: Cache-Control: - no-cache, must-revalidate @@ -41,33 +122,31 @@ interactions: Content-Type: - application/json Date: - - Wed, 25 Sep 2024 09:30:50 GMT - Set-Cookie: test_set_cookie + - Sun, 13 Oct 2024 02:56:14 GMT Strict-Transport-Security: - max-age=31536000; includeSubDomains; preload access-control-allow-origin: - '*' apim-request-id: - - 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339 + - 0e0af575-5634-415c-88e3-a7bf68549ee5 azureml-model-session: - - d145-20240919052126 - openai-organization: test_openai_org_key + - d159-20241010142543 x-accel-buffering: - 'no' x-content-type-options: - nosniff x-ms-client-request-id: - - 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339 + - 0e0af575-5634-415c-88e3-a7bf68549ee5 x-ms-rai-invoked: - 'true' x-ms-region: - Switzerland North x-ratelimit-remaining-requests: - - '79' + - '77' x-ratelimit-remaining-tokens: - - '79824' + - '79952' x-request-id: - - 401bd803-b790-47b7-b098-98708d44f060 + - e5012889-ef86-449a-908c-2065dbf0954e status: code: 200 message: OK diff --git a/packages/exchange/tests/providers/cassettes/test_google_complete.yaml b/packages/exchange/tests/providers/cassettes/test_google_complete.yaml new file mode 100644 index 000000000..ec01cc8e0 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_google_complete.yaml @@ -0,0 +1,112 @@ +interactions: +- request: + body: |- + { + "system_instruction": { + "parts": [ + { + "text": "You are a helpful assistant." + } + ] + }, + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "Hello" + } + ] + } + ] + } + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '139' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key + response: + body: + string: |- + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "Hello! \ud83d\udc4b What can I do for you today? \ud83d\ude0a \n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 8, + "candidatesTokenCount": 13, + "totalTokenCount": 21 + } + } + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Cache-Control: + - private + Content-Type: + - application/json; charset=UTF-8 + Date: + - Sun, 13 Oct 2024 02:54:58 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=1201 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '858' + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_google_tools.yaml b/packages/exchange/tests/providers/cassettes/test_google_tools.yaml new file mode 100644 index 000000000..c50e86810 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_google_tools.yaml @@ -0,0 +1,137 @@ +interactions: +- request: + body: |- + { + "system_instruction": { + "parts": [ + { + "text": "You are a helpful assistant. Expect to need to read a file using read_file." + } + ] + }, + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "What are the contents of this file? test.txt" + } + ] + } + ], + "tools": { + "functionDeclarations": [ + { + "name": "read_file", + "description": "Read the contents of the file.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The path to the file, which can be relative or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent working directory." + } + }, + "required": [ + "filename" + ] + } + } + ] + } + } + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '600' + content-type: + - application/json + host: + - generativelanguage.googleapis.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key=test_google_api_key + response: + body: + string: |- + { + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "read_file", + "args": { + "filename": "test.txt" + } + } + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 101, + "candidatesTokenCount": 17, + "totalTokenCount": 118 + } + } + headers: + Alt-Svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + Cache-Control: + - private + Content-Type: + - application/json; charset=UTF-8 + Date: + - Sun, 13 Oct 2024 02:54:59 GMT + Server: + - scaffolding on HTTPServer2 + Server-Timing: + - gfet4t7; dur=449 + Transfer-Encoding: + - chunked + Vary: + - Origin + - X-Origin + - Referer + X-Content-Type-Options: + - nosniff + X-Frame-Options: + - SAMEORIGIN + X-XSS-Protection: + - '0' + content-length: + - '947' + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml b/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml index 88bc206ff..84afcbe99 100644 --- a/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml +++ b/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml @@ -23,15 +23,25 @@ interactions: Content-Type: - text/plain; charset=utf-8 Date: - - Sun, 22 Sep 2024 23:40:13 GMT - Set-Cookie: test_set_cookie - openai-organization: test_openai_org_key + - Sun, 13 Oct 2024 04:53:22 GMT status: code: 200 message: OK - request: - body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}], "model": "mistral-nemo"}' + body: |- + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ], + "model": "mistral-nemo" + } headers: accept: - '*/*' @@ -51,17 +61,36 @@ interactions: uri: http://localhost:11434/v1/chat/completions response: body: - string: "{\"id\":\"chatcmpl-429\",\"object\":\"chat.completion\",\"created\":1727048416,\"model\":\"mistral-nemo\",\"system_fingerprint\":\"fp_ollama\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Hello! - I'm here to help. How can I assist you today? Let's chat. \U0001F60A\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":23,\"total_tokens\":33}}\n" + string: |- + { + "id": "chatcmpl-565", + "object": "chat.completion", + "created": 1728795204, + "model": "mistral-nemo", + "system_fingerprint": "fp_ollama", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm here to help. How can I assist you today? \ud83d\ude0a" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 19, + "total_tokens": 29 + } + } headers: Content-Length: - - '356' + - '344' Content-Type: - application/json Date: - - Sun, 22 Sep 2024 23:40:16 GMT - Set-Cookie: test_set_cookie - openai-organization: test_openai_org_key + - Sun, 13 Oct 2024 04:53:24 GMT status: code: 200 message: OK diff --git a/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml b/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml index 7271bf227..d803153b8 100644 --- a/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml +++ b/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml @@ -23,21 +23,46 @@ interactions: Content-Type: - text/plain; charset=utf-8 Date: - - Wed, 25 Sep 2024 09:23:08 GMT - Set-Cookie: test_set_cookie - openai-organization: test_openai_org_key + - Sun, 13 Oct 2024 02:52:00 GMT status: code: 200 message: OK - request: - body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. - Expect to need to read a file using read_file."}, {"role": "user", "content": - "What are the contents of this file? test.txt"}], "model": "mistral-nemo", "tools": - [{"type": "function", "function": {"name": "read_file", "description": "Read - the contents of the file.", "parameters": {"type": "object", "properties": {"filename": - {"type": "string", "description": "The path to the file, which can be relative - or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent - working directory."}}, "required": ["filename"]}}}]}' + body: |- + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. Expect to need to read a file using read_file." + }, + { + "role": "user", + "content": "What are the contents of this file? test.txt" + } + ], + "model": "mistral-nemo", + "tools": [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the contents of the file.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The path to the file, which can be relative or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent working directory." + } + }, + "required": [ + "filename" + ] + } + } + } + ] + } headers: accept: - '*/*' @@ -57,18 +82,46 @@ interactions: uri: http://localhost:11434/v1/chat/completions response: body: - string: '{"id":"chatcmpl-245","object":"chat.completion","created":1727256190,"model":"mistral-nemo","system_fingerprint":"fp_ollama","choices":[{"index":0,"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_z6fgu3z3","type":"function","function":{"name":"read_file","arguments":"{\"filename\":\"test.txt\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":112,"completion_tokens":21,"total_tokens":133}} - - ' + string: |- + { + "id": "chatcmpl-212", + "object": "chat.completion", + "created": 1728787922, + "model": "mistral-nemo", + "system_fingerprint": "fp_ollama", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_h5d3s25w", + "type": "function", + "function": { + "name": "read_file", + "arguments": "{\"filename\":\"test.txt\"}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 112, + "completion_tokens": 21, + "total_tokens": 133 + } + } headers: Content-Length: - '425' Content-Type: - application/json Date: - - Wed, 25 Sep 2024 09:23:10 GMT - Set-Cookie: test_set_cookie - openai-organization: test_openai_org_key + - Sun, 13 Oct 2024 02:52:02 GMT status: code: 200 message: OK diff --git a/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml b/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml index 1a92eb36b..6de10fb5d 100644 --- a/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml +++ b/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml @@ -1,7 +1,19 @@ interactions: - request: - body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}' + body: |- + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello" + } + ], + "model": "gpt-4o-mini" + } headers: accept: - '*/*' @@ -23,25 +35,48 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: "{\n \"id\": \"chatcmpl-AAQTYi3DXJnltAfd5sUH1Wnzh69t3\",\n \"object\": - \"chat.completion\",\n \"created\": 1727048416,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n - \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": - \"assistant\",\n \"content\": \"Hello! How can I assist you today?\",\n - \ \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\": - \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 18,\n \"completion_tokens\": - 9,\n \"total_tokens\": 27,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": - 0\n }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n" + string: |- + { + "id": "chatcmpl-AHj1Y1xN9345uFT3PVMInIYEQ8g4a", + "object": "chat.completion", + "created": 1728788012, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 18, + "completion_tokens": 9, + "total_tokens": 27, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_e2bde53e6e" + } headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8c762399feb55739-SYD + - 8d1c0a328dca3e44-SIN Connection: - keep-alive Content-Type: - application/json Date: - - Sun, 22 Sep 2024 23:40:17 GMT + - Sun, 13 Oct 2024 02:53:33 GMT Server: - cloudflare Set-Cookie: test_set_cookie @@ -51,15 +86,17 @@ interactions: - nosniff access-control-expose-headers: - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 content-length: - - '593' + - '656' openai-organization: test_openai_org_key openai-processing-ms: - - '560' + - '481' openai-version: - '2020-10-01' strict-transport-security: - - max-age=15552000; includeSubDomains; preload + - max-age=31536000; includeSubDomains; preload x-ratelimit-limit-requests: - '10000' x-ratelimit-limit-tokens: @@ -67,13 +104,13 @@ interactions: x-ratelimit-remaining-requests: - '9999' x-ratelimit-remaining-tokens: - - '199973' + - '199972' x-ratelimit-reset-requests: - 8.64s x-ratelimit-reset-tokens: - 8ms x-request-id: - - req_22e26c840219cde3152eaba1ce89483b + - req_85f532ac5fdad6a4af020cab55e2fd4d status: code: 200 message: OK diff --git a/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml b/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml index 30496fcb8..86c6962c3 100644 --- a/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml +++ b/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml @@ -1,13 +1,40 @@ interactions: - request: - body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. - Expect to need to read a file using read_file."}, {"role": "user", "content": - "What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools": - [{"type": "function", "function": {"name": "read_file", "description": "Read - the contents of the file.", "parameters": {"type": "object", "properties": {"filename": - {"type": "string", "description": "The path to the file, which can be relative - or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent - working directory."}}, "required": ["filename"]}}}]}' + body: |- + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant. Expect to need to read a file using read_file." + }, + { + "role": "user", + "content": "What are the contents of this file? test.txt" + } + ], + "model": "gpt-4o-mini", + "tools": [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the contents of the file.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "The path to the file, which can be relative or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent working directory." + } + }, + "required": [ + "filename" + ] + } + } + } + ] + } headers: accept: - '*/*' @@ -29,29 +56,58 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: "{\n \"id\": \"chatcmpl-ABIV2aZWVKQ774RAQ8KHYdNwkI5N7\",\n \"object\": - \"chat.completion\",\n \"created\": 1727256084,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n - \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": - \"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n - \ \"id\": \"call_xXYlw4A7Ud1qtCopuK5gEJrP\",\n \"type\": - \"function\",\n \"function\": {\n \"name\": \"read_file\",\n - \ \"arguments\": \"{\\\"filename\\\":\\\"test.txt\\\"}\"\n }\n - \ }\n ],\n \"refusal\": null\n },\n \"logprobs\": - null,\n \"finish_reason\": \"tool_calls\"\n }\n ],\n \"usage\": - {\n \"prompt_tokens\": 107,\n \"completion_tokens\": 15,\n \"total_tokens\": - 122,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": 0\n - \ }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n" + string: |- + { + "id": "chatcmpl-AHj1axEdpe3coVDULrCjHmXql5euz", + "object": "chat.completion", + "created": 1728788014, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_Z43oz2RtLmNHw9xvFgxA1SC5", + "type": "function", + "function": { + "name": "read_file", + "arguments": "{\"filename\":\"test.txt\"}" + } + } + ], + "refusal": null + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 107, + "completion_tokens": 15, + "total_tokens": 122, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_e2bde53e6e" + } headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8c89f19fed997e43-SYD + - 8d1c0a419b009d0b-SIN Connection: - keep-alive Content-Type: - application/json Date: - - Wed, 25 Sep 2024 09:21:25 GMT + - Sun, 13 Oct 2024 02:53:35 GMT Server: - cloudflare Set-Cookie: test_set_cookie @@ -61,11 +117,13 @@ interactions: - nosniff access-control-expose-headers: - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 content-length: - - '844' + - '907' openai-organization: test_openai_org_key openai-processing-ms: - - '266' + - '442' openai-version: - '2020-10-01' strict-transport-security: @@ -75,15 +133,15 @@ interactions: x-ratelimit-limit-tokens: - '200000' x-ratelimit-remaining-requests: - - '9991' + - '9997' x-ratelimit-remaining-tokens: - - '199952' + - '199951' x-ratelimit-reset-requests: - - 1m9.486s + - 23.873s x-ratelimit-reset-tokens: - 14ms x-request-id: - - req_ff6b5d65c24f40e1faaf049c175e718d + - req_8ec455e318c9f2d6eecf82d1fdf124ab status: code: 200 message: OK diff --git a/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml b/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml index 1b9691d29..c649d6804 100644 --- a/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml +++ b/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml @@ -1,13 +1,53 @@ interactions: - request: - body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What does the first entry in the menu say?"}, {"role": - "assistant", "tool_calls": [{"id": "xyz", "type": "function", "function": {"name": - "screenshot", "arguments": "{}"}}]}, {"role": "tool", "content": [{"type": "text", - "text": "This tool result included an image that is uploaded in the next message."}], - "tool_call_id": "xyz"}, {"role": "user", "content": [{"type": "image_url", "image_url": - {"url": ""}}]}], - "model": "gpt-4o-mini"}' + body: |- + { + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What does the first entry in the menu say?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "xyz", + "type": "function", + "function": { + "name": "screenshot", + "arguments": "{}" + } + } + ] + }, + { + "role": "tool", + "content": [ + { + "type": "text", + "text": "This tool result included an image that is uploaded in the next message." + } + ], + "tool_call_id": "xyz" + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "" + } + } + ] + } + ], + "model": "gpt-4o-mini" + } headers: accept: - '*/*' @@ -29,25 +69,48 @@ interactions: uri: https://api.openai.com/v1/chat/completions response: body: - string: "{\n \"id\": \"chatcmpl-ABIA0YzOHlhqb02K8Ay4Jwsw6xOpk\",\n \"object\": - \"chat.completion\",\n \"created\": 1727254780,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n - \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": - \"assistant\",\n \"content\": \"The first entry in the menu says \\\"Ask - Goose.\\\"\",\n \"refusal\": null\n },\n \"logprobs\": null,\n - \ \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": - 14230,\n \"completion_tokens\": 11,\n \"total_tokens\": 14241,\n \"completion_tokens_details\": - {\n \"reasoning_tokens\": 0\n }\n },\n \"system_fingerprint\": \"fp_e9627b5346\"\n}\n" + string: |- + { + "id": "chatcmpl-AHj1wBv5ZKIB2nh2p8PWvLZ3QEXLH", + "object": "chat.completion", + "created": 1728788036, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "The first entry in the menu says \"Ask Goose.\"", + "refusal": null + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 14230, + "completion_tokens": 11, + "total_tokens": 14241, + "prompt_tokens_details": { + "cached_tokens": 0 + }, + "completion_tokens_details": { + "reasoning_tokens": 0 + } + }, + "system_fingerprint": "fp_8552ec53e1" + } headers: CF-Cache-Status: - DYNAMIC CF-RAY: - - 8c89d1c45d98a883-SYD + - 8d1c0abdef9701f2-SIN Connection: - keep-alive Content-Type: - application/json Date: - - Wed, 25 Sep 2024 08:59:41 GMT + - Sun, 13 Oct 2024 02:53:57 GMT Server: - cloudflare Set-Cookie: test_set_cookie @@ -57,11 +120,13 @@ interactions: - nosniff access-control-expose-headers: - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 content-length: - - '613' + - '676' openai-organization: test_openai_org_key openai-processing-ms: - - '1289' + - '1966' openai-version: - '2020-10-01' strict-transport-security: @@ -71,15 +136,15 @@ interactions: x-ratelimit-limit-tokens: - '200000' x-ratelimit-remaining-requests: - - '9999' + - '9995' x-ratelimit-remaining-tokens: - '199177' x-ratelimit-reset-requests: - - 8.64s + - 37.664s x-ratelimit-reset-tokens: - 246ms x-request-id: - - req_9503b21e31db78c4ebd2b71b304cea72 + - req_6c0595ef0498819df0c77a9ada75a8e5 status: code: 200 message: OK diff --git a/packages/exchange/tests/providers/conftest.py b/packages/exchange/tests/providers/conftest.py index 010504e84..a747e9ce6 100644 --- a/packages/exchange/tests/providers/conftest.py +++ b/packages/exchange/tests/providers/conftest.py @@ -1,11 +1,13 @@ +import json import os import re -from typing import Type, Tuple import pytest +import yaml from exchange import Message, ToolUse, ToolResult, Tool from exchange.providers import Usage, Provider + from tests.conftest import read_file OPENAI_API_KEY = "test_openai_api_key" @@ -53,6 +55,90 @@ def default_azure_env(monkeypatch): monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY) +GOOGLE_API_KEY = "test_google_api_key" + + +@pytest.fixture +def default_google_env(monkeypatch): + """ + This fixture prevents GoogleProvider.from_env() from erring on missing + environment variables. + + When running VCR tests for the first time or after deleting a cassette + recording, set required environment variables, so that real requests don't + fail. Subsequent runs use the recorded data, so don't need them. + """ + if "GOOGLE_API_KEY" not in os.environ: + monkeypatch.setenv("GOOGLE_API_KEY", GOOGLE_API_KEY) + + +class LiteralBlockScalar(str): + """Formats the string as a literal block scalar, preserving whitespace and + without interpreting escape characters""" + + pass + + +def literal_block_scalar_presenter(dumper, data): + """Represents a scalar string as a literal block, via '|' syntax""" + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + + +yaml.add_representer(LiteralBlockScalar, literal_block_scalar_presenter) + + +def process_string_value(string_value): + """Pretty-prints JSON or returns long strings as a LiteralString""" + try: + json_data = json.loads(string_value) + return LiteralBlockScalar(json.dumps(json_data, indent=2)) + except (ValueError, TypeError): + if len(string_value) > 80: + return LiteralBlockScalar(string_value) + return string_value + + +def convert_body_to_literal(data): + """Searches the data for body strings, attempting to pretty-print JSON""" + if isinstance(data, dict): + for key, value in data.items(): + # Handle response body case (e.g., response.body.string) + if key == "body" and isinstance(value, dict) and "string" in value: + value["string"] = process_string_value(value["string"]) + + # Handle request body case (e.g., request.body) + elif key == "body" and isinstance(value, str): + data[key] = process_string_value(value) + + else: + convert_body_to_literal(value) + + elif isinstance(data, list): + for i, item in enumerate(data): + data[i] = convert_body_to_literal(item) + + return data + + +class PrettyPrintJSONBody: + """This makes request and response body recordings more readable.""" + + @staticmethod + def serialize(cassette_dict): + cassette_dict = convert_body_to_literal(cassette_dict) + return yaml.dump(cassette_dict, default_flow_style=False, allow_unicode=True) + + @staticmethod + def deserialize(cassette_string): + return yaml.load(cassette_string, Loader=yaml.Loader) + + +@pytest.fixture(scope="module") +def vcr(vcr): + vcr.register_serializer("yaml", PrettyPrintJSONBody) + return vcr + + @pytest.fixture(scope="module") def vcr_config(): """ @@ -85,6 +171,8 @@ def scrub_request_url(request): request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri) request.headers["host"] = AZURE_ENDPOINT.replace("https://", "") request.headers["api-key"] = AZURE_API_KEY + elif "generativelanguage.googleapis.com" in request.uri: + request.uri = re.sub(r"([?&])key=[^&]+", r"\1key=" + GOOGLE_API_KEY, request.uri) return request @@ -93,19 +181,21 @@ def scrub_response_headers(response): """ This scrubs sensitive response headers. Note they are case-sensitive! """ - response["headers"]["openai-organization"] = OPENAI_ORG_ID - response["headers"]["Set-Cookie"] = "test_set_cookie" + if "openai-organization" in response["headers"]: + response["headers"]["openai-organization"] = OPENAI_ORG_ID + if "Set-Cookie" in response["headers"]: + response["headers"]["Set-Cookie"] = "test_set_cookie" return response -def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: +def complete(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]: provider = provider_cls.from_env() system = "You are a helpful assistant." messages = [Message.user("Hello")] - return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs) + return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs) -def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: +def tools(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]: provider = provider_cls.from_env() system = "You are a helpful assistant. Expect to need to read a file using read_file." messages = [Message.user("What are the contents of this file? test.txt")] @@ -114,7 +204,7 @@ def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, ) -def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: +def vision(provider_cls: type[Provider], model: str, **kwargs) -> tuple[Message, Usage]: provider = provider_cls.from_env() system = "You are a helpful assistant." messages = [ @@ -128,4 +218,4 @@ def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')], ), ] - return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs) + return provider.complete(model=model, system=system, messages=messages, tools=(), **kwargs) diff --git a/packages/exchange/tests/providers/test_anthropic.py b/packages/exchange/tests/providers/test_anthropic.py index 272ebcb0f..daf925270 100644 --- a/packages/exchange/tests/providers/test_anthropic.py +++ b/packages/exchange/tests/providers/test_anthropic.py @@ -91,7 +91,7 @@ def test_message_text_to_anthropic_spec() -> None: def test_messages_to_anthropic_spec() -> None: messages = [ - Message(role="user", content=[Text(text="Hello, Claude")]), + Message(role="user", content=[Text("Hello, Claude")]), Message( role="assistant", content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], @@ -148,7 +148,7 @@ def create_response(status_code, json_data=None): reply_message, reply_usage = anthropic_provider.complete(model=model, system=system, messages=messages) - assert reply_message.content == [Text(text="Hello from Claude!")] + assert reply_message.content == [Text("Hello from Claude!")] assert reply_usage.total_tokens == 35 assert mock_post.call_count == 2 mock_post.assert_any_call( diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py index b46be30b9..44b75d380 100644 --- a/packages/exchange/tests/providers/test_azure.py +++ b/packages/exchange/tests/providers/test_azure.py @@ -14,10 +14,10 @@ @pytest.mark.parametrize( "env_var_name", [ - ("AZURE_CHAT_COMPLETIONS_HOST_NAME"), - ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), - ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), - ("AZURE_CHAT_COMPLETIONS_KEY"), + "AZURE_CHAT_COMPLETIONS_HOST_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", + "AZURE_CHAT_COMPLETIONS_KEY", ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): @@ -43,7 +43,7 @@ def test_from_env_throw_error_when_missing_env_var(env_var_name): def test_azure_complete(default_azure_env): reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL) - assert reply_message.content == [Text(text="Hello! How can I assist you today?")] + assert reply_message.content == [Text("Hello! How can I assist you today?\n")] assert reply_usage.total_tokens == 27 @@ -61,7 +61,7 @@ def test_azure_tools(default_azure_env): tool_use = reply_message.content[0] assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" - assert tool_use.id == "call_a47abadDxlGKIWjvYYvGVAHa" + assert tool_use.id == "call_Dn0idyNSdmHSYpsql3EbFH9L" assert tool_use.name == "read_file" assert tool_use.parameters == {"filename": "test.txt"} assert reply_usage.total_tokens == 125 diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py index f8fcaa4b8..f7b68c034 100644 --- a/packages/exchange/tests/providers/test_bedrock.py +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -15,9 +15,9 @@ @pytest.mark.parametrize( "env_var_name", [ - ("AWS_ACCESS_KEY_ID"), - ("AWS_SECRET_ACCESS_KEY"), - ("AWS_SESSION_TOKEN"), + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py index 4b6793abc..cd01335a7 100644 --- a/packages/exchange/tests/providers/test_databricks.py +++ b/packages/exchange/tests/providers/test_databricks.py @@ -10,8 +10,8 @@ @pytest.mark.parametrize( "env_var_name", [ - ("DATABRICKS_HOST"), - ("DATABRICKS_TOKEN"), + "DATABRICKS_HOST", + "DATABRICKS_TOKEN", ], ) def test_from_env_throw_error_when_missing_env_var(env_var_name): @@ -61,7 +61,7 @@ def test_databricks_completion(mock_error, mock_warning, mock_sleep, mock_post, model=model, system=system, messages=messages, tools=tools ) - assert reply_message.content == [Text(text="Hello!")] + assert reply_message.content == [Text("Hello!")] assert reply_usage.total_tokens == 35 assert mock_post.call_count == 1 mock_post.assert_called_once_with( diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py index 76ae4c8d7..a8db06475 100644 --- a/packages/exchange/tests/providers/test_google.py +++ b/packages/exchange/tests/providers/test_google.py @@ -1,13 +1,15 @@ import os from unittest.mock import patch -import httpx import pytest from exchange import Message, Text from exchange.content import ToolResult, ToolUse from exchange.providers.base import MissingProviderEnvVariableError from exchange.providers.google import GoogleProvider from exchange.tool import Tool +from .conftest import complete, tools + +GOOGLE_MODEL = os.getenv("GOOGLE_MODEL", "gemini-1.5-flash") def example_fn(param: str) -> None: @@ -30,12 +32,6 @@ def test_from_env_throw_error_when_missing_api_key(): assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message -@pytest.fixture -@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) -def google_provider(): - return GoogleProvider.from_env() - - def test_google_response_to_text_message() -> None: response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]} message = GoogleProvider.google_response_to_message(response) @@ -87,7 +83,7 @@ def test_message_text_to_google_spec() -> None: def test_messages_to_google_spec() -> None: messages = [ - Message(role="user", content=[Text(text="Hello, Gemini")]), + Message(role="user", content=[Text("Hello, Gemini")]), Message( role="assistant", content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], @@ -105,54 +101,40 @@ def test_messages_to_google_spec() -> None: assert actual_spec == expected_spec -@patch("httpx.Client.post") -@patch("logging.warning") -@patch("logging.error") -def test_google_completion(mock_error, mock_warning, mock_post, google_provider): - mock_response = { - "candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}], - "usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13}, - } +@pytest.mark.vcr() +def test_google_complete(default_google_env): + reply_message, reply_usage = complete(GoogleProvider, GOOGLE_MODEL) - # First attempts fail with status code 429, 2nd succeeds - def create_response(status_code, json_data=None): - response = httpx.Response(status_code) - response._content = httpx._content.json_dumps(json_data or {}).encode() - response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/") - return response + assert reply_message.content == [Text("Hello! 👋 What can I do for you today? 😊 \n")] + assert reply_usage.total_tokens == 21 - mock_post.side_effect = [ - create_response(429), # 1st attempt - create_response(200, mock_response), # Final success - ] - model = "gemini-1.5-flash" - system = "You are a helpful assistant." - messages = [Message.user("Hello, Gemini")] +@pytest.mark.integration +def test_google_complete_integration(): + reply = complete(GoogleProvider, GOOGLE_MODEL) - reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages) + assert reply[0].content is not None + print("Completion content from Google:", reply[0].content) - assert reply_message.content == [Text(text="Hello from Gemini!")] - assert reply_usage.total_tokens == 13 - assert mock_post.call_count == 2 - mock_post.assert_any_call( - "models/gemini-1.5-flash:generateContent", - json={ - "system_instruction": {"parts": [{"text": "You are a helpful assistant."}]}, - "contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}], - }, - ) +@pytest.mark.vcr() +def test_google_tools(default_google_env): + reply_message, reply_usage = tools(GoogleProvider, GOOGLE_MODEL) -@pytest.mark.integration -def test_google_integration(): - provider = GoogleProvider.from_env() - model = "gemini-1.5-flash" # updated model to a known valid model - system = "You are a helpful assistant." - messages = [Message.user("Hello, Gemini")] + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "read_file" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 118 - # Run the completion - reply = provider.complete(model=model, system=system, messages=messages) - assert reply[0].content is not None - print("Completion content from Google:", reply[0].content) +@pytest.mark.integration +def test_google_tools_integration(): + reply = tools(GoogleProvider, GOOGLE_MODEL) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} diff --git a/packages/exchange/tests/providers/test_ollama.py b/packages/exchange/tests/providers/test_ollama.py index 3ce870d36..b1df70d5c 100644 --- a/packages/exchange/tests/providers/test_ollama.py +++ b/packages/exchange/tests/providers/test_ollama.py @@ -13,8 +13,8 @@ def test_ollama_complete(): reply_message, reply_usage = complete(OllamaProvider, OLLAMA_MODEL) - assert reply_message.content == [Text(text="Hello! I'm here to help. How can I assist you today? Let's chat. 😊")] - assert reply_usage.total_tokens == 33 + assert reply_message.content == [Text("Hello! I'm here to help. How can I assist you today? 😊")] + assert reply_usage.total_tokens == 29 @pytest.mark.integration @@ -31,7 +31,7 @@ def test_ollama_tools(): tool_use = reply_message.content[0] assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" - assert tool_use.id == "call_z6fgu3z3" + assert tool_use.id == "call_h5d3s25w" assert tool_use.name == "read_file" assert tool_use.parameters == {"filename": "test.txt"} assert reply_usage.total_tokens == 133 diff --git a/packages/exchange/tests/providers/test_openai.py b/packages/exchange/tests/providers/test_openai.py index ea979abeb..db0a5261a 100644 --- a/packages/exchange/tests/providers/test_openai.py +++ b/packages/exchange/tests/providers/test_openai.py @@ -25,7 +25,7 @@ def test_from_env_throw_error_when_missing_api_key(): def test_openai_complete(default_openai_env): reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) - assert reply_message.content == [Text(text="Hello! How can I assist you today?")] + assert reply_message.content == [Text("Hello! How can I assist you today?")] assert reply_usage.total_tokens == 27 @@ -43,7 +43,7 @@ def test_openai_tools(default_openai_env): tool_use = reply_message.content[0] assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" - assert tool_use.id == "call_xXYlw4A7Ud1qtCopuK5gEJrP" + assert tool_use.id == "call_Z43oz2RtLmNHw9xvFgxA1SC5" assert tool_use.name == "read_file" assert tool_use.parameters == {"filename": "test.txt"} assert reply_usage.total_tokens == 122 @@ -64,7 +64,7 @@ def test_openai_tools_integration(): def test_openai_vision(default_openai_env): reply_message, reply_usage = vision(OpenAiProvider, OPENAI_MODEL) - assert reply_message.content == [Text(text='The first entry in the menu says "Ask Goose."')] + assert reply_message.content == [Text('The first entry in the menu says "Ask Goose."')] assert reply_usage.total_tokens == 14241 diff --git a/packages/exchange/tests/providers/test_provider_utils.py b/packages/exchange/tests/providers/test_provider_utils.py index 5ad0135ea..2a6ab729b 100644 --- a/packages/exchange/tests/providers/test_provider_utils.py +++ b/packages/exchange/tests/providers/test_provider_utils.py @@ -107,9 +107,9 @@ def test_messages_to_openai_spec() -> None: Message(role="user", content=[Text("How are you?")]), Message( role="assistant", - content=[ToolUse(id=1, name="tool1", parameters={"param1": "value1"})], + content=[ToolUse(id="1", name="tool1", parameters={"param1": "value1"})], ), - Message(role="user", content=[ToolResult(tool_use_id=1, output="Result")]), + Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), ] spec = messages_to_openai_spec(messages) @@ -121,7 +121,7 @@ def test_messages_to_openai_spec() -> None: "role": "assistant", "tool_calls": [ { - "id": 1, + "id": "1", "type": "function", "function": { "name": "tool1", @@ -133,7 +133,7 @@ def test_messages_to_openai_spec() -> None: { "role": "tool", "content": "Result", - "tool_call_id": 1, + "tool_call_id": "1", }, ] @@ -216,7 +216,7 @@ def test_openai_response_to_message_valid_tooluse() -> None: expect = asdict( Message( role="assistant", - content=[ToolUse(id=1, name="example_fn", parameters={"param": "value"})], + content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], ) ) actual.pop("id") diff --git a/packages/exchange/tests/test_exchange.py b/packages/exchange/tests/test_exchange.py index f01ef4694..34937630c 100644 --- a/packages/exchange/tests/test_exchange.py +++ b/packages/exchange/tests/test_exchange.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import pytest from exchange.checkpoint import Checkpoint, CheckpointData @@ -29,12 +27,12 @@ def no_overlapping_checkpoints(exchange: Exchange) -> bool: return True -def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]: +def checkpoint_to_index_pairs(checkpoints: list[Checkpoint]) -> list[tuple[int, int]]: return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints] class MockProvider(Provider): - def __init__(self, sequence: List[Message], usage_dicts: List[dict]): + def __init__(self, sequence: list[Message], usage_dicts: list[dict]): # We'll use init to provide a preplanned reply sequence self.sequence = sequence self.call_count = 0 @@ -56,11 +54,18 @@ def get_usage(data: dict) -> Usage: total_tokens=total_tokens, ) - def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message: + def complete( + self, + model: str, + system: str, + messages: list[Message], + tools: tuple[Tool, ...], + **kwargs: dict[str, any], + ) -> tuple[Message, Usage]: output = self.sequence[self.call_count] usage = self.get_usage(self.usage_dicts[self.call_count]) self.call_count += 1 - return (output, usage) + return output, usage def test_reply_with_unsupported_tool(): @@ -116,7 +121,7 @@ def test_invalid_tool_parameters(): ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) @@ -154,7 +159,7 @@ def test_max_tool_use_when_limit_reached(): ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) @@ -195,7 +200,7 @@ def long_output_tool_char() -> str: ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(long_output_tool_char)], + tools=(Tool.from_function(long_output_tool_char),), moderator=PassiveModerator(), ) @@ -236,7 +241,7 @@ def long_output_tool_token() -> str: ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(long_output_tool_token)], + tools=(Tool.from_function(long_output_tool_token),), moderator=PassiveModerator(), ) @@ -301,7 +306,7 @@ def resumed_exchange() -> Exchange: ex = Exchange( provider=provider, messages=messages, - tools=[], + tools=(), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", checkpoint_data=CheckpointData(), @@ -399,7 +404,7 @@ def test_pop_first_message_no_messages(): provider=MockProvider(sequence=[], usage_dicts=[]), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) @@ -741,7 +746,7 @@ def test_rewind_with_tool_usage(): ), model="gpt-4o-2024-05-13", system="You are a helpful assistant.", - tools=[Tool.from_function(dummy_tool)], + tools=(Tool.from_function(dummy_tool),), moderator=PassiveModerator(), ) ex.add(Message(role="user", content=[Text(text="test")])) diff --git a/packages/exchange/tests/test_exchange_frozen.py b/packages/exchange/tests/test_exchange_frozen.py index a3095b3a3..9d227afab 100644 --- a/packages/exchange/tests/test_exchange_frozen.py +++ b/packages/exchange/tests/test_exchange_frozen.py @@ -9,7 +9,7 @@ class MockProvider(Provider): - def complete(self, model, system, messages, tools=None): + def complete(self, model, system, messages, tools, **kwargs): return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict( {"total_tokens": 35} ) diff --git a/packages/exchange/tests/test_integration_vision.py b/packages/exchange/tests/test_integration_vision.py index 20f165ade..6adf3f041 100644 --- a/packages/exchange/tests/test_integration_vision.py +++ b/packages/exchange/tests/test_integration_vision.py @@ -9,6 +9,7 @@ cases = [ (get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini")), + (get_provider("google"), os.getenv("GOOGLE_MODEL", "gemini-1.5-flash")), ] diff --git a/packages/exchange/tests/test_langfuse_wrapper.py b/packages/exchange/tests/test_langfuse_wrapper.py new file mode 100644 index 000000000..321850978 --- /dev/null +++ b/packages/exchange/tests/test_langfuse_wrapper.py @@ -0,0 +1,46 @@ +import pytest +from unittest.mock import patch, MagicMock +from exchange.langfuse_wrapper import observe_wrapper + + +@pytest.fixture +def mock_langfuse_context(): + with patch("exchange.langfuse_wrapper.langfuse_context") as mock: + yield mock + + +@patch("exchange.langfuse_wrapper.HAS_LANGFUSE_CREDENTIALS", True) +def test_function_is_wrapped(mock_langfuse_context): + mock_observe = MagicMock(side_effect=lambda *args, **kwargs: lambda fn: fn) + mock_langfuse_context.observe = mock_observe + + def original_function(x: int, y: int) -> int: + return x + y + + # test function before we decorate it with + # @observe_wrapper("arg1", kwarg1="kwarg1") + assert not hasattr(original_function, "__wrapped__") + + # ensure we args get passed along (e.g. @observe(capture_input=False, capture_output=False)) + decorated_function = observe_wrapper("arg1", kwarg1="kwarg1")(original_function) + assert hasattr(decorated_function, "__wrapped__") + assert decorated_function.__wrapped__ is original_function, "Function is not properly wrapped" + + assert decorated_function(2, 3) == 5 + mock_observe.assert_called_once() + mock_observe.assert_called_with("arg1", kwarg1="kwarg1") + + +@patch("exchange.langfuse_wrapper.HAS_LANGFUSE_CREDENTIALS", False) +def test_function_is_not_wrapped(mock_langfuse_context): + mock_observe = MagicMock(return_value=lambda f: f) + mock_langfuse_context.observe = mock_observe + + @observe_wrapper("arg1", kwarg1="kwarg1") + def hello() -> str: + return "Hello" + + assert not hasattr(hello, "__wrapped__") + assert hello() == "Hello" + + mock_observe.assert_not_called() diff --git a/packages/exchange/tests/test_summarizer.py b/packages/exchange/tests/test_summarizer.py index fa7281920..7920fe317 100644 --- a/packages/exchange/tests/test_summarizer.py +++ b/packages/exchange/tests/test_summarizer.py @@ -3,11 +3,11 @@ from exchange.content import ToolResult, ToolUse from exchange.moderators.passive import PassiveModerator from exchange.moderators.summarizer import ContextSummarizer -from exchange.providers import Usage +from exchange.providers import Usage, Provider -class MockProvider: - def complete(self, model, system, messages, tools): +class MockProvider(Provider): + def complete(self, model, system, messages, tools, **kwargs): assistant_message_text = "Summarized content here." output_tokens = len(assistant_message_text) total_input_tokens = sum(len(msg.text) for msg in messages) @@ -138,14 +138,14 @@ def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_inst ] -class AnotherMockProvider: +class AnotherMockProvider(Provider): def __init__(self): self.sequence = MESSAGE_SEQUENCE self.current_index = 1 self.summarize_next = False self.summarized_count = 0 - def complete(self, model, system, messages, tools): + def complete(self, model, system, messages, tools, **kwargs): system_prompt_tokens = 100 input_token_count = system_prompt_tokens diff --git a/packages/exchange/tests/test_truncate.py b/packages/exchange/tests/test_truncate.py index 3875303e7..eeb993ff1 100644 --- a/packages/exchange/tests/test_truncate.py +++ b/packages/exchange/tests/test_truncate.py @@ -73,7 +73,7 @@ def __init__(self): self.summarize_next = False self.summarized_count = 0 - def complete(self, model, system, messages, tools): + def complete(self, model, system, messages, tools, **kwargs): input_token_count = SYSTEM_PROMPT_TOKENS message = self.sequence[self.current_index] diff --git a/pyproject.toml b/pyproject.toml index 07a019266..2971b64a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "goose-ai" description = "a programming agent that runs on your machine" -version = "0.9.3" +version = "0.9.5" readme = "README.md" requires-python = ">=3.10" dependencies = [ @@ -12,6 +12,7 @@ dependencies = [ "click>=8.1.7", "prompt-toolkit>=3.0.47", "keyring>=25.4.1", + "langfuse>=2.38.2", ] author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] packages = [{ include = "goose", from = "src" }] @@ -28,6 +29,7 @@ github = "goose.toolkit.github:Github" jira = "goose.toolkit.jira:Jira" screen = "goose.toolkit.screen:Screen" language_server = "goose.toolkit.language_server:LanguageServerCoordinator" +reasoner = "goose.toolkit.reasoner:Reasoner" repo_context = "goose.toolkit.repo_context.repo_context:RepoContext" [project.entry-points."goose.profile"] diff --git a/scripts/langfuse-docker-compose.yaml b/scripts/langfuse-docker-compose.yaml new file mode 100644 index 000000000..153932120 --- /dev/null +++ b/scripts/langfuse-docker-compose.yaml @@ -0,0 +1,46 @@ +services: + langfuse-server: + image: langfuse/langfuse:2 + depends_on: + db: + condition: service_healthy + ports: + - "3000:3000" + environment: + - DATABASE_URL=postgresql://postgres:postgres@db:5432/postgres + - NEXTAUTH_SECRET=mysecret + - SALT=mysalt + - ENCRYPTION_KEY=0000000000000000000000000000000000000000000000000000000000000000 # generate via `openssl rand -hex 32` + - NEXTAUTH_URL=http://localhost:3000 + - TELEMETRY_ENABLED=${TELEMETRY_ENABLED:-true} + - LANGFUSE_ENABLE_EXPERIMENTAL_FEATURES=${LANGFUSE_ENABLE_EXPERIMENTAL_FEATURES:-false} + - LANGFUSE_INIT_ORG_ID=${LANGFUSE_INIT_ORG_ID:-} + - LANGFUSE_INIT_ORG_NAME=${LANGFUSE_INIT_ORG_NAME:-} + - LANGFUSE_INIT_PROJECT_ID=${LANGFUSE_INIT_PROJECT_ID:-} + - LANGFUSE_INIT_PROJECT_NAME=${LANGFUSE_INIT_PROJECT_NAME:-} + - LANGFUSE_INIT_PROJECT_PUBLIC_KEY=${LANGFUSE_INIT_PROJECT_PUBLIC_KEY:-} + - LANGFUSE_INIT_PROJECT_SECRET_KEY=${LANGFUSE_INIT_PROJECT_SECRET_KEY:-} + - LANGFUSE_INIT_USER_EMAIL=${LANGFUSE_INIT_USER_EMAIL:-} + - LANGFUSE_INIT_USER_NAME=${LANGFUSE_INIT_USER_NAME:-} + - LANGFUSE_INIT_USER_PASSWORD=${LANGFUSE_INIT_USER_PASSWORD:-} + + db: + image: postgres + restart: always + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 3s + timeout: 3s + retries: 10 + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=postgres + ports: + - 5432:5432 + volumes: + - database_data:/var/lib/postgresql/data + +volumes: + database_data: + driver: local diff --git a/scripts/setup_langfuse.sh b/scripts/setup_langfuse.sh new file mode 100755 index 000000000..480747959 --- /dev/null +++ b/scripts/setup_langfuse.sh @@ -0,0 +1,99 @@ +#!/bin/bash + +# setup_langfuse.sh +# +# This script sets up and runs Langfuse locally for development and testing purposes. +# +# Key functionalities: +# 1. Downloads the latest docker-compose.yaml from the Langfuse repository +# 2. Starts Langfuse using Docker Compose with default initialization variables +# 3. Waits for the service to be available +# 4. Launches a browser to open the local Langfuse UI +# 5. Prints login credentials from the environment file +# +# Usage: +# ./setup_langfuse.sh +# +# Requirements: +# - Docker +# - curl +# - A .env.langfuse.local file in the env directory +# +# Note: This script is intended for local development use only. + +set -e + +SCRIPT_DIR=$(realpath "$(dirname "${BASH_SOURCE[0]}")") +LANGFUSE_DOCKER_COMPOSE_URL="https://raw.githubusercontent.com/langfuse/langfuse/main/docker-compose.yml" +LANGFUSE_DOCKER_COMPOSE_FILE="langfuse-docker-compose.yaml" +LANGFUSE_ENV_FILE="$SCRIPT_DIR/../packages/exchange/.env.langfuse.local" + +check_dependencies() { + local dependencies=("curl" "docker") + local missing_dependencies=() + + for cmd in "${dependencies[@]}"; do + if ! command -v "$cmd" &> /dev/null; then + missing_dependencies+=("$cmd") + fi + done + + if [ ${#missing_dependencies[@]} -ne 0 ]; then + echo "Missing dependencies: ${missing_dependencies[*]}" + exit 1 + fi +} + +download_docker_compose() { + if ! curl --fail --location --output "$SCRIPT_DIR/langfuse-docker-compose.yaml" "$LANGFUSE_DOCKER_COMPOSE_URL"; then + echo "Failed to download docker-compose file from $LANGFUSE_DOCKER_COMPOSE_URL" + exit 1 + fi +} + +start_docker_compose() { + docker compose --env-file "$LANGFUSE_ENV_FILE" -f "$LANGFUSE_DOCKER_COMPOSE_FILE" up --detach +} + +wait_for_service() { + echo "Waiting for Langfuse to start..." + local retries=10 + local count=0 + until curl --silent http://localhost:3000 > /dev/null; do + ((count++)) + if [ "$count" -ge "$retries" ]; then + echo "Max retries reached. Langfuse did not start in time." + exit 1 + fi + sleep 1 + done + echo "Langfuse is now available!" +} + +launch_browser() { + if [[ "$OSTYPE" == "linux-gnu"* ]]; then + xdg-open "http://localhost:3000" + elif [[ "$OSTYPE" == "darwin"* ]]; then + open "http://localhost:3000" + else + echo "Please open http://localhost:3000 to view Langfuse traces." + fi +} + +print_login_variables() { + if [ -f "$LANGFUSE_ENV_FILE" ]; then + echo "If not already logged in use the following credentials to log in:" + grep -E "LANGFUSE_INIT_USER_EMAIL|LANGFUSE_INIT_USER_PASSWORD" "$LANGFUSE_ENV_FILE" + else + echo "Langfuse environment file with local credentials not found." + fi +} + +check_dependencies +pushd "$SCRIPT_DIR" > /dev/null +download_docker_compose +start_docker_compose +wait_for_service +print_login_variables +launch_browser +popd > /dev/null diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index 7bede0be5..9b613ee1c 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -1,6 +1,6 @@ from functools import cache from pathlib import Path -from typing import Callable, Dict, Mapping, Optional, Tuple +from typing import Mapping, Optional from rich import print from rich.panel import Panel @@ -20,7 +20,7 @@ @cache -def default_profiles() -> Mapping[str, Callable]: +def default_profiles() -> Mapping[str, callable]: return load_plugins(group="goose.profile") @@ -29,7 +29,7 @@ def session_path(name: str) -> Path: return SESSIONS_PATH.joinpath(f"{name}{SESSION_FILE_SUFFIX}") -def write_config(profiles: Dict[str, Profile]) -> None: +def write_config(profiles: dict[str, Profile]) -> None: """Overwrite the config with the passed profiles""" PROFILES_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True) converted = {name: profile.to_dict() for name, profile in profiles.items()} @@ -38,7 +38,7 @@ def write_config(profiles: Dict[str, Profile]) -> None: yaml.dump(converted, f) -def ensure_config(name: Optional[str]) -> Tuple[str, Profile]: +def ensure_config(name: Optional[str]) -> tuple[str, Profile]: """Ensure that the config exists and has the default section""" # TODO we should copy a templated default config in to better document # but this is complicated a bit by autodetecting the provider @@ -70,7 +70,7 @@ def ensure_config(name: Optional[str]) -> Tuple[str, Profile]: return (name, default_profile) -def read_config() -> Dict[str, Profile]: +def read_config() -> dict[str, Profile]: """Return config from the configuration file and validates its contents""" yaml = YAML() @@ -80,7 +80,7 @@ def read_config() -> Dict[str, Profile]: return {name: Profile(**profile) for name, profile in data.items()} -def default_model_configuration() -> Tuple[str, str, str]: +def default_model_configuration() -> tuple[str, str, str]: providers = load_plugins(group="exchange.provider") for provider, cls in providers.items(): try: @@ -102,6 +102,7 @@ def default_model_configuration() -> Tuple[str, str, str]: "databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct", ), + "google": ("gemini-1.5-flash", "gemini-1.5-flash"), } processor, accelerator = recommended.get(provider, ("gpt-4o", "gpt-4o-mini")) return provider, processor, accelerator diff --git a/src/goose/cli/main.py b/src/goose/cli/main.py index 7d1359889..a7f484d90 100644 --- a/src/goose/cli/main.py +++ b/src/goose/cli/main.py @@ -14,6 +14,9 @@ from goose.utils.autocomplete import SUPPORTED_SHELLS, setup_autocomplete from goose.utils.session_file import list_sorted_session_files +LOG_LEVELS = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] +LOG_CHOICE = click.Choice(LOG_LEVELS) + @click.group() def goose_cli() -> None: @@ -136,7 +139,10 @@ def get_session_files() -> dict[str, Path]: @click.option("--profile") @click.option("--plan", type=click.Path(exists=True)) @click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") -def session_start(name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None) -> None: +@click.option("--tracing", is_flag=True, required=False) +def session_start( + name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None, tracing: bool = False +) -> None: """Start a new goose session""" if plan: yaml = YAML() @@ -144,8 +150,12 @@ def session_start(name: Optional[str], profile: str, log_level: str, plan: Optio _plan = yaml.load(f) else: _plan = None - session = Session(name=name, profile=profile, plan=_plan, log_level=log_level) - session.run() + + try: + session = Session(name=name, profile=profile, plan=_plan, log_level=log_level, tracing=tracing) + session.run() + except RuntimeError as e: + print(f"[red]Error: {e}") def parse_args(ctx: click.Context, param: click.Parameter, value: str) -> dict[str, str]: @@ -161,7 +171,7 @@ def parse_args(ctx: click.Context, param: click.Parameter, value: str) -> dict[s @session.command(name="planned") @click.option("--plan", type=click.Path(exists=True)) -@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +@click.option("--log-level", type=LOG_CHOICE, default="INFO") @click.option("-a", "--args", callback=parse_args, help="Args in the format arg1:value1,arg2:value2") def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -> None: plan_templated = render_template(Path(plan), context=args) @@ -173,7 +183,7 @@ def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) - @session.command(name="resume") @click.argument("name", required=False, shell_complete=autocomplete_session_files) @click.option("--profile") -@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +@click.option("--log-level", type=LOG_CHOICE, default="INFO") def session_resume(name: Optional[str], profile: str, log_level: str) -> None: """Resume an existing goose session""" session_files = get_session_files() @@ -190,14 +200,15 @@ def session_resume(name: Optional[str], profile: str, log_level: str) -> None: else: print(f"Creating new session: {name}") session = Session(name=name, profile=profile, log_level=log_level) - session.run() + session.run(new_session=False) @goose_cli.command(name="run") @click.argument("message_file", required=False, type=click.Path(exists=True)) @click.option("--profile") -@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") -def run(message_file: Optional[str], profile: str, log_level: str) -> None: +@click.option("--log-level", type=LOG_CHOICE, default="INFO") +@click.option("--resume-session", is_flag=True, help="Resume the last session if available") +def run(message_file: Optional[str], profile: str, log_level: str, resume_session: bool = False) -> None: """Run a single-pass session with a message from a markdown input file""" if message_file: with open(message_file, "r") as f: @@ -205,7 +216,13 @@ def run(message_file: Optional[str], profile: str, log_level: str) -> None: else: initial_message = click.get_text_stream("stdin").read() - session = Session(profile=profile, log_level=log_level) + if resume_session: + session_files = get_session_files() + if session_files: + name = list(session_files.keys())[0] + session = Session(name=name, profile=profile, log_level=log_level) + else: + session = Session(profile=profile, log_level=log_level) session.single_pass(initial_message=initial_message) diff --git a/src/goose/cli/prompt/completer.py b/src/goose/cli/prompt/completer.py index 6739d1530..fb453fd38 100644 --- a/src/goose/cli/prompt/completer.py +++ b/src/goose/cli/prompt/completer.py @@ -1,5 +1,4 @@ import re -from typing import List from prompt_toolkit.completion import CompleteEvent, Completer, Completion from prompt_toolkit.document import Document @@ -8,10 +7,10 @@ class GoosePromptCompleter(Completer): - def __init__(self, commands: List[Command]) -> None: + def __init__(self, commands: list[Command]) -> None: self.commands = commands - def get_command_completions(self, document: Document) -> List[Completion]: + def get_command_completions(self, document: Document) -> list[Completion]: all_completions = [] for command_name, command_instance in self.commands.items(): pattern = rf"(? List[Completion]: all_completions.extend(completions) return all_completions - def get_command_name_completions(self, document: Document) -> List[Completion]: + def get_command_name_completions(self, document: Document) -> list[Completion]: pattern = r"(? List[Completion]: completions.append(Completion(command_name, start_position=-len(query), display=command_name)) return completions - def get_completions(self, document: Document, _: CompleteEvent) -> List[Completion]: + def get_completions(self, document: Document, _: CompleteEvent) -> list[Completion]: command_completions = self.get_command_completions(document) command_name_completions = self.get_command_name_completions(document) return command_name_completions + command_completions diff --git a/src/goose/cli/prompt/lexer.py b/src/goose/cli/prompt/lexer.py index 0e2bb0c91..e00fd207a 100644 --- a/src/goose/cli/prompt/lexer.py +++ b/src/goose/cli/prompt/lexer.py @@ -1,5 +1,5 @@ import re -from typing import Callable, List, Tuple +from typing import Callable from prompt_toolkit.document import Document from prompt_toolkit.lexers import Lexer @@ -27,7 +27,7 @@ def value_for_command(command_string: str) -> re.Pattern[str]: class PromptLexer(Lexer): - def __init__(self, command_names: List[str]) -> None: + def __init__(self, command_names: list[str]) -> None: self.patterns = [] for command_name in command_names: self.patterns.append((completion_for_command(command_name), "class:command")) @@ -35,7 +35,7 @@ def __init__(self, command_names: List[str]) -> None: self.patterns.append((command_itself(command_name), "class:command")) def lex_document(self, document: Document) -> Callable[[int], list]: - def get_line_tokens(line_number: int) -> Tuple[str, str]: + def get_line_tokens(line_number: int) -> tuple[str, str]: line = document.lines[line_number] tokens = [] diff --git a/src/goose/cli/prompt/overwrite_session_prompt.py b/src/goose/cli/prompt/overwrite_session_prompt.py new file mode 100644 index 000000000..64bbeed61 --- /dev/null +++ b/src/goose/cli/prompt/overwrite_session_prompt.py @@ -0,0 +1,30 @@ +from rich.prompt import Prompt + + +class OverwriteSessionPrompt(Prompt): + def __init__(self, *args: tuple[any], **kwargs: dict[str, any]) -> None: + super().__init__(*args, **kwargs) + self.choices = { + "yes": "Overwrite the existing session", + "no": "Pick a new session name", + "resume": "Resume the existing session", + } + self.default = "resume" + + def check_choice(self, choice: str) -> bool: + normalized_choice = choice.lower() + for key in self.choices: + is_key = normalized_choice == key + is_first_letter = normalized_choice and normalized_choice[0] == key[0] + if is_key or is_first_letter: + return True + return False + + def pre_prompt(self) -> str: + print("Would you like to overwrite it?") + print() + for key, value in self.choices.items(): + first_letter, remaining = key[0], key[1:] + rendered_key = rf"[{first_letter}]{remaining}" + print(f" {rendered_key:10} {value}") + print() diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 83332f113..b8182f58c 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -1,24 +1,29 @@ from contextlib import nullcontext +import logging import traceback from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Optional -from exchange import Message, ToolResult, ToolUse, Text +from langfuse.decorators import langfuse_context +from exchange import Message, Text, ToolResult, ToolUse +from exchange.langfuse_wrapper import observe_wrapper, auth_check from rich import print from rich.markdown import Markdown from rich.panel import Panel +from rich.prompt import Prompt from rich.status import Status -from goose.cli.config import ensure_config, session_path, LOG_PATH from goose._logger import get_logger, setup_logging +from goose.cli.config import LOG_PATH, ensure_config, session_path from goose.cli.prompt.goose_prompt_session import GoosePromptSession +from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt from goose.cli.session_notifier import SessionNotifier from goose.profile import Profile from goose.toolkit.language_server import LanguageServerCoordinator from goose.utils import droid, load_plugins from goose.utils._cost_calculator import get_total_cost_message from goose.utils._create_exchange import create_exchange -from goose.utils.session_file import read_or_create_file, save_latest_session +from goose.utils.session_file import is_empty_session, is_existing_session, read_or_create_file, save_latest_session RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -62,7 +67,8 @@ def __init__( profile: Optional[str] = None, plan: Optional[dict] = None, log_level: Optional[str] = "INFO", - **kwargs: Dict[str, Any], + tracing: bool = False, + **kwargs: dict[str, any], ) -> None: if name is None: self.name = droid() @@ -73,6 +79,18 @@ def __init__( self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) self.profile = load_profile(profile) + if not tracing: + logging.getLogger("langfuse").setLevel(logging.ERROR) + else: + langfuse_auth = auth_check() + if langfuse_auth: + print("Local Langfuse initialized. View your traces at http://localhost:3000") + else: + raise RuntimeError( + "You passed --tracing, but a Langfuse object was not found in the current context. " + "Please initialize the local Langfuse server and restart Goose." + ) + langfuse_context.configure(enabled=tracing) self.exchange = create_exchange(profile=self.profile, notifier=self.notifier) setup_logging(log_file_directory=LOG_PATH, log_level=log_level) @@ -84,7 +102,7 @@ def __init__( self.prompt_session = GoosePromptSession() - def _get_initial_messages(self) -> List[Message]: + def _get_initial_messages(self) -> list[Message]: messages = self.load_session() if messages and messages[-1].role == "user": @@ -155,13 +173,19 @@ def single_pass(self, initial_message: str) -> None: print(f"[dim]ended run | name:[cyan]{self.name}[/] profile:[cyan]{profile}[/]") print(f"[dim]to resume: [magenta]goose session resume {self.name} --profile {profile}[/][/]") - def run(self) -> None: + def run(self, new_session: bool = True) -> None: """ Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. + + Args: + new_session (bool): True when starting a new session, False when resuming. """ - print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]") - print(f"[dim]saving to {self.session_file_path}") + if is_existing_session(self.session_file_path) and new_session: + self._prompt_overwrite_session() + + profile_name = self.profile_name or "default" + print(f"[dim]starting session | name: [cyan]{self.name}[/cyan] profile: [cyan]{profile_name}[/cyan][/dim]") print() message = self.process_first_message() with self.setup_language_server()() as _: @@ -191,8 +215,10 @@ def run(self) -> None: user_input = self.prompt_session.get_user_input() message = Message.user(text=user_input.text) if user_input.to_continue() else None + self._remove_empty_session() self._log_cost() + @observe_wrapper() def reply(self) -> None: """Reply to the last user message, calling tools as needed""" self.status_indicator.update("responding") @@ -247,12 +273,53 @@ def interrupt_reply(self) -> None: def session_file_path(self) -> Path: return session_path(self.name) - def load_session(self) -> List[Message]: + def load_session(self) -> list[Message]: return read_or_create_file(self.session_file_path) def _log_cost(self) -> None: get_logger().info(get_total_cost_message(self.exchange.get_token_usage())) - print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}") + print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}[/]") + + def _prompt_overwrite_session(self) -> None: + print(f"[yellow]Session already exists at {self.session_file_path}.[/]") + + choice = OverwriteSessionPrompt.ask("Enter your choice", show_choices=False) + match choice: + case "y" | "yes": + print("Overwriting existing session") + + case "n" | "no": + while True: + new_session_name = Prompt.ask("Enter a new session name") + if not is_existing_session(session_path(new_session_name)): + self.name = new_session_name + break + print(f"[yellow]Session '{new_session_name}' already exists[/]") + + case "r" | "resume": + self.exchange.messages.extend(self.load_session()) + + def _remove_empty_session(self) -> bool: + """ + Removes the session file only when it's empty. + + Note: This is because a session file is created at the start of the run + loop. When a user aborts before their first message empty session files + will be created, causing confusion when resuming sessions (which + depends on most recent mtime and is non-empty). + + Returns: + bool: True if the session file was removed, False otherwise. + """ + logger = get_logger() + try: + if is_empty_session(self.session_file_path): + logger.debug(f"deleting empty session file: {self.session_file_path}") + self.session_file_path.unlink() + return True + except Exception as e: + logger.error(f"error deleting empty session file: {e}") + return False if __name__ == "__main__": diff --git a/src/goose/command/__init__.py b/src/goose/command/__init__.py index d9fd674a4..cef47fec9 100644 --- a/src/goose/command/__init__.py +++ b/src/goose/command/__init__.py @@ -1,5 +1,4 @@ from functools import cache -from typing import Dict from goose.command.base import Command from goose.utils import load_plugins @@ -11,5 +10,5 @@ def get_command(name: str) -> type[Command]: @cache -def get_commands() -> Dict[str, type[Command]]: +def get_commands() -> dict[str, type[Command]]: return load_plugins(group="goose.command") diff --git a/src/goose/command/base.py b/src/goose/command/base.py index 5a8c346ff..081453de6 100644 --- a/src/goose/command/base.py +++ b/src/goose/command/base.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List, Optional +from typing import Optional from prompt_toolkit.completion import Completion @@ -7,7 +7,7 @@ class Command(ABC): """A command that can be executed by the CLI.""" - def get_completions(self, query: str) -> List[Completion]: + def get_completions(self, query: str) -> list[Completion]: """ Get completions for the command. diff --git a/src/goose/command/file.py b/src/goose/command/file.py index cb8bdfd67..786785cf8 100644 --- a/src/goose/command/file.py +++ b/src/goose/command/file.py @@ -1,5 +1,4 @@ import os -from typing import List from prompt_toolkit.completion import Completion @@ -7,7 +6,7 @@ class FileCommand(Command): - def get_completions(self, query: str) -> List[Completion]: + def get_completions(self, query: str) -> list[Completion]: if query.startswith("/"): directory = os.path.dirname(query) search_term = os.path.basename(query) diff --git a/src/goose/profile.py b/src/goose/profile.py index cdc34fb85..2999a470a 100644 --- a/src/goose/profile.py +++ b/src/goose/profile.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Mapping, Type +from typing import Mapping from attrs import asdict, define, field @@ -21,10 +21,10 @@ class Profile: processor: str accelerator: str moderator: str - toolkits: List[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec)) + toolkits: list[ToolkitSpec] = field(factory=list, converter=ensure_list(ToolkitSpec)) @toolkits.validator - def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[ToolkitSpec]) -> None: + def check_toolkit_requirements(self, _: type["ToolkitSpec"], toolkits: list[ToolkitSpec]) -> None: # checks that the list of toolkits in the profile have their requirements installed_toolkits = set([toolkit.name for toolkit in toolkits]) @@ -36,7 +36,7 @@ def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[Tool msg = f"Toolkit {toolkit_name} requires {req} but it is not present" raise ValueError(msg) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, any]: return asdict(self) def profile_info(self) -> str: @@ -44,7 +44,7 @@ def profile_info(self) -> str: return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}" -def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile: +def default_profile(provider: str, processor: str, accelerator: str, **kwargs: dict[str, any]) -> Profile: """Get the default profile""" # TODO consider if the providers should have recommended models diff --git a/src/goose/toolkit/base.py b/src/goose/toolkit/base.py index 850f40263..faa102d3d 100644 --- a/src/goose/toolkit/base.py +++ b/src/goose/toolkit/base.py @@ -1,6 +1,6 @@ import inspect from abc import ABC -from typing import Callable, Mapping, Optional, Tuple, TypeVar +from typing import Mapping, Optional, TypeVar from attrs import define, field from exchange import Tool @@ -8,7 +8,7 @@ from goose.notifier import Notifier # Create a type variable that can represent any function signature -F = TypeVar("F", bound=Callable) +F = TypeVar("F", bound=callable) def tool(func: F) -> F: @@ -55,7 +55,7 @@ def system(self) -> str: """Get the addition to the system prompt for this toolkit.""" return "" - def tools(self) -> Tuple[Tool, ...]: + def tools(self) -> tuple[Tool, ...]: """Get the tools for this toolkit This default method looks for functions on the toolkit annotated diff --git a/src/goose/toolkit/developer.py b/src/goose/toolkit/developer.py index ba600d921..b48a18069 100644 --- a/src/goose/toolkit/developer.py +++ b/src/goose/toolkit/developer.py @@ -3,7 +3,6 @@ import subprocess import time from pathlib import Path -from typing import Dict, List from exchange import Message from goose.toolkit.base import Toolkit, tool @@ -35,9 +34,9 @@ class Developer(Toolkit): We also include some default shell strategies in the prompt, such as using ripgrep """ - def __init__(self, *args: object, **kwargs: Dict[str, object]) -> None: + def __init__(self, *args: object, **kwargs: dict[str, object]) -> None: super().__init__(*args, **kwargs) - self.timestamps: Dict[str, float] = {} + self.timestamps: dict[str, float] = {} def system(self) -> str: """Retrieve system configuration details for developer""" @@ -55,7 +54,7 @@ def system(self) -> str: return system_prompt @tool - def update_plan(self, tasks: List[dict]) -> List[dict]: + def update_plan(self, tasks: list[dict]) -> list[dict]: """ Update the plan by overwriting all current tasks @@ -63,7 +62,7 @@ def update_plan(self, tasks: List[dict]) -> List[dict]: shown to the user directly, you do not need to reiterate it Args: - tasks (List(dict)): The list of tasks, where each task is a dictionary + tasks (list(dict)): The list of tasks, where each task is a dictionary with a key for the task "description" and the task "status". The status MUST be one of "planned", "complete", "failed", "in-progress". diff --git a/src/goose/toolkit/prompts/reasoner.jinja b/src/goose/toolkit/prompts/reasoner.jinja new file mode 100644 index 000000000..0129eb6ba --- /dev/null +++ b/src/goose/toolkit/prompts/reasoner.jinja @@ -0,0 +1,5 @@ +It is important to use deep reasoning and thinking tools when working with code, especially solving problems or new issues. +Writing code requires deep thinking and reasoning at times, which can be used to provide ideas, solutions, and to check other solutions. +Always use your generate_code tool when writing code especially on a new problem, +and use deep_reason to check solutions or when there have been errors or solutions are not clear. +Consider these tools as expert consultants that can provide advice and code that you may use. \ No newline at end of file diff --git a/src/goose/toolkit/reasoner.py b/src/goose/toolkit/reasoner.py new file mode 100644 index 000000000..fe8f3aa23 --- /dev/null +++ b/src/goose/toolkit/reasoner.py @@ -0,0 +1,77 @@ +from exchange import Exchange, Message, Text +from exchange.content import Content +from exchange.providers import OpenAiProvider +from goose.toolkit.base import Toolkit, tool +from goose.utils.ask import ask_an_ai + + +class Reasoner(Toolkit): + """Deep thinking toolkit for reasoning through problems and solutions""" + + def message_content(self, content: Content) -> Text: + if isinstance(content, Text): + return content + else: + return Text(str(content)) + + @tool + def deep_reason(self, problem: str) -> str: + """ + Debug or reason about challenges or problems. + It will take a minute to think about it and consider solutions. + + Args: + problem (str): description of problem or errors seen. + + Returns: + response (str): A solution, which may include a suggestion or code snippet. + """ + # Create an instance of Exchange with the inlined OpenAI provider + self.notifier.status("thinking...") + provider = OpenAiProvider.from_env() + + # Create messages list + existing_messages_copy = [ + Message(role=msg.role, content=[self.message_content(content) for content in msg.content]) + for msg in self.exchange_view.processor.messages + ] + exchange = Exchange(provider=provider, model="o1-preview", messages=existing_messages_copy, system=None) + + response = ask_an_ai(input="please help reason about this: " + problem, exchange=exchange, no_history=False) + return response.content[0].text + + @tool + def generate_code(self, instructions: str) -> str: + """ + reason about and write code based on instructions given. + this will consider and reason about the instructions and come up with code to solve it. + + Args: + instructions (str): instructions of what code to write or how to modify it. + + Returns: + response (str): generated code to be tested or applied. Not it will not write directly to files so you have to take it and process it if it is suitable. + """ # noqa: E501 + # Create an instance of Exchange with the inlined OpenAI provider + provider = OpenAiProvider.from_env() + + # clone messages, converting to text for context + existing_messages_copy = [ + Message(role=msg.role, content=[self.message_content(content) for content in msg.content]) + for msg in self.exchange_view.processor.messages + ] + exchange = Exchange(provider=provider, model="o1-mini", messages=existing_messages_copy, system=None) + + self.notifier.status("generating code...") + response = ask_an_ai( + input="Please follow the instructions, " + + "and ideally return relevant code and little commentary:" + + instructions, + exchange=exchange, + no_history=False, + ) + return response.content[0].text + + def system(self) -> str: + """Retrieve instructions on how to use this reasoning and code generation tool""" + return Message.load("prompts/reasoner.jinja").text diff --git a/src/goose/toolkit/repo_context/repo_context.py b/src/goose/toolkit/repo_context/repo_context.py index 8be8794f6..b8d0f1aae 100644 --- a/src/goose/toolkit/repo_context/repo_context.py +++ b/src/goose/toolkit/repo_context/repo_context.py @@ -1,7 +1,6 @@ import os from functools import cache from subprocess import CompletedProcess, run -from typing import Dict, Tuple from exchange import Message @@ -21,7 +20,7 @@ def __init__(self, notifier: Notifier, requires: Requirements) -> None: self.repo_project_root, self.is_git_repo, self.goose_session_root = self.determine_git_proj() - def determine_git_proj(self) -> Tuple[str, bool, str]: + def determine_git_proj(self) -> tuple[str, bool, str]: """Determines the root as well as where Goose is currently running If the project is not part of a Github repo, the root of the project will be defined as the current working @@ -72,11 +71,11 @@ def is_mono_repo(self) -> bool: return self.repo_size > 2000 @tool - def summarize_current_project(self) -> Dict[str, str]: + def summarize_current_project(self) -> dict[str, str]: """Summarizes the current project based on repo root (if git repo) or current project_directory (if not) Returns: - summary (Dict[str, str]): Keys are file paths and values are the summaries + summary (dict[str, str]): Keys are file paths and values are the summaries """ self.notifier.log("Summarizing the most relevant files in the current project. This may take a while...") @@ -101,8 +100,11 @@ def summarize_current_project(self) -> Dict[str, str]: file_select_exchange = replace_prompt(exchange=file_select_exchange, prompt=system) files = goose_picks_files(root=project_directory, exchange=file_select_exchange) + # summarize the selected files using a blank exchange with no tools summary = summarize_files_concurrent( - exchange=self.exchange_view.accelerator, file_list=files, project_name=project_directory.split("/")[-1] + exchange=clear_exchange(self.exchange_view.accelerator, clear_tools=True), + file_list=files, + project_name=project_directory.split("/")[-1], ) return summary diff --git a/src/goose/toolkit/repo_context/utils.py b/src/goose/toolkit/repo_context/utils.py index dca7f04b0..e69cea936 100644 --- a/src/goose/toolkit/repo_context/utils.py +++ b/src/goose/toolkit/repo_context/utils.py @@ -2,7 +2,6 @@ import concurrent.futures import os from collections import deque -from typing import Dict, List, Tuple from exchange import Exchange @@ -26,7 +25,7 @@ def get_repo_size(repo_path: str) -> int: return get_directory_size(git_dir) / (1024**2) -def get_files_and_directories(root_dir: str) -> Dict[str, list]: +def get_files_and_directories(root_dir: str) -> dict[str, list]: """Gets file names and directory names. Checks that goose has correctly typed the file and directory names and that the files actually exist (to avoid downstream file read errors). @@ -61,7 +60,7 @@ def get_files_and_directories(root_dir: str) -> Dict[str, list]: return {"files": files, "directories": dirs} -def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> List[str]: +def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> list[str]: """Lets goose pick files in a BFS manner""" queue = deque([root]) @@ -80,7 +79,7 @@ def goose_picks_files(root: str, exchange: Exchange, max_workers: int = 4) -> Li return all_files -def process_directory(current_dir: str, exchange: Exchange) -> Tuple[List[str], List[str]]: +def process_directory(current_dir: str, exchange: Exchange) -> tuple[list[str], list[str]]: """Allows goose to pick files and subdirectories contained in a given directory (current_dir). Get the list of file and directory names in the current folder, then ask Goose to pick which ones to keep. diff --git a/src/goose/toolkit/summarization/summarize_project.py b/src/goose/toolkit/summarization/summarize_project.py index d910fbc47..f5e22562c 100644 --- a/src/goose/toolkit/summarization/summarize_project.py +++ b/src/goose/toolkit/summarization/summarize_project.py @@ -1,5 +1,5 @@ import os -from typing import List, Optional +from typing import Optional from goose.toolkit import Toolkit from goose.toolkit.base import tool @@ -11,7 +11,7 @@ class SummarizeProject(Toolkit): def get_project_summary( self, project_dir_path: Optional[str] = os.getcwd(), - extensions: Optional[List[str]] = None, + extensions: Optional[list[str]] = None, summary_instructions_prompt: Optional[str] = None, ) -> dict: """Generates or retrieves a project summary based on specified file extensions. @@ -19,7 +19,7 @@ def get_project_summary( Args: project_dir_path (Optional[Path]): Path to the project directory. Defaults to the current working directory if None - extensions (Optional[List[str]]): Specific file extensions to summarize. + extensions (Optional[list[str]]): Specific file extensions to summarize. summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g. "Summarize the file in two sentences.". The default instruction is "Please summarize this file." diff --git a/src/goose/toolkit/summarization/summarize_repo.py b/src/goose/toolkit/summarization/summarize_repo.py index 18c7da428..58765dd9a 100644 --- a/src/goose/toolkit/summarization/summarize_repo.py +++ b/src/goose/toolkit/summarization/summarize_repo.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from goose.toolkit import Toolkit from goose.toolkit.base import tool @@ -10,7 +10,7 @@ class SummarizeRepo(Toolkit): def summarize_repo( self, repo_url: str, - specified_extensions: Optional[List[str]] = None, + specified_extensions: Optional[list[str]] = None, summary_instructions_prompt: Optional[str] = None, ) -> dict: """ @@ -19,7 +19,7 @@ def summarize_repo( Args: repo_url (str): The URL of the repository to summarize. - specified_extensions (Optional[List[str]]): List of file extensions to summarize, e.g., ["tf", "md"]. If + specified_extensions (Optional[list[str]]): list of file extensions to summarize, e.g., ["tf", "md"]. If this list is empty, then all files in the repo are summarized summary_instructions_prompt (Optional[str]): Instructions to give to the LLM about how to summarize each file. E.g. "Summarize the file in two sentences.". The default instruction is "Please summarize this file." diff --git a/src/goose/toolkit/summarization/utils.py b/src/goose/toolkit/summarization/utils.py index d398713cc..96e5d363d 100644 --- a/src/goose/toolkit/summarization/utils.py +++ b/src/goose/toolkit/summarization/utils.py @@ -2,7 +2,7 @@ import subprocess from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Optional from exchange import Exchange from exchange.providers.utils import InitialMessageTooLargeError @@ -15,7 +15,7 @@ # TODO: move git stuff -def run_git_command(command: List[str]) -> subprocess.CompletedProcess[str]: +def run_git_command(command: list[str]) -> subprocess.CompletedProcess[str]: result = subprocess.run(["git"] + command, capture_output=True, text=True, check=False) if result.returncode != 0: @@ -28,7 +28,7 @@ def clone_repo(repo_url: str, target_directory: str) -> None: run_git_command(["clone", repo_url, target_directory]) -def load_summary_file_if_exists(project_name: str) -> Optional[Dict]: +def load_summary_file_if_exists(project_name: str) -> Optional[dict]: """Checks if a summary file exists at '.goose/summaries/projectname-summary.json. Returns contents of the file if it exists, otherwise returns None @@ -36,7 +36,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]: project_name (str): name of the project or repo Returns: - Optional[Dict]: File contents, else None + Optional[dict]: File contents, else None """ summary_file_path = f"{SUMMARIES_FOLDER}/{project_name}-summary.json" if Path(summary_file_path).exists(): @@ -44,7 +44,7 @@ def load_summary_file_if_exists(project_name: str) -> Optional[Dict]: return json.load(f) -def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> Tuple[str, str]: +def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = None) -> tuple[str, str]: """Summarizes a single file Args: @@ -74,15 +74,15 @@ def summarize_file(filepath: str, exchange: Exchange, prompt: Optional[str] = No def summarize_repo( repo_url: str, exchange: Exchange, - extensions: List[str], + extensions: list[str], summary_instructions_prompt: Optional[str] = None, -) -> Dict[str, str]: +) -> dict[str, str]: """Clones (if needed) and summarizes a repo Args: repo_url (str): Repository url exchange (Exchange): Exchange for summarizing the repo. - extensions (List[str]): List of file-types to summarize. + extensions (list[str]): list of file-types to summarize. summary_instructions_prompt (Optional[str]): Optional parameter to customize summarization results. Defaults to "Please summarize this file" """ @@ -110,15 +110,15 @@ def summarize_repo( def summarize_directory( - directory: str, exchange: Exchange, extensions: List[str], summary_instructions_prompt: Optional[str] = None -) -> Dict[str, str]: + directory: str, exchange: Exchange, extensions: list[str], summary_instructions_prompt: Optional[str] = None +) -> dict[str, str]: """Summarize files in a given directory based on extensions. Will also recursively find files in subdirectories and summarize them. Args: directory (str): path to the top-level directory to summarize exchange (Exchange): Exchange to use to summarize - extensions (List[str]): List of file-type extensions to summarize (and ignore all other extensions). + extensions (list[str]): list of file-type extensions to summarize (and ignore all other extensions). summary_instructions_prompt (Optional[str]): Optional instructions to give to the exchange regarding summarization. Returns: @@ -158,19 +158,19 @@ def summarize_directory( def summarize_files_concurrent( - exchange: Exchange, file_list: List[str], project_name: str, summary_instructions_prompt: Optional[str] = None -) -> Dict[str, str]: + exchange: Exchange, file_list: list[str], project_name: str, summary_instructions_prompt: Optional[str] = None +) -> dict[str, str]: """Takes in a list of files and summarizes them. Exchange does not keep history of the summarized files. Args: exchange (Exchange): Underlying exchange - file_list (List[str]): List of paths to files to summarize + file_list (list[str]): list of paths to files to summarize project_name (str): Used to save the summary of the files to .goose/summaries/-summary.json summary_instructions_prompt (Optional[str]): Summary instructions for the LLM. Defaults to "Please summarize this file." Returns: - file_summaries (Dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange + file_summaries (dict[str, str]): Keys are file paths and values are the summaries returned by the Exchange """ summary_file = load_summary_file_if_exists(project_name) if summary_file: diff --git a/src/goose/toolkit/utils.py b/src/goose/toolkit/utils.py index ad97360f2..d6f0335b1 100644 --- a/src/goose/toolkit/utils.py +++ b/src/goose/toolkit/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Dict +from typing import Optional from pygments.lexers import get_lexer_for_filename from pygments.util import ClassNotFound @@ -67,7 +67,7 @@ def find_last_task_group_index(input_str: str) -> int: return last_group_start_index -def parse_plan(input_plan_str: str) -> Dict: +def parse_plan(input_plan_str: str) -> dict: last_group_start_index = find_last_task_group_index(input_plan_str) if last_group_start_index == -1: return {"kickoff_message": input_plan_str, "tasks": []} diff --git a/src/goose/utils/__init__.py b/src/goose/utils/__init__.py index 69887e7f0..d9535b187 100644 --- a/src/goose/utils/__init__.py +++ b/src/goose/utils/__init__.py @@ -1,7 +1,7 @@ import random import string from importlib.metadata import entry_points -from typing import Any, Callable, Dict, List, Type, TypeVar +from typing import TypeVar, Callable T = TypeVar("T") @@ -31,10 +31,10 @@ def load_plugins(group: str) -> dict: return plugins -def ensure(cls: Type[T]) -> Callable[[Any], T]: +def ensure(cls: type[T]) -> Callable[[any], T]: """Convert dictionary to a class instance""" - def converter(val: Any) -> T: # noqa: ANN401 + def converter(val: any) -> T: # noqa: ANN401 if isinstance(val, cls): return val elif isinstance(val, dict): @@ -47,10 +47,10 @@ def converter(val: Any) -> T: # noqa: ANN401 return converter -def ensure_list(cls: Type[T]) -> Callable[[List[Dict[str, Any]]], Type[T]]: +def ensure_list(cls: type[T]) -> Callable[[list[dict[str, any]]], type[T]]: """Convert a list of dictionaries to class instances""" - def converter(val: List[Dict[str, Any]]) -> List[T]: + def converter(val: list[dict[str, any]]) -> list[T]: output = [] for entry in val: output.append(ensure(cls)(entry)) diff --git a/src/goose/utils/file_utils.py b/src/goose/utils/file_utils.py index 370ce6565..d3f9e8e1b 100644 --- a/src/goose/utils/file_utils.py +++ b/src/goose/utils/file_utils.py @@ -2,7 +2,7 @@ import os from collections import Counter from pathlib import Path -from typing import Dict, List, Optional +from typing import Optional from goose.utils.language import Language @@ -13,7 +13,7 @@ def create_extensions_list(project_root: str, max_n: int) -> list: project_root (str): Root of the project to analyze max_n (int): The number of file extensions to return Returns: - extensions (List[str]): A list of the top N file extensions + extensions (list[str]): A list of the top N file extensions """ if max_n == 0: raise (ValueError("Number of file extensions must be greater than 0")) @@ -33,14 +33,14 @@ def create_extensions_list(project_root: str, max_n: int) -> list: return extensions -def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float]: +def create_language_weighting(files_in_directory: list[str]) -> dict[str, float]: """Calculate language weighting by file size to match GitHub's methodology. Args: - files_in_directory (List[str]): Paths to files in the project directory + files_in_directory (list[str]): Paths to files in the project directory Returns: - Dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values + dict[str, float]: A dictionary with languages as keys and their percentage of the total codebase as values """ # Initialize counters for sizes @@ -61,7 +61,7 @@ def create_language_weighting(files_in_directory: List[str]) -> Dict[str, float] return dict(sorted(language_percentages.items(), key=lambda item: item[1], reverse=True)) -def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> List[str]: +def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> list[str]: """List all files in a directory with a given extension. Set extension to '' to return all files. Args: @@ -69,7 +69,7 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L extension (Optional[str]): extension to lookup. Defaults to '' which will return all files. Returns: - files (List[str]): List of file paths + files (list[str]): list of file paths """ # add a leading '.' to extension if needed if extension and not extension.startswith("."): @@ -79,15 +79,15 @@ def list_files_with_extension(dir_path: str, extension: Optional[str] = "") -> L return files -def create_file_list(dir_path: str, extensions: List[str]) -> List[str]: +def create_file_list(dir_path: str, extensions: list[str]) -> list[str]: """Creates a list of files with certain extensions Args: dir_path (str): Directory to list files of. Will include files recursively in sub-directories. - extensions (List[str]): List of file extensions to select for. If empty list, return all files + extensions (list[str]): list of file extensions to select for. If empty list, return all files Returns: - final_file_list (List[str]): List of file paths with specified extensions. + final_file_list (list[str]): list of file paths with specified extensions. """ # if extensions is empty list, return all files if not extensions: diff --git a/src/goose/utils/session_file.py b/src/goose/utils/session_file.py index 435186ce5..510b4aa19 100644 --- a/src/goose/utils/session_file.py +++ b/src/goose/utils/session_file.py @@ -1,20 +1,28 @@ import json import os -from pathlib import Path import tempfile -from typing import Dict, Iterator, List +from pathlib import Path +from typing import Iterator from exchange import Message from goose.cli.config import SESSION_FILE_SUFFIX -def write_to_file(file_path: Path, messages: List[Message]) -> None: +def is_existing_session(path: Path) -> bool: + return path.is_file() and path.stat().st_size > 0 + + +def is_empty_session(path: Path) -> bool: + return path.is_file() and path.stat().st_size == 0 + + +def write_to_file(file_path: Path, messages: list[Message]) -> None: with open(file_path, "w") as f: _write_messages_to_file(f, messages) -def read_or_create_file(file_path: Path) -> List[Message]: +def read_or_create_file(file_path: Path) -> list[Message]: if file_path.exists(): return read_from_file(file_path) with open(file_path, "w"): @@ -22,7 +30,7 @@ def read_or_create_file(file_path: Path) -> List[Message]: return [] -def read_from_file(file_path: Path) -> List[Message]: +def read_from_file(file_path: Path) -> list[Message]: try: with open(file_path, "r") as f: messages = [json.loads(m) for m in list(f) if m.strip()] @@ -32,7 +40,7 @@ def read_from_file(file_path: Path) -> List[Message]: return [Message(**m) for m in messages] -def list_sorted_session_files(session_files_directory: Path) -> Dict[str, Path]: +def list_sorted_session_files(session_files_directory: Path) -> dict[str, Path]: logs = list_session_files(session_files_directory) return {log.stem: log for log in sorted(logs, key=lambda x: x.stat().st_mtime, reverse=True)} @@ -47,7 +55,7 @@ def session_file_exists(session_files_directory: Path) -> bool: return any(list_session_files(session_files_directory)) -def save_latest_session(file_path: Path, messages: List[Message]) -> None: +def save_latest_session(file_path: Path, messages: list[Message]) -> None: with tempfile.NamedTemporaryFile("w", delete=False) as temp_file: _write_messages_to_file(temp_file, messages) temp_file_path = temp_file.name @@ -55,7 +63,7 @@ def save_latest_session(file_path: Path, messages: List[Message]) -> None: os.replace(temp_file_path, file_path) -def _write_messages_to_file(file: any, messages: List[Message]) -> None: +def _write_messages_to_file(file: any, messages: list[Message]) -> None: for m in messages: json.dump(m.to_dict(), file) file.write("\n") diff --git a/tests/cli/prompt/test_overwrite_session_prompt.py b/tests/cli/prompt/test_overwrite_session_prompt.py new file mode 100644 index 000000000..95cf825b4 --- /dev/null +++ b/tests/cli/prompt/test_overwrite_session_prompt.py @@ -0,0 +1,49 @@ +import pytest +from goose.cli.prompt.overwrite_session_prompt import OverwriteSessionPrompt + + +@pytest.fixture +def prompt(): + return OverwriteSessionPrompt() + + +def test_init(prompt): + assert prompt.choices == { + "yes": "Overwrite the existing session", + "no": "Pick a new session name", + "resume": "Resume the existing session", + } + assert prompt.default == "resume" + + +@pytest.mark.parametrize( + "choice, expected", + [ + ("", False), + ("invalid", False), + ("n", True), + ("N", True), + ("no", True), + ("NO", True), + ("r", True), + ("R", True), + ("resume", True), + ("RESUME", True), + ("y", True), + ("Y", True), + ("yes", True), + ("YES", True), + ], +) +def test_check_choice(prompt, choice, expected): + assert prompt.check_choice(choice) == expected + + +def test_instantiation(): + prompt = OverwriteSessionPrompt() + assert prompt.choices == { + "yes": "Overwrite the existing session", + "no": "Pick a new session name", + "resume": "Resume the existing session", + } + assert prompt.default == "resume" diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py index 9f6f2f66f..38b38920f 100644 --- a/tests/cli/test_main.py +++ b/tests/cli/test_main.py @@ -34,7 +34,9 @@ def test_session_start_command_with_session_name(mock_session): mock_session_class, mock_session_instance = mock_session runner = CliRunner() runner.invoke(goose_cli, ["session", "start", "session1", "--profile", "default"]) - mock_session_class.assert_called_once_with(name="session1", profile="default", plan=None, log_level="INFO") + mock_session_class.assert_called_once_with( + name="session1", profile="default", plan=None, log_level="INFO", tracing=False + ) mock_session_instance.run.assert_called_once() diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index f2437462c..b2eafea6c 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,7 +1,8 @@ +import json from unittest.mock import MagicMock, patch import pytest -from exchange import Exchange, Message, ToolUse, ToolResult +from exchange import Exchange, Message, ToolResult, ToolUse from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -123,36 +124,88 @@ def test_log_log_cost(create_session_with_mock_configs): mock_logger.info.assert_called_once_with(cost_message) -def test_run_should_auto_save_session(create_session_with_mock_configs, mock_sessions_path): +@patch.object(GoosePromptSession, "get_user_input", name="get_user_input") +@patch.object(Exchange, "generate", name="mock_generate") +@patch("goose.cli.session.save_latest_session", name="mock_save_latest_session") +def test_run_should_auto_save_session( + mock_save_latest_session, + mock_generate, + mock_get_user_input, + create_session_with_mock_configs, + mock_sessions_path, +): def custom_exchange_generate(self, *args, **kwargs): message = Message.assistant("Response") self.add(message) return message + def mock_generate_side_effect(*args, **kwargs): + return custom_exchange_generate(session.exchange, *args, **kwargs) + + def save_latest_session(file, messages): + file.write_text("\n".join(json.dumps(m.to_dict()) for m in messages)) + user_inputs = [ UserInput(action=PromptAction.CONTINUE, text="Question1"), UserInput(action=PromptAction.CONTINUE, text="Question2"), UserInput(action=PromptAction.EXIT), ] - session = create_session_with_mock_configs({"name": SESSION_NAME}) - with ( - patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), - patch.object(Exchange, "generate") as mock_generate, - patch("goose.cli.session.save_latest_session") as mock_save_latest_session, - ): - mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs) - session.run() - session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" - assert session.exchange.generate.call_count == 2 - assert mock_save_latest_session.call_count == 2 - assert mock_save_latest_session.call_args_list[0][0][0] == session_file - assert session_file.exists() + mock_get_user_input.side_effect = user_inputs + mock_generate.side_effect = mock_generate_side_effect + mock_save_latest_session.side_effect = save_latest_session + + session.run() + + session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" + + assert mock_generate.call_count == 2 + assert mock_save_latest_session.call_count == 2 + assert mock_save_latest_session.call_args_list[0][0][0] == session_file + assert session_file.exists() + with open(session_file, "r") as f: + saved_messages = [json.loads(line) for line in f] + + expected_messages = [ + Message.user("Question1"), + Message.assistant("Response"), + Message.user("Question2"), + Message.assistant("Response"), + ] + + assert len(saved_messages) == len(expected_messages) + for saved, expected in zip(saved_messages, expected_messages): + assert saved["role"] == expected.role + assert saved["content"][0]["text"] == expected.text + + +@patch("goose.cli.session.droid", return_value="generated_session_name", name="mock_droid") +def test_set_generated_session_name(mock_droid, create_session_with_mock_configs, mock_sessions_path): + session = create_session_with_mock_configs({"name": None}) + assert session.name == "generated_session_name" + + +@patch("goose.cli.session.is_existing_session", name="mock_is_existing") +@patch("goose.cli.session.Session._prompt_overwrite_session", name="mock_prompt") +def test_existing_session_prompt(mock_prompt, mock_is_existing, create_session_with_mock_configs): + session = create_session_with_mock_configs({"name": SESSION_NAME}) -def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path): - generated_session_name = "generated_session_name" - with patch("goose.cli.session.droid", return_value=generated_session_name): - session = create_session_with_mock_configs({"name": None}) - assert session.name == generated_session_name + def check_prompt_behavior(is_existing, new_session, should_prompt): + mock_is_existing.return_value = is_existing + if new_session is None: + session.run() + else: + session.run(new_session=new_session) + + if should_prompt: + mock_prompt.assert_called_once() + else: + mock_prompt.assert_not_called() + mock_prompt.reset_mock() + + check_prompt_behavior(is_existing=True, new_session=None, should_prompt=True) + check_prompt_behavior(is_existing=False, new_session=None, should_prompt=False) + check_prompt_behavior(is_existing=True, new_session=True, should_prompt=True) + check_prompt_behavior(is_existing=False, new_session=False, should_prompt=False) diff --git a/tests/utils/test_session_file.py b/tests/utils/test_session_file.py index 6a2a64981..290564566 100644 --- a/tests/utils/test_session_file.py +++ b/tests/utils/test_session_file.py @@ -1,9 +1,11 @@ import os from pathlib import Path +from unittest.mock import patch import pytest from exchange import Message from goose.utils.session_file import ( + is_empty_session, list_sorted_session_files, read_from_file, read_or_create_file, @@ -115,3 +117,22 @@ def create_session_file(file_path, file_name) -> Path: file = file_path / f"{file_name}.jsonl" file.touch() return file + + +@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file") +@patch("pathlib.Path.stat", name="mock_stat") +def test_is_empty_session(mock_stat, mock_is_file): + mock_stat.return_value.st_size = 0 + assert is_empty_session(Path("empty_file.json")) + + +@patch("pathlib.Path.is_file", return_value=True, name="mock_is_file") +@patch("pathlib.Path.stat", name="mock_stat") +def test_is_not_empty_session(mock_stat, mock_is_file): + mock_stat.return_value.st_size = 100 + assert not is_empty_session(Path("non_empty_file.json")) + + +@patch("pathlib.Path.is_file", return_value=False, name="mock_is_file") +def test_is_not_empty_session_file_not_found(mock_is_file): + assert not is_empty_session(Path("file_not_found.json"))