diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0fb1701e0..dbe5ab6dd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -5,7 +5,7 @@ on: branches: [main] jobs: - build: + exchange: runs-on: ubuntu-latest steps: @@ -19,9 +19,82 @@ jobs: - name: Ruff run: | - uvx ruff check - uvx ruff format --check + uvx ruff check packages/exchange + uvx ruff format packages/exchange --check - name: Run tests + working-directory: ./packages/exchange run: | uv run pytest tests -m 'not integration' + + goose: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install UV + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Source Cargo Environment + run: source $HOME/.cargo/env + + - name: Ruff + run: | + uvx ruff check src tests + uvx ruff format src tests --check + + - name: Run tests + run: | + uv run pytest tests -m 'not integration' + + + # This runs integration tests of the OpenAI API, using Ollama to host models. + # This lets us test PRs from forks which can't access secrets like API keys. + ollama: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: + # Only test the lastest python version. + - "3.12" + ollama-model: + # For quicker CI, use a smaller, tool-capable model than the default. + - "qwen2.5:0.5b" + + steps: + - uses: actions/checkout@v4 + + - name: Install UV + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Source Cargo Environment + run: source $HOME/.cargo/env + + - name: Set up Python + run: uv python install ${{ matrix.python-version }} + + - name: Install Ollama + run: curl -fsSL https://ollama.com/install.sh | sh + + - name: Start Ollama + run: | + # Run the background, in a way that survives to the next step + nohup ollama serve > ollama.log 2>&1 & + + # Block using the ready endpoint + time curl --retry 5 --retry-connrefused --retry-delay 1 -sf http://localhost:11434 + + # Tests use OpenAI which does not have a mechanism to pull models. Run a + # simple prompt to (pull and) test the model first. + - name: Test Ollama model + run: ollama run $OLLAMA_MODEL hello || cat ollama.log + env: + OLLAMA_MODEL: ${{ matrix.ollama-model }} + + - name: Run Ollama tests + run: uv run pytest tests -m integration -k ollama + working-directory: ./packages/exchange + env: + OLLAMA_MODEL: ${{ matrix.ollama-model }} diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 000000000..969ebb7e7 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,50 @@ +name: Publish + +# A release on goose will also publish exchange, if it has updated +# This means in some cases we may need to make a bump in goose without other changes to release exchange +on: + release: + types: [published] + +jobs: + publish: + permissions: + id-token: write + contents: read + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Get current version from pyproject.toml + id: get_version + run: | + echo "VERSION=$(grep -m 1 'version =' "pyproject.toml" | awk -F'"' '{print $2}')" >> $GITHUB_ENV + + - name: Extract tag version + id: extract_tag + run: | + TAG_VERSION=$(echo "${{ github.event.release.tag_name }}" | sed -E 's/v(.*)/\1/') + echo "TAG_VERSION=$TAG_VERSION" >> $GITHUB_ENV + + - name: Check if tag matches version from pyproject.toml + id: check_tag + run: | + if [ "${{ env.TAG_VERSION }}" != "${{ env.VERSION }}" ]; then + echo "::error::Tag version (${{ env.TAG_VERSION }}) does not match version in pyproject.toml (${{ env.VERSION }})." + exit 1 + fi + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Build Package + run: | + uv build -o dist --package goose-ai + uv build -o dist --package ai-exchange + + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + skip-existing: true diff --git a/.github/workflows/pypi_release.yaml b/.github/workflows/pypi_release.yaml deleted file mode 100644 index 98758fb96..000000000 --- a/.github/workflows/pypi_release.yaml +++ /dev/null @@ -1,47 +0,0 @@ -name: PYPI Release - -on: - push: - tags: - - 'v*' - -jobs: - pypi_release: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Install UV - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - - name: Source Cargo Environment - run: source $HOME/.cargo/env - - - name: Build with UV - run: uvx --from build pyproject-build --installer uv - - - name: Check version - id: check_version - run: | - PACKAGE_NAME=$(grep '^name =' pyproject.toml | sed -E 's/name = "(.*)"/\1/') - TAG_VERSION=$(echo "$GITHUB_REF" | sed -E 's/refs\/tags\/v(.+)/\1/') - CURRENT_VERSION=$(curl -s https://pypi.org/pypi/$PACKAGE_NAME/json | jq -r .info.version) - PROJECT_VERSION=$(grep '^version =' pyproject.toml | sed -E 's/version = "(.*)"/\1/') - if [ "$TAG_VERSION" != "$PROJECT_VERSION" ]; then - echo "Tag version does not match version in pyproject.toml" - exit 1 - fi - if python -c "from packaging.version import parse as parse_version; exit(0 if parse_version('$TAG_VERSION') > parse_version('$CURRENT_VERSION') else 1)"; then - echo "new_version=true" >> $GITHUB_OUTPUT - else - exit 1 - fi - - - name: Publish - uses: pypa/gh-action-pypi-publish@v1.4.2 - if: steps.check_version.outputs.new_version == 'true' - with: - user: __token__ - password: ${{ secrets.PYPI_TOKEN_TEMP }} - packages_dir: ./dist/ diff --git a/packages/exchange/README.md b/packages/exchange/README.md new file mode 100644 index 000000000..207030844 --- /dev/null +++ b/packages/exchange/README.md @@ -0,0 +1,95 @@ +
+ + + +Exchange - a uniform python SDK for message generation with LLMs
+ +- Provides a flexible layer for message handling and generation +- Directly integrates python functions into tool calling +- Persistently surfaces errors to the underlying models to support reflection + +## Example + +> [!NOTE] +> Before you can run this example, you need to setup an API key with +> `export OPENAI_API_KEY=your-key-here` + +``` python +from exchange import Exchange, Message, Tool +from exchange.providers import OpenAiProvider + +def word_count(text: str): + """Get the count of words in text + + Args: + text (str): The text with words to count + """ + return len(text.split(" ")) + +ex = Exchange( + provider=OpenAiProvider.from_env(), + model="gpt-4o", + system="You are a helpful assistant.", + tools=[Tool.from_function(word_count)], +) +ex.add(Message.user("Count the number of words in this current message")) + +# The model sees it has a word count tool, and should use it along the way to answer +# This will call all the tools as needed until the model replies with the final result +reply = ex.reply() +print(reply.text) + +# you can see all the tool calls in the message history +print(ex.messages) +``` + +## Plugins + +*exchange* has a plugin mechanism to add support for additional providers and moderators. If you need a +provider not supported here, we'd be happy to review [contributions][CONTRIBUTING]. But you +can also consider building and using your own plugin. + +To create a `Provider` plugin, subclass `exchange.provider.Provider`. You will need to +implement the `complete` method. For example this is what we use as a mock in our tests. +You can see a full implementation example of the [OpenAiProvider][openaiprovider]. We +also generally recommend implementing a `from_env` classmethod to instantiate the provider. + +``` python +class MockProvider(Provider): + def __init__(self, sequence: List[Message]): + # We'll use init to provide a preplanned reply sequence + self.sequence = sequence + self.call_count = 0 + + def complete( + self, model: str, system: str, messages: List[Message], tools: List[Tool] + ) -> Message: + output = self.sequence[self.call_count] + self.call_count += 1 + return output +``` + +Then use [python packaging's entrypoints][plugins] to register your plugin. + +``` toml +[project.entry-points.'exchange.provider'] +example = 'path.to.plugin:ExampleProvider' +``` + +Your plugin will then be available in your application or other applications built on *exchange* +through: + +``` python +from exchange.providers import get_provider + +provider = get_provider('example').from_env() +``` + +[CONTRIBUTING]: CONTRIBUTING.md +[openaiprovider]: src/exchange/providers/openai.py +[plugins]: https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ diff --git a/packages/exchange/pyproject.toml b/packages/exchange/pyproject.toml new file mode 100644 index 000000000..83a9e3c25 --- /dev/null +++ b/packages/exchange/pyproject.toml @@ -0,0 +1,48 @@ +[project] +name = "ai-exchange" +version = "0.9.3" +description = "a uniform python SDK for message generation with LLMs" +readme = "README.md" +requires-python = ">=3.10" +author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] +packages = [{ include = "exchange", from = "src" }] +dependencies = [ + "griffe>=1.1.1", + "attrs>=24.2.0", + "jinja2>=3.1.4", + "tiktoken>=0.7.0", + "httpx>=0.27.0", + "tenacity>=9.0.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/exchange"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.uv] +dev-dependencies = ["pytest>=8.3.2", "pytest-vcr>=1.0.2", "codecov>=2.1.13"] + +[project.entry-points."exchange.provider"] +openai = "exchange.providers.openai:OpenAiProvider" +azure = "exchange.providers.azure:AzureProvider" +databricks = "exchange.providers.databricks:DatabricksProvider" +anthropic = "exchange.providers.anthropic:AnthropicProvider" +bedrock = "exchange.providers.bedrock:BedrockProvider" +ollama = "exchange.providers.ollama:OllamaProvider" +google = "exchange.providers.google:GoogleProvider" + +[project.entry-points."exchange.moderator"] +passive = "exchange.moderators.passive:PassiveModerator" +truncate = "exchange.moderators.truncate:ContextTruncate" +summarize = "exchange.moderators.summarizer:ContextSummarizer" + +[project.entry-points."metadata.plugins"] +ai-exchange = "exchange:module_name" + +[tool.pytest.ini_options] +markers = [ + "integration: marks tests that need to authenticate (deselect with '-m \"not integration\"')", +] diff --git a/packages/exchange/src/exchange/__init__.py b/packages/exchange/src/exchange/__init__.py new file mode 100644 index 000000000..41adfcf3a --- /dev/null +++ b/packages/exchange/src/exchange/__init__.py @@ -0,0 +1,9 @@ +"""Classes for interacting with the exchange API.""" + +from exchange.tool import Tool # noqa +from exchange.content import Text, ToolResult, ToolUse # noqa +from exchange.message import Message # noqa +from exchange.exchange import Exchange # noqa +from exchange.checkpoint import CheckpointData, Checkpoint # noqa + +module_name = "ai-exchange" diff --git a/packages/exchange/src/exchange/checkpoint.py b/packages/exchange/src/exchange/checkpoint.py new file mode 100644 index 000000000..f355dd0a2 --- /dev/null +++ b/packages/exchange/src/exchange/checkpoint.py @@ -0,0 +1,67 @@ +from copy import deepcopy +from typing import List +from attrs import define, field + + +@define +class Checkpoint: + """Checkpoint that counts the tokens in messages between the start and end index""" + + start_index: int = field(default=0) # inclusive + end_index: int = field(default=0) # inclusive + token_count: int = field(default=0) + + def __deepcopy__(self, _) -> "Checkpoint": # noqa: ANN001 + """ + Returns a deep copy of the Checkpoint object. + """ + return Checkpoint( + start_index=self.start_index, + end_index=self.end_index, + token_count=self.token_count, + ) + + +@define +class CheckpointData: + """Aggregates all information about checkpoints""" + + # the total number of tokens in the exchange. this is updated every time a checkpoint is + # added or removed + total_token_count: int = field(default=0) + + # in order list of individual checkpoints in the exchange + checkpoints: List[Checkpoint] = field(factory=list) + + # the offset to apply to the message index when calculating the last message index + # this is useful because messages on the exchange behave like a queue, where you can only + # pop from the left or right sides. This offset allows us to map the checkpoint indices + # to the correct message index, even if we have popped messages from the left side of + # the exchange in the past. we reset this offset to 0 when we empty the checkpoint data. + message_index_offset: int = field(default=0) + + def __deepcopy__(self, memo: dict) -> "CheckpointData": + """Returns a deep copy of the CheckpointData object.""" + return CheckpointData( + total_token_count=self.total_token_count, + checkpoints=deepcopy(self.checkpoints, memo), + message_index_offset=self.message_index_offset, + ) + + @property + def last_message_index(self) -> int: + if not self.checkpoints: + return -1 # we don't have enough information to know + return self.checkpoints[-1].end_index - self.message_index_offset + + def reset(self) -> None: + """Resets the checkpoint data to its initial state.""" + self.checkpoints = [] + self.message_index_offset = 0 + self.total_token_count = 0 + + def pop(self, index: int = -1) -> Checkpoint: + """Removes and returns the checkpoint at the given index.""" + popped_checkpoint = self.checkpoints.pop(index) + self.total_token_count = self.total_token_count - popped_checkpoint.token_count + return popped_checkpoint diff --git a/packages/exchange/src/exchange/content.py b/packages/exchange/src/exchange/content.py new file mode 100644 index 000000000..b9cc986fc --- /dev/null +++ b/packages/exchange/src/exchange/content.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, Optional + +from attrs import define, asdict + + +CONTENT_TYPES = {} + + +class Content: + def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None: + super().__init_subclass__(**kwargs) + CONTENT_TYPES[cls.__name__] = cls + + def to_dict(self) -> Dict[str, Any]: + data = asdict(self, recurse=True) + data["type"] = self.__class__.__name__ + return data + + +@define +class Text(Content): + text: str + + +@define +class ToolUse(Content): + id: str + name: str + parameters: Any + is_error: bool = False + error_message: Optional[str] = None + + +@define +class ToolResult(Content): + tool_use_id: str + output: str + is_error: bool = False diff --git a/packages/exchange/src/exchange/exchange.py b/packages/exchange/src/exchange/exchange.py new file mode 100644 index 000000000..b2fdbc5ec --- /dev/null +++ b/packages/exchange/src/exchange/exchange.py @@ -0,0 +1,336 @@ +import json +import traceback +from copy import deepcopy +from typing import Any, Dict, List, Mapping, Tuple + +from attrs import define, evolve, field, Factory +from tiktoken import get_encoding + +from exchange.checkpoint import Checkpoint, CheckpointData +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.moderators import Moderator +from exchange.moderators.truncate import ContextTruncate +from exchange.providers import Provider, Usage +from exchange.tool import Tool +from exchange.token_usage_collector import _token_usage_collector + + +def validate_tool_output(output: str) -> None: + """Validate tool output for the given model""" + max_output_chars = 2**20 + max_output_tokens = 16000 + encoder = get_encoding("cl100k_base") + if len(output) > max_output_chars or len(encoder.encode(output)) > max_output_tokens: + raise ValueError("This tool call created an output that was too long to handle!") + + +@define(frozen=True) +class Exchange: + """An exchange of messages with an LLM + + The exchange class is meant to be largely immutable, with only the message list + growing once constructed. Use .replace to alter the model, tools, etc. + + The exchange supports tool usage, calling tools and letting the model respond when + using the .reply method. It handles most forms of errors and sends those errors back + to the model, to let it attempt to recover. + """ + + provider: Provider + model: str + system: str + moderator: Moderator = field(default=ContextTruncate()) + tools: Tuple[Tool] = field(factory=tuple, converter=tuple) + messages: List[Message] = field(factory=list) + checkpoint_data: CheckpointData = field(factory=CheckpointData) + generation_args: dict = field(default=Factory(dict)) + + @property + def _toolmap(self) -> Mapping[str, Tool]: + return {tool.name: tool for tool in self.tools} + + def replace(self, **kwargs: Dict[str, Any]) -> "Exchange": + """Make a copy of the exchange, replacing any passed arguments""" + # TODO: ensure that the checkpoint data is updated correctly. aka, + # if we replace the messages, we need to update the checkpoint data + # if we change the model, we need to update the checkpoint data (?) + + if kwargs.get("messages") is None: + kwargs["messages"] = deepcopy(self.messages) + if kwargs.get("checkpoint_data") is None: + kwargs["checkpoint_data"] = deepcopy( + self.checkpoint_data, + ) + return evolve(self, **kwargs) + + def add(self, message: Message) -> None: + """Add a message to the history.""" + if self.messages and message.role == self.messages[-1].role: + raise ValueError("Messages in the exchange must alternate between user and assistant") + self.messages.append(message) + + def generate(self) -> Message: + """Generate the next message.""" + self.moderator.rewrite(self) + message, usage = self.provider.complete( + self.model, + self.system, + messages=self.messages, + tools=self.tools, + **self.generation_args, + ) + self.add(message) + self.add_checkpoints_from_usage(usage) # this has to come after adding the response + + # TODO: also call `rewrite` here, as this will make our + # messages *consistently* below the token limit. this currently + # is not the case because we could append a large message after calling + # `rewrite` above. + # self.moderator.rewrite(self) + + _token_usage_collector.collect(self.model, usage) + return message + + def reply(self, max_tool_use: int = 128) -> Message: + """Get the reply from the underlying model. + + This will process any requests for tool calls, calling them immediately, and + storing the intermediate tool messages in the queue. It will return after the + first response that does not request a tool use + + Args: + max_tool_use: The maximum number of tool calls to make before returning. Defaults to 128. + """ + if max_tool_use <= 0: + raise ValueError("max_tool_use must be greater than 0") + response = self.generate() + curr_iter = 1 # generate() already called once + while response.tool_use: + content = [] + for tool_use in response.tool_use: + tool_result = self.call_function(tool_use) + content.append(tool_result) + self.add(Message(role="user", content=content)) + + # We've reached the limit of tool calls - break out of the loop + if curr_iter >= max_tool_use: + # At this point, the most recent message is `Message(role='user', content=ToolResult(...))` + response = Message.assistant( + f"We've stopped executing additional tool cause because we reached the limit of {max_tool_use}", + ) + self.add(response) + break + else: + response = self.generate() + curr_iter += 1 + + return response + + def call_function(self, tool_use: ToolUse) -> ToolResult: + """Call the function indicated by the tool use""" + tool = self._toolmap.get(tool_use.name) + + if tool is None or tool_use.is_error: + output = f"ERROR: Failed to use tool {tool_use.id}.\nDo NOT use the same tool name and parameters again - that will lead to the same error." # noqa: E501 + + if tool_use.is_error: + output += f"\n{tool_use.error_message}" + elif tool is None: + valid_tool_names = ", ".join(self._toolmap.keys()) + output += f"\nNo tool exists with the name '{tool_use.name}'. Valid tool names are: {valid_tool_names}" + + return ToolResult(tool_use_id=tool_use.id, output=output, is_error=True) + + try: + if isinstance(tool_use.parameters, dict): + output = json.dumps(tool.function(**tool_use.parameters)) + elif isinstance(tool_use.parameters, list): + output = json.dumps(tool.function(*tool_use.parameters)) + else: + raise ValueError( + f"The provided tool parameters, {tool_use.parameters} could not be interpreted as a mapping of arguments." # noqa: E501 + ) + + validate_tool_output(output) + + is_error = False + except Exception as e: + tb = traceback.format_exc() + output = str(tb) + "\n" + str(e) + is_error = True + + return ToolResult(tool_use_id=tool_use.id, output=output, is_error=is_error) + + def add_tool_use(self, tool_use: ToolUse) -> None: + """Manually add a tool use and corresponding result + + This will call the implied function and add an assistant + message requesting the ToolUse and a user message with the ToolResult + """ + tool_result = self.call_function(tool_use) + self.add(Message(role="assistant", content=[tool_use])) + self.add(Message(role="user", content=[tool_result])) + + def add_checkpoints_from_usage(self, usage: Usage) -> None: + """ + Add checkpoints to the exchange based on the token counts of the last two + groups of messages, as well as the current token total count of the exchange + """ + # we know we just appended one message as the response from the LLM + # so we need to create two checkpoints as we know the token counts + # of the last two groups of messages: + # 1. from the last checkpoint to the most recent user message + # 2. the most recent assistant message + last_checkpoint_end_index = ( + self.checkpoint_data.checkpoints[-1].end_index - self.checkpoint_data.message_index_offset + if len(self.checkpoint_data.checkpoints) > 0 + else -1 + ) + new_start_index = last_checkpoint_end_index + 1 + + # here, our self.checkpoint_data.total_token_count is the previous total token count from the last time + # that we performed a request. if we subtract this value from the input_tokens from our + # latest response, we know how many tokens our **1** from above is. + first_block_token_count = usage.input_tokens - self.checkpoint_data.total_token_count + second_block_token_count = usage.output_tokens + + if len(self.messages) - new_start_index > 1: + # this will occur most of the time, as we will have one new user message and one + # new assistant message. + + self.checkpoint_data.checkpoints.append( + Checkpoint( + start_index=new_start_index + self.checkpoint_data.message_index_offset, + # end index below is equivalent to the second last message. why? becuase + # the last message is the assistant message that we add below. we need to also + # track the token count of the user message sent. + end_index=len(self.messages) - 2 + self.checkpoint_data.message_index_offset, + token_count=first_block_token_count, + ) + ) + self.checkpoint_data.checkpoints.append( + Checkpoint( + start_index=len(self.messages) - 1 + self.checkpoint_data.message_index_offset, + end_index=len(self.messages) - 1 + self.checkpoint_data.message_index_offset, + token_count=second_block_token_count, + ) + ) + + # TODO: check if the front of the checkpoints doesn't overlap with + # the first message. if so, we are missing checkpoint data from + # message[0] to message[checkpoint_data.checkpoints[0].start_index] + # we can fill in this data by performing an extra request and doing some math + self.checkpoint_data.total_token_count = usage.total_tokens + + def pop_last_message(self) -> Message: + """Pop the last message from the exchange, handling checkpoints correctly""" + if ( + len(self.checkpoint_data.checkpoints) > 0 + and self.checkpoint_data.last_message_index > len(self.messages) - 1 + ): + raise ValueError("Our checkpoint data is out of sync with our message data") + if ( + len(self.checkpoint_data.checkpoints) > 0 + and self.checkpoint_data.last_message_index == len(self.messages) - 1 + ): + # remove the last checkpoint, because we no longer know the token count of it's contents. + # note that this is not the same as reverting to the last checkpoint, as we want to + # keep the messages from the last checkpoint. they will have a new checkpoint created for + # them when we call generate() again + self.checkpoint_data.pop() + self.messages.pop() + + def pop_first_message(self) -> Message: + """Pop the first message from the exchange, handling checkpoints correctly""" + if len(self.messages) == 0: + raise ValueError("There are no messages to pop") + if len(self.checkpoint_data.checkpoints) == 0: + raise ValueError("There must be at least one checkpoint to pop the first message") + + # get the start and end indexes of the first checkpoint, use these to remove message + first_checkpoint = self.checkpoint_data.checkpoints[0] + first_checkpoint_start_index = first_checkpoint.start_index - self.checkpoint_data.message_index_offset + + # check if the first message is part of the first checkpoint + if first_checkpoint_start_index == 0: + # remove this checkpoint, as it no longer has any messages + self.checkpoint_data.pop(0) + + self.messages.pop(0) + self.checkpoint_data.message_index_offset += 1 + + if len(self.checkpoint_data.checkpoints) == 0: + # we've removed all the checkpoints, so we need to reset the message index offset + self.checkpoint_data.message_index_offset = 0 + + def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + """ + Reverts the exchange back to the last checkpoint, removing associated messages + """ + removed_checkpoint = self.checkpoint_data.checkpoints.pop() + # pop messages until we reach the start of the next checkpoint + messages = [] + while len(self.messages) > removed_checkpoint.start_index - self.checkpoint_data.message_index_offset: + messages.append(self.messages.pop()) + return removed_checkpoint, messages + + def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + """ + Pop the first checkpoint from the exchange, removing associated messages + """ + if len(self.checkpoint_data.checkpoints) == 0: + raise ValueError("There are no checkpoints to pop") + first_checkpoint = self.checkpoint_data.pop(0) + + # remove messages until we reach the start of the next checkpoint + messages = [] + stop_at_index = first_checkpoint.end_index - self.checkpoint_data.message_index_offset + for _ in range(stop_at_index + 1): # +1 because it's inclusive + messages.append(self.messages.pop(0)) + self.checkpoint_data.message_index_offset += 1 + + if len(self.checkpoint_data.checkpoints) == 0: + # we've removed all the checkpoints, so we need to reset the message index offset + self.checkpoint_data.message_index_offset = 0 + return first_checkpoint, messages + + def prepend_checkpointed_message(self, message: Message, token_count: int) -> None: + """Prepend a message to the exchange, updating the checkpoint data""" + self.messages.insert(0, message) + new_index = max(0, self.checkpoint_data.message_index_offset - 1) + self.checkpoint_data.checkpoints.insert( + 0, + Checkpoint( + start_index=new_index, + end_index=new_index, + token_count=token_count, + ), + ) + self.checkpoint_data.message_index_offset = new_index + + def rewind(self) -> None: + if not self.messages: + return + + # we remove messages until we find the last user text message + while not (self.messages[-1].role == "user" and type(self.messages[-1].content[-1]) is Text): + self.pop_last_message() + + # now we remove that last user text message, putting us at a good point + # to ask the user for their input again + if self.messages: + self.pop_last_message() + + @property + def is_allowed_to_call_llm(self) -> bool: + """ + Returns True if the exchange is allowed to call the LLM, False otherwise + """ + # TODO: reconsider whether this function belongs here and whether it is necessary + # Some models will have different requirements than others, so it may be better for + # this to be a required method of the provider instead. + return len(self.messages) > 0 and self.messages[-1].role == "user" + + def get_token_usage(self) -> Dict[str, Usage]: + return _token_usage_collector.get_token_usage_group_by_model() diff --git a/packages/exchange/src/exchange/message.py b/packages/exchange/src/exchange/message.py new file mode 100644 index 000000000..035c60345 --- /dev/null +++ b/packages/exchange/src/exchange/message.py @@ -0,0 +1,121 @@ +import inspect +import time +from pathlib import Path +from typing import Any, Dict, List, Literal, Type + +from attrs import define, field +from jinja2 import Environment, FileSystemLoader + +from exchange.content import CONTENT_TYPES, Content, Text, ToolResult, ToolUse +from exchange.utils import create_object_id + +Role = Literal["user", "assistant"] + + +def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401 + if instance.role == "user": + if not (instance.text or instance.tool_result): + raise ValueError("User message must include a Text or ToolResult") + if instance.tool_use: + raise ValueError("User message does not support ToolUse") + elif instance.role == "assistant": + if not (instance.text or instance.tool_use): + raise ValueError("Assistant message must include a Text or ToolUsage") + if instance.tool_result: + raise ValueError("Assistant message does not support ToolResult") + + +def content_converter(contents: List[Dict[str, Any]]) -> List[Content]: + return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents] + + +@define +class Message: + """A message to or from a language model. + + This supports several content types to extend to tool usage and (tbi) images. + + We also provide shortcuts for simplified text usage; these two are identical: + ``` + m = Message(role='user', content=[Text(text='abcd')]) + assert m.content[0].text == 'abcd' + + m = Message.user('abcd') + assert m.text == 'abcd' + ``` + """ + + role: Role = field(default="user") + id: str = field(factory=lambda: str(create_object_id(prefix="msg"))) + created: int = field(factory=lambda: int(time.time())) + content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter) + + def to_dict(self) -> Dict[str, Any]: + return { + "role": self.role, + "id": self.id, + "created": self.created, + "content": [item.to_dict() for item in self.content], + } + + @property + def text(self) -> str: + """The text content of this message.""" + result = [] + for content in self.content: + if isinstance(content, Text): + result.append(content.text) + return "\n".join(result) + + @property + def tool_use(self) -> List[ToolUse]: + """All tool use content of this message.""" + result = [] + for content in self.content: + if isinstance(content, ToolUse): + result.append(content) + return result + + @property + def tool_result(self) -> List[ToolResult]: + """All tool result content of this message.""" + result = [] + for content in self.content: + if isinstance(content, ToolResult): + result.append(content) + return result + + @classmethod + def load( + cls: Type["Message"], + filename: str, + role: Role = "user", + **kwargs: Dict[str, Any], + ) -> "Message": + """Load the message from filename relative to where the load is called. + + This only supports simplified content, with a single text entry + + This is meant to emulate importing code rather than a runtime filesystem. So + if you have a directory of code that contains example.py, and example.py has + a function that calls User.load('example.jinja'), it will look in the same + directory as example.py for the jinja file. + """ + frm = inspect.stack()[1] + mod = inspect.getmodule(frm[0]) + + base_path = Path(mod.__file__).parent + + env = Environment(loader=FileSystemLoader(base_path)) + template = env.get_template(filename) + rendered_content = template.render(**kwargs) + + return cls(role=role, content=[Text(text=rendered_content)]) + + @classmethod + def user(cls: Type["Message"], text: str) -> "Message": + return cls(role="user", content=[Text(text)]) + + @classmethod + def assistant(cls: Type["Message"], text: str) -> "Message": + return cls(role="assistant", content=[Text(text)]) diff --git a/packages/exchange/src/exchange/moderators/__init__.py b/packages/exchange/src/exchange/moderators/__init__.py new file mode 100644 index 000000000..56b198a75 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/__init__.py @@ -0,0 +1,13 @@ +from functools import cache +from typing import Type + +from exchange.moderators.base import Moderator +from exchange.utils import load_plugins +from exchange.moderators.passive import PassiveModerator # noqa +from exchange.moderators.truncate import ContextTruncate # noqa +from exchange.moderators.summarizer import ContextSummarizer # noqa + + +@cache +def get_moderator(name: str) -> Type[Moderator]: + return load_plugins(group="exchange.moderator")[name] diff --git a/packages/exchange/src/exchange/moderators/base.py b/packages/exchange/src/exchange/moderators/base.py new file mode 100644 index 000000000..d7c630c6a --- /dev/null +++ b/packages/exchange/src/exchange/moderators/base.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Type + + +class Moderator(ABC): + @abstractmethod + def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + pass diff --git a/packages/exchange/src/exchange/moderators/passive.py b/packages/exchange/src/exchange/moderators/passive.py new file mode 100644 index 000000000..e3a24efbd --- /dev/null +++ b/packages/exchange/src/exchange/moderators/passive.py @@ -0,0 +1,7 @@ +from typing import Type +from exchange.moderators.base import Moderator + + +class PassiveModerator(Moderator): + def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + pass diff --git a/packages/exchange/src/exchange/moderators/summarizer.jinja b/packages/exchange/src/exchange/moderators/summarizer.jinja new file mode 100644 index 000000000..00c29ed82 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/summarizer.jinja @@ -0,0 +1,9 @@ +You are an expert technical summarizer. + +During your conversation with the user, you may be asked to summarize the content in you conversational history. +When asked to summarize, you should concisely summarize the conversation giving emphasis to newer content. Newer content will be towards the end of the conversation. +Preferentially keep user supplied content in the summary. + +The summary *MUST* include filenames that were touched and/or modified. If the updates occurred more recently, keep the latest modifications made to the files in the summary. If the changes occurred earlier in the chat, briefly summarize the changes and don't include the changes in the summary. + +There will likely be json formatted blocks referencing ToolUse and ToolResults. You can ignore ToolUse references, but keep the ToolResult outputs, summarizing as needed and with the same guidelines as above. diff --git a/packages/exchange/src/exchange/moderators/summarizer.py b/packages/exchange/src/exchange/moderators/summarizer.py new file mode 100644 index 000000000..7e2dd5588 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/summarizer.py @@ -0,0 +1,46 @@ +from typing import Type + +from exchange import Message +from exchange.checkpoint import CheckpointData +from exchange.moderators import ContextTruncate, PassiveModerator + + +class ContextSummarizer(ContextTruncate): + def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + """Summarize the context history up to the last few messages in the exchange""" + + self._update_system_prompt_token_count(exchange) + + # TODO: use an offset for summarization + if exchange.checkpoint_data.total_token_count < self.max_tokens: + return + + messages_to_summarize = self._get_messages_to_remove(exchange) + num_messages_to_remove = len(messages_to_summarize) + + # the llm will throw an error if the last message isn't a user message + if messages_to_summarize[-1].role == "assistant" and (not messages_to_summarize[-1].tool_use): + messages_to_summarize.append(Message.user("Summarize our the above conversation")) + + summarizer_exchange = exchange.replace( + system=Message.load("summarizer.jinja").text, + moderator=PassiveModerator(), + model=self.model, + messages=messages_to_summarize, + checkpoint_data=CheckpointData(), + ) + + # get the summarized content and the tokens associated with this content + summary = summarizer_exchange.reply() + summary_checkpoint = summarizer_exchange.checkpoint_data.checkpoints[-1] + + # remove the checkpoints that were summarized from the original exchange + for _ in range(num_messages_to_remove): + exchange.pop_first_message() + + # insert summary as first message/checkpoint + if len(exchange.messages) == 0 or exchange.messages[0].role == "assistant": + summary_message = Message.user(summary.text) + else: + summary_message = Message.assistant(summary.text) + exchange.prepend_checkpointed_message(summary_message, summary_checkpoint.token_count) diff --git a/packages/exchange/src/exchange/moderators/truncate.py b/packages/exchange/src/exchange/moderators/truncate.py new file mode 100644 index 000000000..41115f663 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/truncate.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +from exchange.checkpoint import CheckpointData +from exchange.message import Message +from exchange.moderators import PassiveModerator +from exchange.moderators.base import Moderator + +if TYPE_CHECKING: + from exchange.exchange import Exchange + +# currently this is the point at which we start to truncate, so +# so once we get to this token size the token count will exceed this +# by a little bit. +# TODO: make this configurable for each provider +MAX_TOKENS = 100000 + + +class ContextTruncate(Moderator): + def __init__( + self, + model: Optional[str] = None, + max_tokens: int = MAX_TOKENS, + ) -> None: + self.model = model + self.system_prompt_token_count = 0 + self.max_tokens = max_tokens + self.last_system_prompt = None + + def rewrite(self, exchange: Exchange) -> None: + """Truncate the exchange messages with a FIFO strategy.""" + self._update_system_prompt_token_count(exchange) + + if exchange.checkpoint_data.total_token_count < self.max_tokens: + return + + messages_to_remove = self._get_messages_to_remove(exchange) + for _ in range(len(messages_to_remove)): + exchange.pop_first_message() + + def _update_system_prompt_token_count(self, exchange: Exchange) -> None: + is_different_system_prompt = False + if self.last_system_prompt != exchange.system: + is_different_system_prompt = True + self.last_system_prompt = exchange.system + + if not self.system_prompt_token_count or is_different_system_prompt: + # calculate the system prompt tokens (includes functions etc...) + # we use a placeholder message with one token, which we subtract later + # this ensures compatibility with providers that require a user message + _system_token_exchange = exchange.replace( + messages=[Message.user("a")], + checkpoint_data=CheckpointData(), + moderator=PassiveModerator(), + model=self.model if self.model else exchange.model, + ) + _system_token_exchange.generate() + last_system_prompt_token_count = self.system_prompt_token_count + self.system_prompt_token_count = _system_token_exchange.checkpoint_data.total_token_count - 1 + + exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count + exchange.checkpoint_data.total_token_count += self.system_prompt_token_count + + def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]: + # this keeps all the messages/checkpoints + throwaway_exchange = exchange.replace( + moderator=PassiveModerator(), + ) + + # get the messages that we want to remove + messages_to_remove = [] + while throwaway_exchange.checkpoint_data.total_token_count > self.max_tokens: + _, messages = throwaway_exchange.pop_first_checkpoint() + messages_to_remove.extend(messages) + + while len(throwaway_exchange.messages) > 0 and throwaway_exchange.messages[0].tool_result: + # we would need a corresponding tool use once we resume, so we pop this one off too + # and summarize it as well + _, messages = throwaway_exchange.pop_first_checkpoint() + messages_to_remove.extend(messages) + return messages_to_remove diff --git a/packages/exchange/src/exchange/providers/__init__.py b/packages/exchange/src/exchange/providers/__init__.py new file mode 100644 index 000000000..ac7ed07a0 --- /dev/null +++ b/packages/exchange/src/exchange/providers/__init__.py @@ -0,0 +1,17 @@ +from functools import cache +from typing import Type + +from exchange.providers.anthropic import AnthropicProvider # noqa +from exchange.providers.base import Provider, Usage # noqa +from exchange.providers.databricks import DatabricksProvider # noqa +from exchange.providers.openai import OpenAiProvider # noqa +from exchange.providers.ollama import OllamaProvider # noqa +from exchange.providers.azure import AzureProvider # noqa +from exchange.providers.google import GoogleProvider # noqa + +from exchange.utils import load_plugins + + +@cache +def get_provider(name: str) -> Type[Provider]: + return load_plugins(group="exchange.provider")[name] diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py new file mode 100644 index 000000000..154ec5f79 --- /dev/null +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -0,0 +1,158 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange import Message, Tool +from exchange.content import Text, ToolResult, ToolUse +from exchange.providers.base import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status +from exchange.providers.utils import raise_for_status + +ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class AnthropicProvider(Provider): + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": + url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) + try: + key = os.environ["ANTHROPIC_API_KEY"] + except KeyError: + raise RuntimeError("Failed to get ANTHROPIC_API_KEY from the environment") + client = httpx.Client( + base_url=url, + headers={ + "x-api-key": key, + "content-type": "application/json", + "anthropic-version": "2023-06-01", + }, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: Dict) -> Usage: # noqa: ANN401 + usage = data.get("usage") + input_tokens = usage.get("input_tokens") + output_tokens = usage.get("output_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + @staticmethod + def anthropic_response_to_message(response: Dict) -> Message: + content_blocks = response.get("content", []) + content = [] + for block in content_blocks: + if block["type"] == "text": + content.append(Text(text=block["text"])) + elif block["type"] == "tool_use": + content.append( + ToolUse( + id=block["id"], + name=block["name"], + parameters=block["input"], + ) + ) + return Message(role="assistant", content=content) + + @staticmethod + def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]: + return [ + { + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.parameters, + } + for tool in tools + ] + + @staticmethod + def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]: + messages_spec = [] + # if messages is empty - just make a default + for message in messages: + converted = {"role": message.role} + for content in message.content: + if isinstance(content, Text): + converted["content"] = [{"type": "text", "text": content.text}] + elif isinstance(content, ToolUse): + converted.setdefault("content", []).append( + { + "type": "tool_use", + "id": content.id, + "name": content.name, + "input": content.parameters, + } + ) + elif isinstance(content, ToolResult): + converted.setdefault("content", []).append( + { + "type": "tool_result", + "tool_use_id": content.tool_use_id, + "content": content.output, + } + ) + messages_spec.append(converted) + if len(messages_spec) == 0: + converted = { + "role": "user", + "content": [{"type": "text", "text": "Ignore"}], + } + messages_spec.append(converted) + return messages_spec + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: List[Tool] = [], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + tools_set = set() + unique_tools = [] + for tool in tools: + if tool.name not in tools_set: + unique_tools.append(tool) + tools_set.add(tool.name) + + payload = dict( + system=system, + model=model, + max_tokens=4096, + messages=self.messages_to_anthropic_spec(messages), + tools=self.tools_to_anthropic_spec(tuple(unique_tools)), + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + + response = self._post(payload) + message = self.anthropic_response_to_message(response) + usage = self.get_usage(response) + + return message, usage + + @retry_procedure + def _post(self, payload: dict) -> httpx.Response: + response = self.client.post(ANTHROPIC_HOST, json=payload) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py new file mode 100644 index 000000000..7bacb9ddc --- /dev/null +++ b/packages/exchange/src/exchange/providers/azure.py @@ -0,0 +1,45 @@ +import os +from typing import Type + +import httpx + +from exchange.providers import OpenAiProvider + + +class AzureProvider(OpenAiProvider): + """Provides chat completions for models hosted by the Azure OpenAI Service""" + + def __init__(self, client: httpx.Client) -> None: + super().__init__(client) + + @classmethod + def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": + try: + url = os.environ["AZURE_CHAT_COMPLETIONS_HOST_NAME"] + except KeyError: + raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_HOST_NAME from the environment.") + + try: + deployment_name = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"] + except KeyError: + raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME from the environment.") + + try: + api_version = os.environ["AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"] + except KeyError: + raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION from the environment.") + + try: + key = os.environ["AZURE_CHAT_COMPLETIONS_KEY"] + except KeyError: + raise RuntimeError("Failed to get AZURE_CHAT_COMPLETIONS_KEY from the environment.") + + # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version + url = f"{url}/openai/deployments/{deployment_name}/" + client = httpx.Client( + base_url=url, + headers={"api-key": key, "Content-Type": "application/json"}, + params={"api-version": api_version}, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py new file mode 100644 index 000000000..7b7ff88bc --- /dev/null +++ b/packages/exchange/src/exchange/providers/base.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from attrs import define, field +from typing import List, Tuple, Type + +from exchange.message import Message +from exchange.tool import Tool + + +@define(hash=True) +class Usage: + input_tokens: int = field(factory=None) + output_tokens: int = field(default=None) + total_tokens: int = field(default=None) + + +class Provider(ABC): + @classmethod + def from_env(cls: Type["Provider"]) -> "Provider": + return cls() + + @abstractmethod + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + ) -> Tuple[Message, Usage]: + """Generate the next message using the specified model""" + pass diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py new file mode 100644 index 000000000..2a5f53dc8 --- /dev/null +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -0,0 +1,328 @@ +import hashlib +import hmac +import json +import logging +import os +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple, Type +from urllib.parse import quote, urlparse + +import httpx + +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.providers import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status +from exchange.providers.utils import raise_for_status +from exchange.tool import Tool + +SERVICE = "bedrock-runtime" +UTC = timezone.utc + +logger = logging.getLogger(__name__) + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class AwsClient(httpx.Client): + def __init__( + self, + aws_region: str, + aws_access_key: str, + aws_secret_key: str, + aws_session_token: Optional[str] = None, + **kwargs: Dict[str, Any], + ) -> None: + self.region = aws_region + self.host = f"https://{SERVICE}.{aws_region}.amazonaws.com/" + self.access_key = aws_access_key + self.secret_key = aws_secret_key + self.session_token = aws_session_token + super().__init__(base_url=self.host, timeout=600, **kwargs) + + def post(self, path: str, json: Dict, **kwargs: Dict[str, Any]) -> httpx.Response: + signed_headers = self.sign_and_get_headers( + method="POST", + url=path, + payload=json, + service="bedrock", + ) + return super().post(url=path, json=json, headers=signed_headers, **kwargs) + + def sign_and_get_headers( + self, + method: str, + url: str, + payload: dict, + service: str, + ) -> Dict[str, str]: + """ + Sign the request and generate the necessary headers for AWS authentication. + + Args: + method (str): HTTP method (e.g., 'GET', 'POST'). + url (str): The request URL. + payload (dict): The request payload. + service (str): The AWS service name. + region (str): The AWS region. + access_key (str): The AWS access key. + secret_key (str): The AWS secret key. + session_token (Optional[str]): The AWS session token, if any. + + Returns: + Dict[str, str]: The headers required for the request. + """ + + def sign(key: bytes, msg: str) -> bytes: + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: str) -> bytes: + k_date = sign(("AWS4" + key).encode("utf-8"), date_stamp) + k_region = sign(k_date, region_name) + k_service = sign(k_region, service_name) + k_signing = sign(k_service, "aws4_request") + return k_signing + + # Convert payload to JSON string + request_parameters = json.dumps(payload) + + # Create a date for headers and the credential string + t = datetime.now(UTC) + amz_date = t.strftime("%Y%m%dT%H%M%SZ") + date_stamp = t.strftime("%Y%m%d") # Date w/o time, used in credential scope + + # Create canonical URI and headers + parsedurl = urlparse(url) + canonical_uri = quote(parsedurl.path if parsedurl.path else "/", safe="/-_.~") + canonical_headers = f"host:{parsedurl.netloc}\nx-amz-date:{amz_date}\n" + + # Create the list of signed headers. + signed_headers = "host;x-amz-date" + if self.session_token: + canonical_headers += "x-amz-security-token:" + self.session_token + "\n" + signed_headers += ";x-amz-security-token" + + # Create payload hash + payload_hash = hashlib.sha256(request_parameters.encode("utf-8")).hexdigest() + + # Canonical request + canonical_request = f"{method}\n{canonical_uri}\n\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + + # Create the string to sign + algorithm = "AWS4-HMAC-SHA256" + credential_scope = f"{date_stamp}/{self.region}/{service}/aws4_request" + string_to_sign = ( + f"{algorithm}\n{amz_date}\n{credential_scope}\n" + f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}' + ) + + # Create the signing key + signing_key = get_signature_key(self.secret_key, date_stamp, self.region, service) + + # Sign the string_to_sign using the signing key + signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() + + # Add signing information to the request + authorization_header = ( + f"{algorithm} Credential={self.access_key}/{credential_scope}, SignedHeaders={signed_headers}, " + f"Signature={signature}" + ) + + # Headers + headers = { + "Content-Type": "application/json", + "Authorization": authorization_header, + "X-Amz-date": amz_date.encode(), + "x-amz-content-sha256": payload_hash, + } + if self.session_token: + headers["X-Amz-Security-Token"] = self.session_token + + return headers + + +class BedrockProvider(Provider): + def __init__(self, client: AwsClient) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": + aws_region = os.environ.get("AWS_REGION", "us-east-1") + try: + aws_access_key = os.environ["AWS_ACCESS_KEY_ID"] + aws_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"] + aws_session_token = os.environ.get("AWS_SESSION_TOKEN") + except KeyError: + raise RuntimeError("Failed to get AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY from the environment") + + client = AwsClient( + aws_region=aws_region, + aws_access_key=aws_access_key, + aws_secret_key=aws_secret_key, + aws_session_token=aws_session_token, + ) + return cls(client=client) + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + """ + Generate a completion response from the Bedrock gateway. + + Args: + model (str): The model identifier. + system (str): The system prompt or configuration. + messages (List[Message]): A list of messages to be processed by the model. + tools (Tuple[Tool]): A tuple of tools to be used in the completion process. + **kwargs: Additional keyword arguments for inference configuration. + + Returns: + Tuple[Message, Usage]: A tuple containing the response message and usage data. + """ + + inference_config = dict( + temperature=kwargs.pop("temperature", None), + maxTokens=kwargs.pop("max_tokens", None), + stopSequences=kwargs.pop("stop", None), + topP=kwargs.pop("topP", None), + ) + inference_config = {k: v for k, v in inference_config.items() if v is not None} or None + + converted_messages = [self.message_to_bedrock_spec(message) for message in messages] + converted_system = [dict(text=system)] + tool_config = self.tools_to_bedrock_spec(tools) + payload = dict( + system=converted_system, + inferenceConfig=inference_config, + messages=converted_messages, + toolConfig=tool_config, + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + + path = f"{self.client.host}model/{model}/converse" + response = self._post(payload, path) + response_message = response["output"]["message"] + + usage_data = response["usage"] + usage = Usage( + input_tokens=usage_data.get("inputTokens"), + output_tokens=usage_data.get("outputTokens"), + total_tokens=usage_data.get("totalTokens"), + ) + + return self.response_to_message(response_message), usage + + @retry_procedure + def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401 + response = self.client.post(path, json=payload) + return raise_for_status(response).json() + + @staticmethod + def message_to_bedrock_spec(message: Message) -> dict: + bedrock_content = [] + try: + for content in message.content: + if isinstance(content, Text): + bedrock_content.append({"text": content.text}) + elif isinstance(content, ToolUse): + for tool_use in message.tool_use: + bedrock_content.append( + { + "toolUse": { + "toolUseId": tool_use.id, + "name": tool_use.name, + "input": tool_use.parameters, + } + } + ) + elif isinstance(content, ToolResult): + for tool_result in message.tool_result: + # try to parse the output as json + try: + output = json.loads(tool_result.output) + if isinstance(output, dict): + content = [{"json": output}] + else: + content = [{"text": str(output)}] + except json.JSONDecodeError: + content = [{"text": tool_result.output}] + + bedrock_content.append( + { + "toolResult": { + "toolUseId": tool_result.tool_use_id, + "content": content, + **({"status": "error"} if tool_result.is_error else {}), + } + } + ) + return {"role": message.role, "content": bedrock_content} + + except AttributeError: + raise Exception("Invalid message") + + @staticmethod + def response_to_message(response_message: dict) -> Message: + content = [] + if response_message["role"] == "user": + for block in response_message["content"]: + if "text" in block: + content.append(Text(block["text"])) + if "toolResult" in block: + content.append( + ToolResult( + tool_use_id=block["toolResult"]["toolResultId"], + output=block["toolResult"]["content"][0]["json"], + is_error=block["toolResult"].get("status") == "error", + ) + ) + return Message(role="user", content=content) + elif response_message["role"] == "assistant": + for block in response_message["content"]: + if "text" in block: + content.append(Text(block["text"])) + if "toolUse" in block: + content.append( + ToolUse( + id=block["toolUse"]["toolUseId"], + name=block["toolUse"]["name"], + parameters=block["toolUse"]["input"], + ) + ) + return Message(role="assistant", content=content) + raise Exception("Invalid response") + + @staticmethod + def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: + if len(tools) == 0: + return None # API requires a non-empty tool config or None + tools_added = set() + tool_config_list = [] + for tool in tools: + if tool.name in tools_added: + logging.warning(f"Tool {tool.name} already added to tool config. Skipping.") + continue + tool_config_list.append( + { + "toolSpec": { + "name": tool.name, + "description": tool.description, + "inputSchema": {"json": tool.parameters}, + } + } + ) + tools_added.add(tool.name) + tool_config = {"tools": tool_config_list} + return tool_config diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py new file mode 100644 index 000000000..84dc7515c --- /dev/null +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -0,0 +1,102 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange.message import Message +from exchange.providers.base import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import ( + messages_to_openai_spec, + openai_response_to_message, + tools_to_openai_spec, +) +from exchange.tool import Tool + + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class DatabricksProvider(Provider): + """Provides chat completions for models on Databricks serving endpoints + + Models are expected to follow the llm/v1/chat "task". This includes support + for foundation and external model endpoints + https://docs.databricks.com/en/machine-learning/model-serving/create-foundation-model-endpoints.html#create-generative-ai-model-serving-endpoints + """ + + def __init__(self, client: httpx.Client) -> None: + super().__init__() + self.client = client + + @classmethod + def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": + try: + url = os.environ["DATABRICKS_HOST"] + except KeyError: + raise RuntimeError( + "Failed to get DATABRICKS_HOST from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + ) + try: + key = os.environ["DATABRICKS_TOKEN"] + except KeyError: + raise RuntimeError( + "Failed to get DATABRICKS_TOKEN from the environment. See https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + ) + client = httpx.Client( + base_url=url, + auth=("token", key), + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: dict) -> Usage: + usage = data.pop("usage") + input_tokens = usage.get("prompt_tokens") + output_tokens = usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + payload = dict( + messages=[ + {"role": "system", "content": system}, + *messages_to_openai_spec(messages), + ], + tools=tools_to_openai_spec(tools) if tools else [], + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(model, payload) + message = openai_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, model: str, payload: dict) -> httpx.Response: + response = self.client.post( + f"serving-endpoints/{model}/invocations", + json=payload, + ) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py new file mode 100644 index 000000000..426aa79d5 --- /dev/null +++ b/packages/exchange/src/exchange/providers/google.py @@ -0,0 +1,154 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange import Message, Tool +from exchange.content import Text, ToolResult, ToolUse +from exchange.providers.base import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status +from exchange.providers.utils import raise_for_status + +GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class GoogleProvider(Provider): + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": + url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) + try: + key = os.environ["GOOGLE_API_KEY"] + except KeyError: + raise RuntimeError( + "Failed to get GOOGLE_API_KEY from the environment, see https://ai.google.dev/gemini-api/docs/api-key" + ) + + client = httpx.Client( + base_url=url, + headers={ + "Content-Type": "application/json", + }, + params={"key": key}, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: Dict) -> Usage: # noqa: ANN401 + usage = data.get("usageMetadata") + input_tokens = usage.get("promptTokenCount") + output_tokens = usage.get("candidatesTokenCount") + total_tokens = usage.get("totalTokenCount") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + @staticmethod + def google_response_to_message(response: Dict) -> Message: + candidates = response.get("candidates", []) + if candidates: + # Only use first candidate for now + candidate = candidates[0] + content_parts = candidate.get("content", {}).get("parts", []) + content = [] + for part in content_parts: + if "text" in part: + content.append(Text(text=part["text"])) + elif "functionCall" in part: + content.append( + ToolUse( + id=part["functionCall"].get("name", ""), + name=part["functionCall"].get("name", ""), + parameters=part["functionCall"].get("args", {}), + ) + ) + return Message(role="assistant", content=content) + + # If no valid candidates were found, return an empty message + return Message(role="assistant", content=[]) + + @staticmethod + def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: + if not tools: + return {} + converted_tools = [] + for tool in tools: + converted_tool: Dict[str, Any] = { + "name": tool.name, + "description": tool.description or "", + } + if tool.parameters["properties"]: + converted_tool["parameters"] = tool.parameters + converted_tools.append(converted_tool) + return {"functionDeclarations": converted_tools} + + @staticmethod + def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: + messages_spec = [] + for message in messages: + role = "user" if message.role == "user" else "model" + converted = {"role": role, "parts": []} + for content in message.content: + if isinstance(content, Text): + converted["parts"].append({"text": content.text}) + elif isinstance(content, ToolUse): + converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}}) + elif isinstance(content, ToolResult): + converted["parts"].append( + {"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} + ) + messages_spec.append(converted) + + if not messages_spec: + messages_spec.append({"role": "user", "parts": [{"text": "Ignore"}]}) + + return messages_spec + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: List[Tool] = [], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + tools_set = set() + unique_tools = [] + for tool in tools: + if tool.name not in tools_set: + unique_tools.append(tool) + tools_set.add(tool.name) + + payload = dict( + system_instruction={"parts": [{"text": system}]}, + contents=self.messages_to_google_spec(messages), + tools=self.tools_to_google_spec(tuple(unique_tools)), + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(payload, model) + message = self.google_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, payload: dict, model: str) -> httpx.Response: + response = self.client.post("models/" + model + ":generateContent", json=payload) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/ollama.py b/packages/exchange/src/exchange/providers/ollama.py new file mode 100644 index 000000000..acad89d9f --- /dev/null +++ b/packages/exchange/src/exchange/providers/ollama.py @@ -0,0 +1,45 @@ +import os +from typing import Type + +import httpx + +from exchange.providers.openai import OpenAiProvider + +OLLAMA_HOST = "http://localhost:11434/" +OLLAMA_MODEL = "mistral-nemo" + + +class OllamaProvider(OpenAiProvider): + """Provides chat completions for models hosted by Ollama""" + + __doc__ += f""" + +Here's an example profile configuration to try: + + ollama: + provider: ollama + processor: {OLLAMA_MODEL} + accelerator: {OLLAMA_MODEL} + moderator: passive + toolkits: + - name: developer + requires: {{}} +""" + + def __init__(self, client: httpx.Client) -> None: + print("PLEASE NOTE: the ollama provider is experimental, use with care") + super().__init__(client) + + @classmethod + def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider": + ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST) + timeout = httpx.Timeout(60 * 10) + + # from_env is expected to fail if required ENV variables are not + # available. Since this provider can run with defaults, we substitute + # an Ollama health check (GET /) to determine if the service is ok. + httpx.get(ollama_url, timeout=timeout) + + # When served by Ollama, the OpenAI API is available at the path "v1/". + client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout) + return cls(client) diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py new file mode 100644 index 000000000..dbd293b47 --- /dev/null +++ b/packages/exchange/src/exchange/providers/openai.py @@ -0,0 +1,101 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange.message import Message +from exchange.providers.base import Provider, Usage +from exchange.providers.utils import ( + messages_to_openai_spec, + openai_response_to_message, + openai_single_message_context_length_exceeded, + raise_for_status, + tools_to_openai_spec, +) +from exchange.tool import Tool +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status + +OPENAI_HOST = "https://api.openai.com/" + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class OpenAiProvider(Provider): + """Provides chat completions for models hosted directly by OpenAI""" + + def __init__(self, client: httpx.Client) -> None: + super().__init__() + self.client = client + + @classmethod + def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": + url = os.environ.get("OPENAI_HOST", OPENAI_HOST) + try: + key = os.environ["OPENAI_API_KEY"] + except KeyError: + raise RuntimeError( + "Failed to get OPENAI_API_KEY from the environment, see https://platform.openai.com/docs/api-reference/api-keys" + ) + client = httpx.Client( + base_url=url + "v1/", + auth=("Bearer", key), + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: dict) -> Usage: + usage = data.pop("usage") + input_tokens = usage.get("prompt_tokens") + output_tokens = usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}] + payload = dict( + messages=system_message + messages_to_openai_spec(messages), + model=model, + tools=tools_to_openai_spec(tools) if tools else [], + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(payload) + + # Check for context_length_exceeded error for single, long input message + if "error" in response and len(messages) == 1: + openai_single_message_context_length_exceeded(response["error"]) + + message = openai_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, payload: dict) -> dict: + # Note: While OpenAI and Ollama mount the API under "v1", this is + # conventional and not a strict requirement. For example, Azure OpenAI + # mounts the API under the deployment name, and "v1" is not in the URL. + # See https://github.com/openai/openai-openapi/blob/master/openapi.yaml + response = self.client.post("chat/completions", json=payload) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py new file mode 100644 index 000000000..4be7ac31e --- /dev/null +++ b/packages/exchange/src/exchange/providers/utils.py @@ -0,0 +1,185 @@ +import base64 +import json +import re +from typing import Any, Callable, Dict, List, Optional, Tuple + +import httpx +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.tool import Tool +from tenacity import retry_if_exception + + +def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable: + codes = codes or [] + + def predicate(exc: Exception) -> bool: + if isinstance(exc, httpx.HTTPStatusError): + if exc.response.status_code in codes: + return True + if above and exc.response.status_code >= above: + return True + return False + + return retry_if_exception(predicate) + + +def raise_for_status(response: httpx.Response) -> httpx.Response: + """Raise with reason text.""" + try: + response.raise_for_status() + return response + except httpx.HTTPStatusError as e: + response.read() + if response.text: + raise httpx.HTTPStatusError(f"{e}\n{response.text}", request=e.request, response=e.response) + else: + raise e + + +def encode_image(image_path: str) -> str: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]: + messages_spec = [] + for message in messages: + converted = {"role": message.role} + output = [] + for content in message.content: + if isinstance(content, Text): + converted["content"] = content.text + elif isinstance(content, ToolUse): + sanitized_name = re.sub(r"[^a-zA-Z0-9_-]", "_", content.name) + converted.setdefault("tool_calls", []).append( + { + "id": content.id, + "type": "function", + "function": { + "name": sanitized_name, + "arguments": json.dumps(content.parameters), + }, + } + ) + elif isinstance(content, ToolResult): + if content.output.startswith('"image:'): + image_path = content.output.replace('"image:', "").replace('"', "") + output.append( + { + "role": "tool", + "content": [ + { + "type": "text", + "text": "This tool result included an image that is uploaded in the next message.", + }, + ], + "tool_call_id": content.tool_use_id, + } + ) + # Note: it is possible to only do this when message == messages[-1] + # but it doesn't seem to hurt too much with tokens to keep this. + output.append( + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"}, + } + ], + } + ) + + else: + output.append( + { + "role": "tool", + "content": content.output, + "tool_call_id": content.tool_use_id, + } + ) + + if "content" in converted or "tool_calls" in converted: + output = [converted] + output + messages_spec.extend(output) + return messages_spec + + +def tools_to_openai_spec(tools: Tuple[Tool]) -> Dict[str, Any]: + tools_names = set() + result = [] + for tool in tools: + if tool.name in tools_names: + # we should never allow duplicate tools + raise ValueError(f"Duplicate tool name: {tool.name}") + result.append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + ) + tools_names.add(tool.name) + return result + + +def openai_response_to_message(response: dict) -> Message: + original = response["choices"][0]["message"] + content = [] + text = original.get("content") + if text: + content.append(Text(text=text)) + + tool_calls = original.get("tool_calls") + if tool_calls: + for tool_call in tool_calls: + try: + function_name = tool_call["function"]["name"] + # We occasionally see the model generate an invalid function name + # sending this back to openai raises a validation error + if not re.match(r"^[a-zA-Z0-9_-]+$", function_name): + content.append( + ToolUse( + id=tool_call["id"], + name=function_name, + parameters=tool_call["function"]["arguments"], + is_error=True, + error_message=f"The provided function name '{function_name}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", # noqa: E501 + ) + ) + else: + content.append( + ToolUse( + id=tool_call["id"], + name=function_name, + parameters=json.loads(tool_call["function"]["arguments"]), + ) + ) + except json.JSONDecodeError: + content.append( + ToolUse( + id=tool_call["id"], + name=tool_call["function"]["name"], + parameters=tool_call["function"]["arguments"], + is_error=True, + error_message=f"Could not interpret tool use parameters for id {tool_call['id']}: {tool_call['function']['arguments']}", # noqa: E501 + ) + ) + + return Message(role="assistant", content=content) + + +def openai_single_message_context_length_exceeded(error_dict: dict) -> None: + code = error_dict.get("code") + if code == "context_length_exceeded" or code == "string_above_max_length": + raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") + + +class InitialMessageTooLargeError(Exception): + """Custom error raised when the first input message in an exchange is too large.""" + + pass diff --git a/packages/exchange/src/exchange/token_usage_collector.py b/packages/exchange/src/exchange/token_usage_collector.py new file mode 100644 index 000000000..8f0801062 --- /dev/null +++ b/packages/exchange/src/exchange/token_usage_collector.py @@ -0,0 +1,27 @@ +from collections import defaultdict +from typing import Dict + +from exchange.providers.base import Usage + + +class _TokenUsageCollector: + def __init__(self) -> None: + self.usage_data = [] + + def collect(self, model: str, usage: Usage) -> None: + self.usage_data.append((model, usage)) + + def get_token_usage_group_by_model(self) -> Dict[str, Usage]: + usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0)) + for model, usage in self.usage_data: + usage_by_model = usage_group_by_model[model] + if usage is not None and usage.input_tokens is not None: + usage_by_model.input_tokens += usage.input_tokens + if usage is not None and usage.output_tokens is not None: + usage_by_model.output_tokens += usage.output_tokens + if usage is not None and usage.total_tokens is not None: + usage_by_model.total_tokens += usage.total_tokens + return usage_group_by_model + + +_token_usage_collector = _TokenUsageCollector() diff --git a/packages/exchange/src/exchange/tool.py b/packages/exchange/src/exchange/tool.py new file mode 100644 index 000000000..4ce9e7c50 --- /dev/null +++ b/packages/exchange/src/exchange/tool.py @@ -0,0 +1,55 @@ +import inspect +from typing import Any, Callable, Type + +from attrs import define + +from exchange.utils import json_schema, parse_docstring + + +@define +class Tool: + """A tool that can be used by a model. + + Attributes: + name (str): The name of the tool + description (str): A description of what the tool does + parameters dict[str, Any]: A json schema of the function signature + function (Callable): The python function that powers the tool + """ + + name: str + description: str + parameters: dict[str, Any] + function: Callable + + @classmethod + def from_function(cls: Type["Tool"], func: Any) -> "Tool": # noqa: ANN401 + """Create a tool instance from a function and its docstring + + The function must have a docstring - we require it to load the description + and parameter descriptions. This also supports a class instance with a __call__ + method. + """ + if inspect.isfunction(func) or inspect.ismethod(func): + name = func.__name__ + else: + name = func.__class__.__name__.lower() + func = func.__call__ + + description, param_descriptions = parse_docstring(func) + schema = json_schema(func) + + # Set the 'description' field of the schema to the arg's docstring description + for arg in param_descriptions: + arg_name, arg_description = arg["name"], arg["description"] + + if arg_name not in schema["properties"]: + raise ValueError(f"Argument {arg_name} found in docstring but not in schema") + schema["properties"][arg_name]["description"] = arg_description + + return cls( + name=name, + description=description, + parameters=schema, + function=func, + ) diff --git a/packages/exchange/src/exchange/utils.py b/packages/exchange/src/exchange/utils.py new file mode 100644 index 000000000..04d5ffa18 --- /dev/null +++ b/packages/exchange/src/exchange/utils.py @@ -0,0 +1,155 @@ +import inspect +import uuid +from importlib.metadata import entry_points +from typing import Any, Callable, Dict, List, Type, get_args, get_origin + +from griffe import ( + Docstring, + DocstringSection, + DocstringSectionParameters, + DocstringSectionText, +) + + +def create_object_id(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:24]}" + + +def compact(content: str) -> str: + """Replace any amount of whitespace with a single space""" + return " ".join(content.split()) + + +def parse_docstring(func: Callable) -> tuple[str, List[Dict]]: + """Get description and parameters from function docstring""" + function_args = list(inspect.signature(func).parameters.keys()) + text = str(func.__doc__) + docstring = Docstring(text) + + for style in ["google", "numpy", "sphinx"]: + parsed = docstring.parse(style) + + if not _check_section_is_present(parsed, DocstringSectionText): + continue + + if function_args and not _check_section_is_present(parsed, DocstringSectionParameters): + continue + break + else: # if we did not find a valid style in the for loop + raise ValueError( + f"Attempted to load from a function {func.__name__} with an invalid docstring. Parameter docs are required if the function has parameters. https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings" # noqa: E501 + ) + + description = None + parameters = [] + + for section in parsed: + if isinstance(section, DocstringSectionText): + description = compact(section.value) + elif isinstance(section, DocstringSectionParameters): + parameters = [arg.as_dict() for arg in section.value] + + docstring_args = [d["name"] for d in parameters] + if description is None: + raise ValueError("Docstring must include a description.") + + if not docstring_args == function_args: + extra_docstring_args = ", ".join(sorted(set(docstring_args) - set(function_args))) + extra_function_args = ", ".join(sorted(set(function_args) - set(docstring_args))) + if extra_docstring_args and extra_function_args: + raise ValueError( + f"Docstring args must match function args: docstring had extra {extra_docstring_args}; function had extra {extra_function_args}" # noqa: E501 + ) + elif extra_function_args: + raise ValueError(f"Docstring args must match function args: function had extra {extra_function_args}") + elif extra_docstring_args: + raise ValueError(f"Docstring args must match function args: docstring had extra {extra_docstring_args}") + else: + raise ValueError("Docstring args must match function args") + + return description, parameters + + +def _check_section_is_present( + parsed_docstring: List[DocstringSection], section_type: Type[DocstringSectionText] +) -> bool: + for section in parsed_docstring: + if isinstance(section, section_type): + return True + return False + + +def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401 + """Get the json schema for a function""" + signature = inspect.signature(func) + parameters = signature.parameters + + schema = { + "type": "object", + "properties": {}, + "required": [], + } + + for param_name, param in parameters.items(): + param_schema = {} + + if param.annotation is not inspect.Parameter.empty: + param_schema = _map_type_to_schema(param.annotation) + + if param.default is not inspect.Parameter.empty: + param_schema["default"] = param.default + + schema["properties"][param_name] = param_schema + + if param.default is inspect.Parameter.empty: + schema["required"].append(param_name) + + return schema + + +def _map_type_to_schema(py_type: Type) -> Dict[str, Any]: # noqa: ANN401 + origin = get_origin(py_type) + args = get_args(py_type) + + if origin is list or origin is tuple: + return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)} + elif origin is dict: + return { + "type": "object", + "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any), + } + elif py_type is int: + return {"type": "integer"} + elif py_type is bool: + return {"type": "boolean"} + elif py_type is float: + return {"type": "number"} + elif py_type is str: + return {"type": "string"} + else: + return {"type": "string"} + + +def load_plugins(group: str) -> dict: + """ + Load plugins based on a specified entry point group. + + This function iterates through all entry points registered under a specified group + + Args: + group (str): The entry point group to load plugins from. This should match the group specified + in the package setup where plugins are defined. + + Returns: + dict: A dictionary where each key is the entry point name, and the value is the loaded plugin object. + + Raises: + Exception: Propagates exceptions raised by entry point loading, which might occur if a plugin + is not found or if there are issues with the plugin's code. + """ + plugins = {} + # Access all entry points for the specified group and load each. + for entrypoint in entry_points(group=group): + plugin = entrypoint.load() # Load the plugin. + plugins[entrypoint.name] = plugin # Store the loaded plugin in the dictionary. + return plugins diff --git a/packages/exchange/tests/.ruff.toml b/packages/exchange/tests/.ruff.toml new file mode 100644 index 000000000..cddf42337 --- /dev/null +++ b/packages/exchange/tests/.ruff.toml @@ -0,0 +1,2 @@ +lint.select = ["E", "W", "F", "N"] +line-length = 120 \ No newline at end of file diff --git a/packages/exchange/tests/__init__.py b/packages/exchange/tests/__init__.py new file mode 100644 index 000000000..c2b89ac6d --- /dev/null +++ b/packages/exchange/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for exchange.""" diff --git a/packages/exchange/tests/conftest.py b/packages/exchange/tests/conftest.py new file mode 100644 index 000000000..684a446d7 --- /dev/null +++ b/packages/exchange/tests/conftest.py @@ -0,0 +1,36 @@ +import pytest + +from exchange.providers.base import Usage + + +@pytest.fixture +def dummy_tool(): + def _dummy_tool() -> str: + """An example tool""" + return "dummy response" + + return _dummy_tool + + +@pytest.fixture +def usage_factory(): + def _create_usage(input_tokens=100, output_tokens=200, total_tokens=300): + return Usage(input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens) + + return _create_usage + + +def read_file(filename: str) -> str: + """ + Read the contents of the file. + + Args: + filename (str): The path to the file, which can be relative or + absolute. If it is a plain filename, it is assumed to be in the + current working directory. + + Returns: + str: The contents of the file. + """ + assert filename == "test.txt" + return "hello exchange" diff --git a/packages/exchange/tests/providers/__init__.py b/packages/exchange/tests/providers/__init__.py new file mode 100644 index 000000000..4e13a800d --- /dev/null +++ b/packages/exchange/tests/providers/__init__.py @@ -0,0 +1 @@ +"""Tests for chat completion providers.""" diff --git a/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml b/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml new file mode 100644 index 000000000..3ac8a4fc0 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml @@ -0,0 +1,68 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + api-key: + - test_azure_api_key + connection: + - keep-alive + content-length: + - '139' + content-type: + - application/json + host: + - test.openai.azure.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview + response: + body: + string: '{"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":"stop","index":0,"logprobs":null,"message":{"content":"Hello! + How can I assist you today?","role":"assistant"}}],"created":1727230065,"id":"chatcmpl-ABBjN3AoYlxkP7Vg2lBvUhYeA6j5K","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":18,"total_tokens":27}} + + ' + headers: + Cache-Control: + - no-cache, must-revalidate + Content-Length: + - '825' + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 02:07:45 GMT + Set-Cookie: test_set_cookie + Strict-Transport-Security: + - max-age=31536000; includeSubDomains; preload + access-control-allow-origin: + - '*' + apim-request-id: + - 82e66ef8-ac07-4a43-b60f-9aecec1d8c81 + azureml-model-session: + - d145-20240919052126 + openai-organization: test_openai_org_key + x-accel-buffering: + - 'no' + x-content-type-options: + - nosniff + x-ms-client-request-id: + - 82e66ef8-ac07-4a43-b60f-9aecec1d8c81 + x-ms-rai-invoked: + - 'true' + x-ms-region: + - Switzerland North + x-ratelimit-remaining-requests: + - '79' + x-ratelimit-remaining-tokens: + - '79984' + x-request-id: + - 38db9001-8b16-4efe-84c9-620e10f18c3c + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml b/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml new file mode 100644 index 000000000..9da479790 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml @@ -0,0 +1,74 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. + Expect to need to read a file using read_file."}, {"role": "user", "content": + "What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools": + [{"type": "function", "function": {"name": "read_file", "description": "Read + the contents of the file.", "parameters": {"type": "object", "properties": {"filename": + {"type": "string", "description": "The path to the file, which can be relative + or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent + working directory."}}, "required": ["filename"]}}}]}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + api-key: + - test_azure_api_key + connection: + - keep-alive + content-length: + - '608' + content-type: + - application/json + host: + - test.openai.azure.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview + response: + body: + string: '{"choices":[{"content_filter_results":{},"finish_reason":"tool_calls","index":0,"logprobs":null,"message":{"content":null,"role":"assistant","tool_calls":[{"function":{"arguments":"{\n \"filename\": + \"test.txt\"\n}","name":"read_file"},"id":"call_a47abadDxlGKIWjvYYvGVAHa","type":"function"}]}}],"created":1727256650,"id":"chatcmpl-ABIeABbq5WVCq0e0AriGFaYDSih3P","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":16,"prompt_tokens":109,"total_tokens":125}} + + ' + headers: + Cache-Control: + - no-cache, must-revalidate + Content-Length: + - '769' + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 09:30:50 GMT + Set-Cookie: test_set_cookie + Strict-Transport-Security: + - max-age=31536000; includeSubDomains; preload + access-control-allow-origin: + - '*' + apim-request-id: + - 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339 + azureml-model-session: + - d145-20240919052126 + openai-organization: test_openai_org_key + x-accel-buffering: + - 'no' + x-content-type-options: + - nosniff + x-ms-client-request-id: + - 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339 + x-ms-rai-invoked: + - 'true' + x-ms-region: + - Switzerland North + x-ratelimit-remaining-requests: + - '79' + x-ratelimit-remaining-tokens: + - '79824' + x-request-id: + - 401bd803-b790-47b7-b098-98708d44f060 + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml b/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml new file mode 100644 index 000000000..88bc206ff --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml @@ -0,0 +1,68 @@ +interactions: +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: GET + uri: http://localhost:11434/ + response: + body: + string: Ollama is running + headers: + Content-Length: + - '17' + Content-Type: + - text/plain; charset=utf-8 + Date: + - Sun, 22 Sep 2024 23:40:13 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}], "model": "mistral-nemo"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '140' + content-type: + - application/json + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: POST + uri: http://localhost:11434/v1/chat/completions + response: + body: + string: "{\"id\":\"chatcmpl-429\",\"object\":\"chat.completion\",\"created\":1727048416,\"model\":\"mistral-nemo\",\"system_fingerprint\":\"fp_ollama\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Hello! + I'm here to help. How can I assist you today? Let's chat. \U0001F60A\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":23,\"total_tokens\":33}}\n" + headers: + Content-Length: + - '356' + Content-Type: + - application/json + Date: + - Sun, 22 Sep 2024 23:40:16 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml b/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml new file mode 100644 index 000000000..7271bf227 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml @@ -0,0 +1,75 @@ +interactions: +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: GET + uri: http://localhost:11434/ + response: + body: + string: Ollama is running + headers: + Content-Length: + - '17' + Content-Type: + - text/plain; charset=utf-8 + Date: + - Wed, 25 Sep 2024 09:23:08 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. + Expect to need to read a file using read_file."}, {"role": "user", "content": + "What are the contents of this file? test.txt"}], "model": "mistral-nemo", "tools": + [{"type": "function", "function": {"name": "read_file", "description": "Read + the contents of the file.", "parameters": {"type": "object", "properties": {"filename": + {"type": "string", "description": "The path to the file, which can be relative + or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent + working directory."}}, "required": ["filename"]}}}]}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '609' + content-type: + - application/json + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: POST + uri: http://localhost:11434/v1/chat/completions + response: + body: + string: '{"id":"chatcmpl-245","object":"chat.completion","created":1727256190,"model":"mistral-nemo","system_fingerprint":"fp_ollama","choices":[{"index":0,"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_z6fgu3z3","type":"function","function":{"name":"read_file","arguments":"{\"filename\":\"test.txt\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":112,"completion_tokens":21,"total_tokens":133}} + + ' + headers: + Content-Length: + - '425' + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 09:23:10 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml b/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml new file mode 100644 index 000000000..1a92eb36b --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml @@ -0,0 +1,80 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + authorization: + - Bearer test_openai_api_key + connection: + - keep-alive + content-length: + - '139' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"id\": \"chatcmpl-AAQTYi3DXJnltAfd5sUH1Wnzh69t3\",\n \"object\": + \"chat.completion\",\n \"created\": 1727048416,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": \"Hello! How can I assist you today?\",\n + \ \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\": + \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 18,\n \"completion_tokens\": + 9,\n \"total_tokens\": 27,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": + 0\n }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8c762399feb55739-SYD + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Sun, 22 Sep 2024 23:40:17 GMT + Server: + - cloudflare + Set-Cookie: test_set_cookie + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + content-length: + - '593' + openai-organization: test_openai_org_key + openai-processing-ms: + - '560' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=15552000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '200000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '199973' + x-ratelimit-reset-requests: + - 8.64s + x-ratelimit-reset-tokens: + - 8ms + x-request-id: + - req_22e26c840219cde3152eaba1ce89483b + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml b/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml new file mode 100644 index 000000000..30496fcb8 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml @@ -0,0 +1,90 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. + Expect to need to read a file using read_file."}, {"role": "user", "content": + "What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools": + [{"type": "function", "function": {"name": "read_file", "description": "Read + the contents of the file.", "parameters": {"type": "object", "properties": {"filename": + {"type": "string", "description": "The path to the file, which can be relative + or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent + working directory."}}, "required": ["filename"]}}}]}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + authorization: + - Bearer test_openai_api_key + connection: + - keep-alive + content-length: + - '608' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"id\": \"chatcmpl-ABIV2aZWVKQ774RAQ8KHYdNwkI5N7\",\n \"object\": + \"chat.completion\",\n \"created\": 1727256084,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n + \ \"id\": \"call_xXYlw4A7Ud1qtCopuK5gEJrP\",\n \"type\": + \"function\",\n \"function\": {\n \"name\": \"read_file\",\n + \ \"arguments\": \"{\\\"filename\\\":\\\"test.txt\\\"}\"\n }\n + \ }\n ],\n \"refusal\": null\n },\n \"logprobs\": + null,\n \"finish_reason\": \"tool_calls\"\n }\n ],\n \"usage\": + {\n \"prompt_tokens\": 107,\n \"completion_tokens\": 15,\n \"total_tokens\": + 122,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": 0\n + \ }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8c89f19fed997e43-SYD + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 09:21:25 GMT + Server: + - cloudflare + Set-Cookie: test_set_cookie + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + content-length: + - '844' + openai-organization: test_openai_org_key + openai-processing-ms: + - '266' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '200000' + x-ratelimit-remaining-requests: + - '9991' + x-ratelimit-remaining-tokens: + - '199952' + x-ratelimit-reset-requests: + - 1m9.486s + x-ratelimit-reset-tokens: + - 14ms + x-request-id: + - req_ff6b5d65c24f40e1faaf049c175e718d + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml b/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml new file mode 100644 index 000000000..1b9691d29 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml @@ -0,0 +1,86 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What does the first entry in the menu say?"}, {"role": + "assistant", "tool_calls": [{"id": "xyz", "type": "function", "function": {"name": + "screenshot", "arguments": "{}"}}]}, {"role": "tool", "content": [{"type": "text", + "text": "This tool result included an image that is uploaded in the next message."}], + "tool_call_id": "xyz"}, {"role": "user", "content": [{"type": "image_url", "image_url": + {"url": ""}}]}], + "model": "gpt-4o-mini"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + authorization: + - Bearer test_openai_api_key + connection: + - keep-alive + content-length: + - '78932' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"id\": \"chatcmpl-ABIA0YzOHlhqb02K8Ay4Jwsw6xOpk\",\n \"object\": + \"chat.completion\",\n \"created\": 1727254780,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": \"The first entry in the menu says \\\"Ask + Goose.\\\"\",\n \"refusal\": null\n },\n \"logprobs\": null,\n + \ \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": + 14230,\n \"completion_tokens\": 11,\n \"total_tokens\": 14241,\n \"completion_tokens_details\": + {\n \"reasoning_tokens\": 0\n }\n },\n \"system_fingerprint\": \"fp_e9627b5346\"\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8c89d1c45d98a883-SYD + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 08:59:41 GMT + Server: + - cloudflare + Set-Cookie: test_set_cookie + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + content-length: + - '613' + openai-organization: test_openai_org_key + openai-processing-ms: + - '1289' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '200000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '199177' + x-ratelimit-reset-requests: + - 8.64s + x-ratelimit-reset-tokens: + - 246ms + x-request-id: + - req_9503b21e31db78c4ebd2b71b304cea72 + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/conftest.py b/packages/exchange/tests/providers/conftest.py new file mode 100644 index 000000000..2b35958fb --- /dev/null +++ b/packages/exchange/tests/providers/conftest.py @@ -0,0 +1,129 @@ +import os +import re +from typing import Type, Tuple + +import pytest + +from exchange import Message, ToolUse, ToolResult, Tool +from exchange.providers import Usage, Provider +from tests.conftest import read_file + +OPENAI_API_KEY = "test_openai_api_key" +OPENAI_ORG_ID = "test_openai_org_key" +OPENAI_PROJECT_ID = "test_openai_project_id" + + +@pytest.fixture +def default_openai_env(monkeypatch): + """ + This fixture prevents OpenAIProvider.from_env() from erring on missing + environment variables. + + When running VCR tests for the first time or after deleting a cassette + recording, set required environment variables, so that real requests don't + fail. Subsequent runs use the recorded data, so don't need them. + """ + if "OPENAI_API_KEY" not in os.environ: + monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY) + + +AZURE_ENDPOINT = "https://test.openai.azure.com" +AZURE_DEPLOYMENT_NAME = "test-azure-deployment" +AZURE_API_VERSION = "2024-05-01-preview" +AZURE_API_KEY = "test_azure_api_key" + + +@pytest.fixture +def default_azure_env(monkeypatch): + """ + This fixture prevents AzureProvider.from_env() from erring on missing + environment variables. + + When running VCR tests for the first time or after deleting a cassette + recording, set required environment variables, so that real requests don't + fail. Subsequent runs use the recorded data, so don't need them. + """ + if "AZURE_CHAT_COMPLETIONS_HOST_NAME" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_HOST_NAME", AZURE_ENDPOINT) + if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", AZURE_DEPLOYMENT_NAME) + if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", AZURE_API_VERSION) + if "AZURE_CHAT_COMPLETIONS_KEY" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY) + + +@pytest.fixture(scope="module") +def vcr_config(): + """ + This scrubs sensitive data and gunzips bodies when in recording mode. + + Without this, you would leak cookies and auth tokens in the cassettes. + Also, depending on the request, some responses would be binary encoded + while others plain json. This ensures all bodies are human-readable. + """ + return { + "decode_compressed_response": True, + "filter_headers": [ + ("authorization", "Bearer " + OPENAI_API_KEY), + ("openai-organization", OPENAI_ORG_ID), + ("openai-project", OPENAI_PROJECT_ID), + ("cookie", None), + ], + "before_record_request": scrub_request_url, + "before_record_response": scrub_response_headers, + } + + +def scrub_request_url(request): + """ + This scrubs sensitive request data in provider-specific way. Note that headers + are case-sensitive! + """ + if "openai.azure.com" in request.uri: + request.uri = re.sub(r"https://[^/]+", AZURE_ENDPOINT, request.uri) + request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri) + request.headers["host"] = AZURE_ENDPOINT.replace("https://", "") + request.headers["api-key"] = AZURE_API_KEY + + return request + + +def scrub_response_headers(response): + """ + This scrubs sensitive response headers. Note they are case-sensitive! + """ + response["headers"]["openai-organization"] = OPENAI_ORG_ID + response["headers"]["Set-Cookie"] = "test_set_cookie" + return response + + +def complete(provider_cls: Type[Provider], model: str) -> Tuple[Message, Usage]: + provider = provider_cls.from_env() + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + return provider.complete(model=model, system=system, messages=messages, tools=None) + + +def tools(provider_cls: Type[Provider], model: str) -> Tuple[Message, Usage]: + provider = provider_cls.from_env() + system = "You are a helpful assistant. Expect to need to read a file using read_file." + messages = [Message.user("What are the contents of this file? test.txt")] + return provider.complete(model=model, system=system, messages=messages, tools=(Tool.from_function(read_file),)) + + +def vision(provider_cls: Type[Provider], model: str) -> Tuple[Message, Usage]: + provider = provider_cls.from_env() + system = "You are a helpful assistant." + messages = [ + Message.user("What does the first entry in the menu say?"), + Message( + role="assistant", + content=[ToolUse(id="xyz", name="screenshot", parameters={})], + ), + Message( + role="user", + content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')], + ), + ] + return provider.complete(model=model, system=system, messages=messages, tools=None) diff --git a/packages/exchange/tests/providers/test_anthropic.py b/packages/exchange/tests/providers/test_anthropic.py new file mode 100644 index 000000000..a6f5bc689 --- /dev/null +++ b/packages/exchange/tests/providers/test_anthropic.py @@ -0,0 +1,174 @@ +import os +from unittest.mock import patch + +import httpx +import pytest +from exchange import Message, Text +from exchange.content import ToolResult, ToolUse +from exchange.providers.anthropic import AnthropicProvider +from exchange.tool import Tool + + +def example_fn(param: str) -> None: + """ + Testing function. + + Args: + param (str): Description of param1 + """ + pass + + +@pytest.fixture +@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_api_key"}) +def anthropic_provider(): + return AnthropicProvider.from_env() + + +def test_anthropic_response_to_text_message() -> None: + response = { + "content": [{"type": "text", "text": "Hello from Claude!"}], + } + message = AnthropicProvider.anthropic_response_to_message(response) + assert message.content[0].text == "Hello from Claude!" + + +def test_anthropic_response_to_tool_use_message() -> None: + response = { + "content": [ + { + "type": "tool_use", + "id": "1", + "name": "example_fn", + "input": {"param": "value"}, + } + ], + } + message = AnthropicProvider.anthropic_response_to_message(response) + assert message.content[0].id == "1" + assert message.content[0].name == "example_fn" + assert message.content[0].parameters == {"param": "value"} + + +def test_tools_to_anthropic_spec() -> None: + tools = (Tool.from_function(example_fn),) + expected_spec = [ + { + "name": "example_fn", + "description": "Testing function.", + "input_schema": { + "type": "object", + "properties": {"param": {"type": "string", "description": "Description of param1"}}, + "required": ["param"], + }, + } + ] + result = AnthropicProvider.tools_to_anthropic_spec(tools) + assert result == expected_spec + + +def test_message_text_to_anthropic_spec() -> None: + messages = [Message.user("Hello, Claude")] + expected_spec = [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello, Claude"}], + } + ] + result = AnthropicProvider.messages_to_anthropic_spec(messages) + assert result == expected_spec + + +def test_messages_to_anthropic_spec() -> None: + messages = [ + Message(role="user", content=[Text(text="Hello, Claude")]), + Message( + role="assistant", + content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), + ] + actual_spec = AnthropicProvider.messages_to_anthropic_spec(messages) + # != + expected_spec = [ + {"role": "user", "content": [{"type": "text", "text": "Hello, Claude"}]}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "1", + "name": "example_fn", + "input": {"param": "value"}, + } + ], + }, + { + "role": "user", + "content": [{"type": "tool_result", "tool_use_id": "1", "content": "Result"}], + }, + ] + assert actual_spec == expected_spec + + +@patch("httpx.Client.post") +@patch("logging.warning") +@patch("logging.error") +def test_anthropic_completion(mock_error, mock_warning, mock_post, anthropic_provider): + mock_response = { + "content": [{"type": "text", "text": "Hello from Claude!"}], + "usage": {"input_tokens": 10, "output_tokens": 25}, + } + + # First attempts fail with status code 429, 2nd succeeds + def create_response(status_code, json_data=None): + response = httpx.Response(status_code) + response._content = httpx._content.json_dumps(json_data or {}).encode() + response._request = httpx.Request("POST", "https://api.anthropic.com/v1/messages") + return response + + mock_post.side_effect = [ + create_response(429), # 1st attempt + create_response(200, mock_response), # Final success + ] + + model = "claude-3-5-sonnet-20240620" + system = "You are a helpful assistant." + messages = [Message.user("Hello, Claude")] + + reply_message, reply_usage = anthropic_provider.complete(model=model, system=system, messages=messages) + + assert reply_message.content == [Text(text="Hello from Claude!")] + assert reply_usage.total_tokens == 35 + assert mock_post.call_count == 2 + mock_post.assert_any_call( + "https://api.anthropic.com/v1/messages", + json={ + "system": system, + "model": model, + "max_tokens": 4096, + "messages": [ + *[ + { + "role": msg.role, + "content": [{"type": "text", "text": msg.content[0].text}], + } + for msg in messages + ], + ], + }, + ) + + +@pytest.mark.integration +def test_anthropic_integration(): + provider = AnthropicProvider.from_env() + model = "claude-3-5-sonnet-20240620" # updated model to a known valid model + system = "You are a helpful assistant." + messages = [Message.user("Hello, Claude")] + + # Run the completion + reply = provider.complete(model=model, system=system, messages=messages) + + assert reply[0].content is not None + print("Completion content from Anthropic:", reply[0].content) diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py new file mode 100644 index 000000000..adafabedb --- /dev/null +++ b/packages/exchange/tests/providers/test_azure.py @@ -0,0 +1,48 @@ +import os + +import pytest + +from exchange import Text, ToolUse +from exchange.providers.azure import AzureProvider +from .conftest import complete, tools + +AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") + + +@pytest.mark.vcr() +def test_azure_complete(default_azure_env): + reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL) + + assert reply_message.content == [Text(text="Hello! How can I assist you today?")] + assert reply_usage.total_tokens == 27 + + +@pytest.mark.integration +def test_azure_complete_integration(): + reply = complete(AzureProvider, AZURE_MODEL) + + assert reply[0].content is not None + print("Completion content from Azure:", reply[0].content) + + +@pytest.mark.vcr() +def test_azure_tools(default_azure_env): + reply_message, reply_usage = tools(AzureProvider, AZURE_MODEL) + + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "call_a47abadDxlGKIWjvYYvGVAHa" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 125 + + +@pytest.mark.integration +def test_azure_tools_integration(): + reply = tools(AzureProvider, AZURE_MODEL) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py new file mode 100644 index 000000000..2525f650b --- /dev/null +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -0,0 +1,228 @@ +import logging +import os +from unittest.mock import patch + +import pytest +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.providers.bedrock import BedrockProvider +from exchange.tool import Tool + +logger = logging.getLogger(__name__) + + +@pytest.fixture +@patch.dict( + os.environ, + { + "AWS_REGION": "us-east-1", + "AWS_ACCESS_KEY_ID": "fake-access-key", + "AWS_SECRET_ACCESS_KEY": "fake-secret-key", + "AWS_SESSION_TOKEN": "fake-session-token", + }, +) +def bedrock_provider(): + return BedrockProvider.from_env() + + +@patch("time.time", return_value=1624250000) +def test_sign_and_get_headers(mock_time, bedrock_provider): + # Create sample values + method = "POST" + url = "https://bedrock-runtime.us-east-1.amazonaws.com/some/path" + payload = {"key": "value"} + service = "bedrock" + # Generate headers + headers = bedrock_provider.client.sign_and_get_headers( + method, + url, + payload, + service, + ) + # Assert that headers contain expected keys + assert "Authorization" in headers + assert "Content-Type" in headers + assert "X-Amz-date" in headers + assert "x-amz-content-sha256" in headers + assert "X-Amz-Security-Token" in headers + + +@patch("httpx.Client.post") +def test_complete(mock_post, bedrock_provider): + # Mocked response from the server + mock_response = { + "output": {"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}}, + "usage": {"inputTokens": 10, "outputTokens": 15, "totalTokens": 25}, + } + mock_post.return_value.json.return_value = mock_response + + model = "test-model" + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + tools = () + + reply_message, reply_usage = bedrock_provider.complete(model=model, system=system, messages=messages, tools=tools) + + # Assertions for reply message + assert reply_message.content[0].text == "Hello, world!" + assert reply_usage.total_tokens == 25 + + +def test_message_to_bedrock_spec_text(bedrock_provider): + message = Message(role="user", content=[Text("Hello, world!")]) + expected = {"role": "user", "content": [{"text": "Hello, world!"}]} + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_tool_use(bedrock_provider): + tool_use = ToolUse(id="tool-1", name="WordCount", parameters={"text": "Hello, world!"}) + message = Message(role="assistant", content=[tool_use]) + expected = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "WordCount", + "input": {"text": "Hello, world!"}, + } + } + ], + } + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_tool_result(bedrock_provider): + message = Message( + role="assistant", + content=[ToolUse(id="tool-1", name="WordCount", parameters={"text": "Hello, world!"})], + ) + expected = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "WordCount", + "input": {"text": "Hello, world!"}, + } + } + ], + } + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_tool_result_text(bedrock_provider): + tool_result = ToolResult(tool_use_id="tool-1", output="Error occurred", is_error=True) + message = Message(role="user", content=[tool_result]) + expected = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tool-1", + "content": [{"text": "Error occurred"}], + "status": "error", + } + } + ], + } + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_invalid(bedrock_provider): + with pytest.raises(Exception): + bedrock_provider.message_to_bedrock_spec(Message(role="user", content=[])) + + +def test_response_to_message_text(bedrock_provider): + response = {"role": "user", "content": [{"text": "Hello, world!"}]} + message = bedrock_provider.response_to_message(response) + assert message.role == "user" + assert message.content[0].text == "Hello, world!" + + +def test_response_to_message_tool_use(bedrock_provider): + response = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "WordCount", + "input": {"text": "Hello, world!"}, + } + } + ], + } + message = bedrock_provider.response_to_message(response) + assert message.role == "assistant" + assert message.content[0].name == "WordCount" + assert message.content[0].parameters == {"text": "Hello, world!"} + + +def test_response_to_message_tool_result(bedrock_provider): + response = { + "role": "user", + "content": [ + { + "toolResult": { + "toolResultId": "tool-1", + "content": [{"json": {"result": 2}}], + } + } + ], + } + message = bedrock_provider.response_to_message(response) + assert message.role == "user" + assert message.content[0].tool_use_id == "tool-1" + assert message.content[0].output == {"result": 2} + + +def test_response_to_message_invalid(bedrock_provider): + with pytest.raises(Exception): + bedrock_provider.response_to_message({}) + + +def test_tools_to_bedrock_spec(bedrock_provider): + def word_count(text: str): + return len(text.split()) + + tool = Tool( + name="WordCount", + description="Counts words.", + parameters={"text": "string"}, + function=word_count, + ) + expected = { + "tools": [ + { + "toolSpec": { + "name": "WordCount", + "description": "Counts words.", + "inputSchema": {"json": {"text": "string"}}, + } + } + ] + } + assert bedrock_provider.tools_to_bedrock_spec((tool,)) == expected + + +def test_tools_to_bedrock_spec_duplicate(bedrock_provider): + def word_count(text: str): + return len(text.split()) + + tool = Tool( + name="WordCount", + description="Counts words.", + parameters={"text": "string"}, + function=word_count, + ) + tool_duplicate = Tool( + name="WordCount", + description="Counts words.", + parameters={"text": "string"}, + function=word_count, + ) + tools = bedrock_provider.tools_to_bedrock_spec((tool, tool_duplicate)) + assert set(tool["toolSpec"]["name"] for tool in tools["tools"]) == {"WordCount"} diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py new file mode 100644 index 000000000..3c1421146 --- /dev/null +++ b/packages/exchange/tests/providers/test_databricks.py @@ -0,0 +1,49 @@ +import os +from unittest.mock import patch + +import pytest +from exchange import Message, Text +from exchange.providers.databricks import DatabricksProvider + + +@pytest.fixture +@patch.dict( + os.environ, + {"DATABRICKS_HOST": "http://test-host", "DATABRICKS_TOKEN": "test_token"}, +) +def databricks_provider(): + return DatabricksProvider.from_env() + + +@patch("httpx.Client.post") +@patch("time.sleep", return_value=None) +@patch("logging.warning") +@patch("logging.error") +def test_databricks_completion(mock_error, mock_warning, mock_sleep, mock_post, databricks_provider): + mock_response = { + "choices": [{"message": {"role": "assistant", "content": "Hello!"}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35}, + } + mock_post.return_value.json.return_value = mock_response + + model = "my-databricks-model" + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + tools = () + + reply_message, reply_usage = databricks_provider.complete( + model=model, system=system, messages=messages, tools=tools + ) + + assert reply_message.content == [Text(text="Hello!")] + assert reply_usage.total_tokens == 35 + assert mock_post.call_count == 1 + mock_post.assert_called_once_with( + "serving-endpoints/my-databricks-model/invocations", + json={ + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": "Hello"}, + ] + }, + ) diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py new file mode 100644 index 000000000..47ad46b43 --- /dev/null +++ b/packages/exchange/tests/providers/test_google.py @@ -0,0 +1,147 @@ +import os +from unittest.mock import patch + +import httpx +import pytest +from exchange import Message, Text +from exchange.content import ToolResult, ToolUse +from exchange.providers.google import GoogleProvider +from exchange.tool import Tool + + +def example_fn(param: str) -> None: + """ + Testing function. + + Args: + param (str): Description of param1 + """ + pass + + +@pytest.fixture +@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) +def google_provider(): + return GoogleProvider.from_env() + + +def test_google_response_to_text_message() -> None: + response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]} + message = GoogleProvider.google_response_to_message(response) + assert message.content[0].text == "Hello from Gemini!" + + +def test_google_response_to_tool_use_message() -> None: + response = { + "candidates": [ + { + "content": { + "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}], + "role": "model", + } + } + ] + } + + message = GoogleProvider.google_response_to_message(response) + assert message.content[0].name == "example_fn" + assert message.content[0].parameters == {"param": "value"} + + +def test_tools_to_google_spec() -> None: + tools = (Tool.from_function(example_fn),) + expected_spec = { + "functionDeclarations": [ + { + "name": "example_fn", + "description": "Testing function.", + "parameters": { + "type": "object", + "properties": {"param": {"type": "string", "description": "Description of param1"}}, + "required": ["param"], + }, + } + ] + } + result = GoogleProvider.tools_to_google_spec(tools) + assert result == expected_spec + + +def test_message_text_to_google_spec() -> None: + messages = [Message.user("Hello, Gemini")] + expected_spec = [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}] + result = GoogleProvider.messages_to_google_spec(messages) + assert result == expected_spec + + +def test_messages_to_google_spec() -> None: + messages = [ + Message(role="user", content=[Text(text="Hello, Gemini")]), + Message( + role="assistant", + content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), + ] + actual_spec = GoogleProvider.messages_to_google_spec(messages) + # != + expected_spec = [ + {"role": "user", "parts": [{"text": "Hello, Gemini"}]}, + {"role": "model", "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}]}, + {"role": "user", "parts": [{"functionResponse": {"name": "1", "response": {"content": "Result"}}}]}, + ] + + assert actual_spec == expected_spec + + +@patch("httpx.Client.post") +@patch("logging.warning") +@patch("logging.error") +def test_google_completion(mock_error, mock_warning, mock_post, google_provider): + mock_response = { + "candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}], + "usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13}, + } + + # First attempts fail with status code 429, 2nd succeeds + def create_response(status_code, json_data=None): + response = httpx.Response(status_code) + response._content = httpx._content.json_dumps(json_data or {}).encode() + response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/") + return response + + mock_post.side_effect = [ + create_response(429), # 1st attempt + create_response(200, mock_response), # Final success + ] + + model = "gemini-1.5-flash" + system = "You are a helpful assistant." + messages = [Message.user("Hello, Gemini")] + + reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages) + + assert reply_message.content == [Text(text="Hello from Gemini!")] + assert reply_usage.total_tokens == 13 + assert mock_post.call_count == 2 + mock_post.assert_any_call( + "models/gemini-1.5-flash:generateContent", + json={ + "system_instruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}], + }, + ) + + +@pytest.mark.integration +def test_google_integration(): + provider = GoogleProvider.from_env() + model = "gemini-1.5-flash" # updated model to a known valid model + system = "You are a helpful assistant." + messages = [Message.user("Hello, Gemini")] + + # Run the completion + reply = provider.complete(model=model, system=system, messages=messages) + + assert reply[0].content is not None + print("Completion content from Google:", reply[0].content) diff --git a/packages/exchange/tests/providers/test_ollama.py b/packages/exchange/tests/providers/test_ollama.py new file mode 100644 index 000000000..5a66c482c --- /dev/null +++ b/packages/exchange/tests/providers/test_ollama.py @@ -0,0 +1,48 @@ +import os + +import pytest + +from exchange import Text, ToolUse +from exchange.providers.ollama import OllamaProvider, OLLAMA_MODEL +from .conftest import complete, tools + +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL) + + +@pytest.mark.vcr() +def test_ollama_complete(): + reply_message, reply_usage = complete(OllamaProvider, OLLAMA_MODEL) + + assert reply_message.content == [Text(text="Hello! I'm here to help. How can I assist you today? Let's chat. 😊")] + assert reply_usage.total_tokens == 33 + + +@pytest.mark.integration +def test_ollama_complete_integration(): + reply = complete(OllamaProvider, OLLAMA_MODEL) + + assert reply[0].content is not None + print("Completion content from OpenAI:", reply[0].content) + + +@pytest.mark.vcr() +def test_ollama_tools(): + reply_message, reply_usage = tools(OllamaProvider, OLLAMA_MODEL) + + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "call_z6fgu3z3" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 133 + + +@pytest.mark.integration +def test_ollama_tools_integration(): + reply = tools(OllamaProvider, OLLAMA_MODEL) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} diff --git a/packages/exchange/tests/providers/test_openai.py b/packages/exchange/tests/providers/test_openai.py new file mode 100644 index 000000000..45bc62050 --- /dev/null +++ b/packages/exchange/tests/providers/test_openai.py @@ -0,0 +1,63 @@ +import os + +import pytest + +from exchange import Text, ToolUse +from exchange.providers.openai import OpenAiProvider +from .conftest import complete, vision, tools + +OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") + + +@pytest.mark.vcr() +def test_openai_complete(default_openai_env): + reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) + + assert reply_message.content == [Text(text="Hello! How can I assist you today?")] + assert reply_usage.total_tokens == 27 + + +@pytest.mark.integration +def test_openai_complete_integration(): + reply = complete(OpenAiProvider, OPENAI_MODEL) + + assert reply[0].content is not None + print("Completion content from OpenAI:", reply[0].content) + + +@pytest.mark.vcr() +def test_openai_tools(default_openai_env): + reply_message, reply_usage = tools(OpenAiProvider, OPENAI_MODEL) + + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "call_xXYlw4A7Ud1qtCopuK5gEJrP" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 122 + + +@pytest.mark.integration +def test_openai_tools_integration(): + reply = tools(OpenAiProvider, OPENAI_MODEL) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + + +@pytest.mark.vcr() +def test_openai_vision(default_openai_env): + reply_message, reply_usage = vision(OpenAiProvider, OPENAI_MODEL) + + assert reply_message.content == [Text(text='The first entry in the menu says "Ask Goose."')] + assert reply_usage.total_tokens == 14241 + + +@pytest.mark.integration +def test_openai_vision_integration(): + reply = vision(OpenAiProvider, OPENAI_MODEL) + + assert "ask goose" in reply[0].text.lower() diff --git a/packages/exchange/tests/providers/test_provider_utils.py b/packages/exchange/tests/providers/test_provider_utils.py new file mode 100644 index 000000000..5ad0135ea --- /dev/null +++ b/packages/exchange/tests/providers/test_provider_utils.py @@ -0,0 +1,245 @@ +from copy import deepcopy +import json +from unittest.mock import Mock +from attrs import asdict +import httpx +import pytest +from unittest.mock import patch + +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.providers.utils import ( + messages_to_openai_spec, + openai_response_to_message, + raise_for_status, + tools_to_openai_spec, +) +from exchange.tool import Tool + +OPEN_AI_TOOL_USE_RESPONSE = response = { + "choices": [ + { + "role": "assistant", + "message": { + "tool_calls": [ + { + "id": "1", + "function": { + "name": "example_fn", + "arguments": json.dumps( + { + "param": "value", + } + ), + # TODO: should this handle dict's as well? + }, + } + ], + }, + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35, + }, +} + + +def example_fn(param: str) -> None: + """ + Testing function. + + Args: + param (str): Description of param1 + """ + pass + + +def example_fn_two() -> str: + """ + Second testing function + + Returns: + str: Description of return value + """ + pass + + +def test_raise_for_status_success() -> None: + response = Mock(spec=httpx.Response) + response.status_code = 200 + + result = raise_for_status(response) + + assert result == response + + +def test_raise_for_status_failure_with_text() -> None: + response = Mock(spec=httpx.Response) + response.status_code = 404 + response.text = "Not Found: John Cena" + + try: + raise_for_status(response) + except httpx.HTTPStatusError as e: + assert e.response == response + assert str(e) == "404 Not Found: John Cena" + assert e.request is None + + +def test_raise_for_status_failure_without_text() -> None: + response = Mock(spec=httpx.Response) + response.status_code = 500 + response.text = "" + + try: + raise_for_status(response) + except httpx.HTTPStatusError as e: + assert e.response == response + assert str(e) == "500 Internal Server Error" + assert e.request is None + + +def test_messages_to_openai_spec() -> None: + messages = [ + Message(role="assistant", content=[Text("Hello!")]), + Message(role="user", content=[Text("How are you?")]), + Message( + role="assistant", + content=[ToolUse(id=1, name="tool1", parameters={"param1": "value1"})], + ), + Message(role="user", content=[ToolResult(tool_use_id=1, output="Result")]), + ] + + spec = messages_to_openai_spec(messages) + + assert spec == [ + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": 1, + "type": "function", + "function": { + "name": "tool1", + "arguments": '{"param1": "value1"}', + }, + } + ], + }, + { + "role": "tool", + "content": "Result", + "tool_call_id": 1, + }, + ] + + +def test_tools_to_openai_spec() -> None: + tools = (Tool.from_function(example_fn), Tool.from_function(example_fn_two)) + assert len(tools_to_openai_spec(tools)) == 2 + + +def test_tools_to_openai_spec_duplicate() -> None: + tools = (Tool.from_function(example_fn), Tool.from_function(example_fn)) + with pytest.raises(ValueError): + tools_to_openai_spec(tools) + + +def test_tools_to_openai_spec_single() -> None: + tools = Tool.from_function(example_fn) + expected_spec = [ + { + "type": "function", + "function": { + "name": "example_fn", + "description": "Testing function.", + "parameters": { + "type": "object", + "properties": { + "param": { + "type": "string", + "description": "Description of param1", + } + }, + "required": ["param"], + }, + }, + }, + ] + result = tools_to_openai_spec((tools,)) + assert result == expected_spec + + +def test_tools_to_openai_spec_empty() -> None: + tools = () + expected_spec = [] + assert tools_to_openai_spec(tools) == expected_spec + + +def test_openai_response_to_message_text() -> None: + response = { + "choices": [ + { + "role": "assistant", + "message": {"content": "Hello from John Cena!"}, + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35, + }, + } + + message = openai_response_to_message(response) + + actual = asdict(message) + expect = asdict( + Message( + role="assistant", + content=[Text("Hello from John Cena!")], + ) + ) + actual.pop("id") + expect.pop("id") + assert actual == expect + + +def test_openai_response_to_message_valid_tooluse() -> None: + response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE) + message = openai_response_to_message(response) + actual = asdict(message) + expect = asdict( + Message( + role="assistant", + content=[ToolUse(id=1, name="example_fn", parameters={"param": "value"})], + ) + ) + actual.pop("id") + actual["content"][0].pop("id") + expect.pop("id") + expect["content"][0].pop("id") + assert actual == expect + + +def test_openai_response_to_message_invalid_func_name() -> None: + response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE) + response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] = "invalid fn" + message = openai_response_to_message(response) + assert message.content[0].name == "invalid fn" + assert json.loads(message.content[0].parameters) == {"param": "value"} + assert message.content[0].is_error + assert message.content[0].error_message.startswith("The provided function name") + + +@patch("json.loads", side_effect=json.JSONDecodeError("error", "doc", 0)) +def test_openai_response_to_message_json_decode_error(mock_json) -> None: + response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE) + message = openai_response_to_message(response) + assert message.content[0].name == "example_fn" + assert message.content[0].is_error + assert message.content[0].error_message.startswith("Could not interpret tool use") diff --git a/packages/exchange/tests/test_exchange.py b/packages/exchange/tests/test_exchange.py new file mode 100644 index 000000000..f01ef4694 --- /dev/null +++ b/packages/exchange/tests/test_exchange.py @@ -0,0 +1,763 @@ +from typing import List, Tuple + +import pytest + +from exchange.checkpoint import Checkpoint, CheckpointData +from exchange.content import Text, ToolResult, ToolUse +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators import PassiveModerator +from exchange.providers import Provider, Usage +from exchange.tool import Tool + + +def dummy_tool() -> str: + """An example tool""" + return "dummy response" + + +too_long_output = "x" * (2**20 + 1) +too_long_token_output = "word " * 128000 + + +def no_overlapping_checkpoints(exchange: Exchange) -> bool: + """Assert that there are no overlapping checkpoints in the exchange.""" + for i, checkpoint in enumerate(exchange.checkpoint_data.checkpoints): + for other_checkpoint in exchange.checkpoint_data.checkpoints[i + 1 :]: + if not checkpoint.end_index < other_checkpoint.start_index: + return False + return True + + +def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]: + return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints] + + +class MockProvider(Provider): + def __init__(self, sequence: List[Message], usage_dicts: List[dict]): + # We'll use init to provide a preplanned reply sequence + self.sequence = sequence + self.call_count = 0 + self.usage_dicts = usage_dicts + + @staticmethod + def get_usage(data: dict) -> Usage: + usage = data.pop("usage") + input_tokens = usage.get("input_tokens") + output_tokens = usage.get("output_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message: + output = self.sequence[self.call_count] + usage = self.get_usage(self.usage_dicts[self.call_count]) + self.call_count += 1 + return (output, usage) + + +def test_reply_with_unsupported_tool(): + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="unsupported_tool", parameters={})], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=(Tool.from_function(dummy_tool),), + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert isinstance(content, ToolResult) and content.is_error and "no tool exists" in content.output.lower() + + +def test_invalid_tool_parameters(): + """Test handling of invalid tool parameters response""" + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="dummy_tool", parameters="invalid json")], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test invalid parameters")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert isinstance(content, ToolResult) and content.is_error and "invalid json" in content.output.lower() + + +def test_max_tool_use_when_limit_reached(): + """Test the max_tool_use parameter in the reply method.""" + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="dummy_tool", parameters={})], + ), + Message( + role="assistant", + content=[ToolUse(id="2", name="dummy_tool", parameters={})], + ), + Message( + role="assistant", + content=[ToolUse(id="3", name="dummy_tool", parameters={})], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test max tool use")])) + + response = ex.reply(max_tool_use=3) + + assert ex.provider.call_count == 3 + assert "reached the limit" in response.content[0].text.lower() + + assert isinstance(ex.messages[-2].content[0], ToolResult) and ex.messages[-2].content[0].tool_use_id == "3" + + assert ex.messages[-1].role == "assistant" + + +def test_tool_output_too_long_character_error(): + """Test tool handling when output exceeds character limit.""" + + def long_output_tool_char() -> str: + return too_long_output + + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="long_output_tool_char", parameters={})], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(long_output_tool_char)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test long output char")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert ( + isinstance(content, ToolResult) + and content.is_error + and "output that was too long to handle" in content.output.lower() + ) + + +def test_tool_output_too_long_token_error(): + """Test tool handling when output exceeds token limit.""" + + def long_output_tool_token() -> str: + return too_long_token_output + + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="long_output_tool_token", parameters={})], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(long_output_tool_token)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test long output token")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert ( + isinstance(content, ToolResult) + and content.is_error + and "output that was too long to handle" in content.output.lower() + ) + + +@pytest.fixture(scope="function") +def normal_exchange() -> Exchange: + ex = Exchange( + provider=MockProvider( + sequence=[ + Message(role="assistant", content=[Text(text="Message 1")]), + Message(role="assistant", content=[Text(text="Message 2")]), + Message(role="assistant", content=[Text(text="Message 3")]), + Message(role="assistant", content=[Text(text="Message 4")]), + Message(role="assistant", content=[Text(text="Message 5")]), + ], + usage_dicts=[ + {"usage": {"total_tokens": 10, "input_tokens": 5, "output_tokens": 5}}, + {"usage": {"total_tokens": 28, "input_tokens": 10, "output_tokens": 18}}, + {"usage": {"total_tokens": 33, "input_tokens": 28, "output_tokens": 5}}, + {"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}}, + {"usage": {"total_tokens": 50, "input_tokens": 40, "output_tokens": 10}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=(Tool.from_function(dummy_tool),), + moderator=PassiveModerator(), + checkpoint_data=CheckpointData(), + ) + return ex + + +@pytest.fixture(scope="function") +def resumed_exchange() -> Exchange: + messages = [ + Message(role="user", content=[Text(text="User message 1")]), + Message(role="assistant", content=[Text(text="Assistant Message 1")]), + Message(role="user", content=[Text(text="User message 2")]), + Message(role="assistant", content=[Text(text="Assistant Message 2")]), + Message(role="user", content=[Text(text="User message 3")]), + Message(role="assistant", content=[Text(text="Assistant Message 3")]), + ] + provider = MockProvider( + sequence=[ + Message(role="assistant", content=[Text(text="Assistant Message 4")]), + ], + usage_dicts=[ + {"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}}, + ], + ) + ex = Exchange( + provider=provider, + messages=messages, + tools=[], + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + checkpoint_data=CheckpointData(), + moderator=PassiveModerator(), + ) + return ex + + +def test_checkpoints_on_exchange(normal_exchange): + """Test checkpoints on an exchange.""" + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + + # Check if checkpoints are created correctly + checkpoints = ex.checkpoint_data.checkpoints + assert len(checkpoints) == 6 + for i in range(len(ex.messages)): + # asserting that each message has a corresponding checkpoint + assert checkpoints[i].start_index == i + assert checkpoints[i].end_index == i + + # Check if the messages are ordered correctly + assert [msg.content[0].text for msg in ex.messages] == [ + "User message", + "Message 1", + "User message", + "Message 2", + "User message", + "Message 3", + ] + assert no_overlapping_checkpoints(ex) + + +def test_checkpoints_on_resumed_exchange(resumed_exchange) -> None: + ex = resumed_exchange + ex.pop_last_message() + ex.reply() + + checkpoints = ex.checkpoint_data.checkpoints + assert len(checkpoints) == 2 + assert len(ex.messages) == 6 + assert checkpoints[0].token_count == 32 + assert checkpoints[0].start_index == 0 + assert checkpoints[0].end_index == 4 + assert checkpoints[1].token_count == 8 + assert checkpoints[1].start_index == 5 + assert checkpoints[1].end_index == 5 + assert no_overlapping_checkpoints(ex) + + +def test_pop_last_checkpoint_on_resumed_exchange(resumed_exchange) -> None: + ex = resumed_exchange + ex.add(Message(role="user", content=[Text(text="Assistant Message 4")])) + ex.reply() + ex.pop_last_checkpoint() + + assert len(ex.messages) == 7 + assert len(ex.checkpoint_data.checkpoints) == 1 + + ex.pop_last_checkpoint() + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert no_overlapping_checkpoints(ex) + + +def test_pop_last_checkpoint_on_normal_exchange(normal_exchange) -> None: + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.pop_last_checkpoint() + ex.pop_last_checkpoint() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert no_overlapping_checkpoints(ex) + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.pop_last_checkpoint() + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + ex.reply() + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert no_overlapping_checkpoints(ex) + + +def test_pop_first_message_no_messages(): + ex = Exchange( + provider=MockProvider(sequence=[], usage_dicts=[]), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + + with pytest.raises(ValueError) as e: + ex.pop_first_message() + assert str(e.value) == "There are no messages to pop" + + +def test_pop_first_message_checkpoint_with_many_messages(resumed_exchange): + ex = resumed_exchange + ex.pop_last_message() + ex.reply() + + assert len(ex.messages) == 6 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.checkpoints[0].start_index == 0 + assert ex.checkpoint_data.checkpoints[0].end_index == 4 + assert ex.checkpoint_data.checkpoints[1].start_index == 5 + assert ex.checkpoint_data.checkpoints[1].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 0 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 5 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 4 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 3 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 3 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 4 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 5 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + assert no_overlapping_checkpoints(ex) + + with pytest.raises(ValueError) as e: + ex.pop_first_message() + + assert str(e.value) == "There are no messages to pop" + + +def test_varied_message_manipulation(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message 1")])) + ex.reply() + + ex.pop_first_message() + + ex.add(Message(role="user", content=[Text(text="User message 2")])) + ex.reply() + + assert len(ex.messages) == 3 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert ex.checkpoint_data.message_index_offset == 1 + # (start, end) + # (1, 1), (2, 2), (3, 3) + # actual_index_in_messages_arr = any checkpoint index - offset + assert no_overlapping_checkpoints(ex) + for i in range(3): + assert ex.checkpoint_data.checkpoints[i].start_index == i + 1 + assert ex.checkpoint_data.checkpoints[i].end_index == i + 1 + + ex.pop_last_message() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + for i in range(2): + assert ex.checkpoint_data.checkpoints[i].start_index == i + 1 + assert ex.checkpoint_data.checkpoints[i].end_index == i + 1 + + ex.add(Message(role="assistant", content=[Text(text="Assistant message")])) + ex.add(Message(role="user", content=[Text(text="User message 3")])) + ex.reply() + + assert len(ex.messages) == 5 + assert len(ex.checkpoint_data.checkpoints) == 4 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4), (5, 5)] + + ex.pop_last_checkpoint() + + assert len(ex.messages) == 4 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4)] + + ex.pop_first_message() + + assert len(ex.messages) == 3 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 4)] + + ex.pop_last_message() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)] + + ex.pop_last_message() + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)] + + ex.add(Message(role="assistant", content=[Text(text="Assistant message")])) + ex.add(Message(role="user", content=[Text(text="User message 5")])) + ex.pop_last_checkpoint() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + + ex.add(Message(role="user", content=[Text(text="User message 6")])) + ex.reply() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 3)] + + ex.pop_last_message() + + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)] + + ex.pop_first_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + + ex.add(Message(role="user", content=[Text(text="User message 7")])) + ex.pop_last_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + + +def test_pop_last_message_when_no_checkpoints_but_messages_present(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + + ex.pop_last_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + + +def test_pop_first_message_when_no_checkpoints_but_message_present(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + + with pytest.raises(ValueError) as e: + ex.pop_first_message() + + assert str(e.value) == "There must be at least one checkpoint to pop the first message" + + +def test_pop_first_checkpoint_size_n(resumed_exchange): + ex = resumed_exchange + ex.pop_last_message() # needed because the last message is an assistant message + ex.reply() + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 5 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert len(ex.messages) == 1 + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert len(ex.messages) == 0 + + +def test_pop_first_checkpoint_size_1(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert len(ex.messages) == 1 + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert len(ex.messages) == 0 + + +def test_pop_first_checkpoint_no_checkpoints(normal_exchange): + ex = normal_exchange + + with pytest.raises(ValueError) as e: + ex.pop_first_checkpoint() + + assert str(e.value) == "There are no checkpoints to pop" + + +def test_prepend_checkpointed_message_empty_exchange(normal_exchange): + ex = normal_exchange + ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10) + + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 0 + assert ex.checkpoint_data.checkpoints[0].end_index == 0 + + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert len(ex.messages) == 3 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_checkpoint() + + assert ex.checkpoint_data.message_index_offset == 1 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert len(ex.messages) == 2 + assert no_overlapping_checkpoints(ex) + + ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10) + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert len(ex.messages) == 3 + assert no_overlapping_checkpoints(ex) + + +def test_generate_successful_response_on_first_try(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text("Hello")])) + ex.generate() + + +def test_rewind_in_normal_exchange(normal_exchange): + ex = normal_exchange + ex.rewind() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + + ex.add(Message(role="user", content=[Text("Hello")])) + ex.generate() + ex.add(Message(role="user", content=[Text("Hello")])) + + # testing if it works with a user text message at the end + ex.rewind() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + + ex.add(Message(role="user", content=[Text("Hello")])) + ex.generate() + + # testing if it works with a non-user text message at the end + ex.rewind() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + + +def test_rewind_with_tool_usage(): + # simulating a real exchange with tool usage + ex = Exchange( + provider=MockProvider( + sequence=[ + Message.assistant("Hello!"), + Message( + role="assistant", + content=[ToolUse(id="1", name="dummy_tool", parameters={})], + ), + Message( + role="assistant", + content=[ToolUse(id="2", name="dummy_tool", parameters={})], + ), + Message.assistant("Done!"), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 27, "output_tokens": 44}}, + {"usage": {"input_tokens": 50, "output_tokens": 56}}, + {"usage": {"input_tokens": 60, "output_tokens": 76}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + ex.add(Message(role="user", content=[Text(text="test")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="kick it off!")])) + ex.reply() + + # removing the last message to simulate not getting a response + ex.pop_last_message() + + # calling rewind to last user message + ex.rewind() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert no_overlapping_checkpoints(ex) + assert ex.messages[0].content[0].text == "test" + assert type(ex.messages[1].content[0]) is Text + assert ex.messages[1].role == "assistant" diff --git a/packages/exchange/tests/test_exchange_collect_usage.py b/packages/exchange/tests/test_exchange_collect_usage.py new file mode 100644 index 000000000..590dc709b --- /dev/null +++ b/packages/exchange/tests/test_exchange_collect_usage.py @@ -0,0 +1,33 @@ +from unittest.mock import MagicMock +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators.passive import PassiveModerator +from exchange.providers.base import Provider +from exchange.tool import Tool +from exchange.token_usage_collector import _TokenUsageCollector + +MODEL_NAME = "test-model" + + +def create_exchange(mock_provider, dummy_tool): + return Exchange( + provider=mock_provider, + model=MODEL_NAME, + system="test-system", + tools=(Tool.from_function(dummy_tool),), + messages=[], + moderator=PassiveModerator(), + ) + + +def test_exchange_generate_collect_usage(usage_factory, dummy_tool, monkeypatch): + mock_provider = MagicMock(spec=Provider) + mock_usage_collector = MagicMock(spec=_TokenUsageCollector) + usage = usage_factory() + mock_provider.complete.return_value = (Message.assistant("msg"), usage) + exchange = create_exchange(mock_provider, dummy_tool) + + monkeypatch.setattr("exchange.exchange._token_usage_collector", mock_usage_collector) + exchange.generate() + + mock_usage_collector.collect.assert_called_once_with(MODEL_NAME, usage) diff --git a/packages/exchange/tests/test_exchange_frozen.py b/packages/exchange/tests/test_exchange_frozen.py new file mode 100644 index 000000000..a3095b3a3 --- /dev/null +++ b/packages/exchange/tests/test_exchange_frozen.py @@ -0,0 +1,48 @@ +import pytest +from attr.exceptions import FrozenInstanceError +from exchange.content import Text +from exchange.exchange import Exchange +from exchange.moderators import PassiveModerator +from exchange.message import Message +from exchange.providers import Provider, Usage +from exchange.tool import Tool + + +class MockProvider(Provider): + def complete(self, model, system, messages, tools=None): + return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict( + {"total_tokens": 35} + ) + + +def test_exchange_immutable(dummy_tool): + # Create an instance of Exchange + provider = MockProvider() + # intentionally setting a list instead of tuple on tools, it should be converted + exchange = Exchange( + provider=provider, + model="test-model", + system="test-system", + tools=(Tool.from_function(dummy_tool),), + messages=[Message(role="user", content=[Text(text="Hello!")])], + moderator=PassiveModerator(), + ) + + # Try to directly modify a field (should raise an error) + with pytest.raises(FrozenInstanceError): + exchange.system = "" + + with pytest.raises(AttributeError): + exchange.tools.append("anything") + + # Replace method should return a new instance with deepcopy of messages + new_exchange = exchange.replace(system="changed") + + assert new_exchange.system == "changed" + assert len(exchange.messages) == 1 + assert len(new_exchange.messages) == 1 + + # Ensure that the messages are deep copied + new_exchange.messages[0].content[0].text = "Changed!" + assert exchange.messages[0].content[0].text == "Hello!" + assert new_exchange.messages[0].content[0].text == "Changed!" diff --git a/packages/exchange/tests/test_image.png b/packages/exchange/tests/test_image.png new file mode 100644 index 000000000..3488b8a51 Binary files /dev/null and b/packages/exchange/tests/test_image.png differ diff --git a/packages/exchange/tests/test_integration.py b/packages/exchange/tests/test_integration.py new file mode 100644 index 000000000..1eb198082 --- /dev/null +++ b/packages/exchange/tests/test_integration.py @@ -0,0 +1,89 @@ +import os +import pytest +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators import ContextTruncate +from exchange.providers import get_provider +from exchange.providers.ollama import OLLAMA_MODEL +from exchange.tool import Tool +from tests.conftest import read_file + +too_long_chars = "x" * (2**20 + 1) + +cases = [ + # Set seed and temperature for more determinism, to avoid flakes + (get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL), dict(seed=3, temperature=0.1)), + (get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini"), dict()), + (get_provider("azure"), os.getenv("AZURE_MODEL", "gpt-4o-mini"), dict()), + (get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()), + (get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()), + (get_provider("google"), "gemini-1.5-flash", dict()), +] + + +@pytest.mark.integration +@pytest.mark.parametrize("provider,model,kwargs", cases) +def test_simple(provider, model, kwargs): + provider = provider.from_env() + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant.", + generation_args=kwargs, + ) + + ex.add(Message.user("Who is the most famous wizard from the lord of the rings")) + + response = ex.reply() + + # It's possible this can be flakey, but in experience so far haven't seen it + assert "gandalf" in response.text.lower() + + +@pytest.mark.integration +@pytest.mark.parametrize("provider,model,kwargs", cases) +def test_tools(provider, model, kwargs, tmp_path): + provider = provider.from_env() + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant. Expect to need to read a file using read_file.", + tools=(Tool.from_function(read_file),), + generation_args=kwargs, + ) + + ex.add(Message.user("What are the contents of this file? test.txt")) + + response = ex.reply() + + assert "hello exchange" in response.text.lower() + + +@pytest.mark.integration +@pytest.mark.parametrize("provider,model,kwargs", cases) +def test_tool_use_output_chars(provider, model, kwargs): + provider = provider.from_env() + + def get_password() -> str: + """Return the password for authentication""" + return too_long_chars + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant. Expect to need to authenticate using get_password.", + tools=(Tool.from_function(get_password),), + generation_args=kwargs, + ) + + ex.add(Message.user("Can you authenticate this session by responding with the password")) + + ex.reply() + + # Without our error handling, this would raise + # string too long. Expected a string with maximum length 1048576, but got a string with length ... diff --git a/packages/exchange/tests/test_integration_vision.py b/packages/exchange/tests/test_integration_vision.py new file mode 100644 index 000000000..20f165ade --- /dev/null +++ b/packages/exchange/tests/test_integration_vision.py @@ -0,0 +1,44 @@ +import os + +import pytest +from exchange.content import ToolResult, ToolUse +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators import ContextTruncate +from exchange.providers import get_provider + +cases = [ + (get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini")), +] + + +@pytest.mark.integration # skipped in CI/CD +@pytest.mark.parametrize("provider,model", cases) +def test_simple(provider, model): + provider = provider.from_env() + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant.", + ) + + ex.add(Message.user("What does the first entry in the menu say?")) + ex.add( + Message( + role="assistant", + content=[ToolUse(id="xyz", name="screenshot", parameters={})], + ) + ) + ex.add( + Message( + role="user", + content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')], + ) + ) + + response = ex.reply() + + # It's possible this can be flakey, but in experience so far haven't seen it + assert "ask goose" in response.text.lower() diff --git a/packages/exchange/tests/test_message.py b/packages/exchange/tests/test_message.py new file mode 100644 index 000000000..d5442eb75 --- /dev/null +++ b/packages/exchange/tests/test_message.py @@ -0,0 +1,96 @@ +import subprocess +from pathlib import Path +import pytest + +from exchange.message import Message +from exchange.content import Text, ToolUse, ToolResult + + +def test_user_message(): + user_message = Message.user("abcd") + assert user_message.role == "user" + assert user_message.text == "abcd" + + +def test_assistant_message(): + assistant_message = Message.assistant("abcd") + assert assistant_message.role == "assistant" + assert assistant_message.text == "abcd" + + +def test_message_tool_use(): + from exchange.content import ToolUse + + tu1 = ToolUse(id="1", name="tool", parameters={}) + tu2 = ToolUse(id="2", name="tool", parameters={}) + message = Message(role="assistant", content=[tu1, tu2]) + assert len(message.tool_use) == 2 + assert message.tool_use[0].name == "tool" + + +def test_message_tool_result(): + from exchange.content import ToolResult + + tr1 = ToolResult(tool_use_id="1", output="result") + tr2 = ToolResult(tool_use_id="2", output="result") + message = Message(role="user", content=[tr1, tr2]) + assert len(message.tool_result) == 2 + assert message.tool_result[0].output == "result" + + +def test_message_load(tmpdir): + # To emulate the expected relative lookup, we need to create a mock code dir + # and run the load in a subprocess + test_dir = Path(tmpdir) + + # Create a temporary Jinja template file in the test_dir + template_content = "hello {{ name }} {% include 'relative.jinja' %}" + template_path = test_dir / "template.jinja" + template_path.write_text(template_content) + + relative_content = "and {{ name2 }}" + relative_path = test_dir / "relative.jinja" + relative_path.write_text(relative_content) + + # Create a temporary Python file in the sub_dir that calls the load method with a relative path + python_file_content = """ +from exchange.message import Message + +def test_function(): + message = Message.load('template.jinja', name="a", name2="b") + assert message.text == "hello a and b" + assert message.role == "user" + +test_function() +""" + python_file_path = test_dir / "test_script.py" + python_file_path.write_text(python_file_content) + + # Execute the temporary Python file to test the relative lookup functionality + result = subprocess.run(["python3", str(python_file_path)]) + + assert result.returncode == 0 + + +def test_message_validation(): + # Valid user message + message = Message(role="user", content=[Text(text="Hello")]) + assert message.text == "Hello" + + # Valid assistant message + message = Message(role="assistant", content=[Text(text="Hello")]) + assert message.text == "Hello" + + # Invalid message: user with tool_use + with pytest.raises(ValueError): + Message( + role="user", + content=[Text(text=""), ToolUse(id="1", name="tool", parameters={})], + ) + + # Invalid message: assistant with tool_result + with pytest.raises(ValueError): + Message( + role="assistant", + content=[Text(text=""), ToolResult(tool_use_id="1", output="result")], + ) diff --git a/packages/exchange/tests/test_summarizer.py b/packages/exchange/tests/test_summarizer.py new file mode 100644 index 000000000..fa7281920 --- /dev/null +++ b/packages/exchange/tests/test_summarizer.py @@ -0,0 +1,227 @@ +import pytest +from exchange import Exchange, Message +from exchange.content import ToolResult, ToolUse +from exchange.moderators.passive import PassiveModerator +from exchange.moderators.summarizer import ContextSummarizer +from exchange.providers import Usage + + +class MockProvider: + def complete(self, model, system, messages, tools): + assistant_message_text = "Summarized content here." + output_tokens = len(assistant_message_text) + total_input_tokens = sum(len(msg.text) for msg in messages) + if not messages or messages[-1].role == "assistant": + message = Message.user(assistant_message_text) + else: + message = Message.assistant(assistant_message_text) + total_tokens = total_input_tokens + output_tokens + usage = Usage( + input_tokens=total_input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + return message, usage + + +@pytest.fixture +def exchange_instance(): + ex = Exchange( + provider=MockProvider(), + model="test-model", + system="test-system", + messages=[ + Message.user("Hi, can you help me with my homework?"), + Message.assistant("Of course! What do you need help with?"), + Message.user("I need help with math problems."), + Message.assistant("Sure, I can help with that. Let's get started."), + Message.user("Can you also help with my science homework?"), + Message.assistant("Yes, I can help with science too."), + Message.user("That's great! How about history?"), + Message.assistant("Of course, I can help with history as well."), + Message.user("Thanks! You're very helpful."), + Message.assistant("You're welcome! I'm here to help."), + ], + moderator=PassiveModerator(), + ) + return ex + + +@pytest.fixture +def summarizer_instance(): + return ContextSummarizer(max_tokens=300) + + +def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer): + # Pre-checks + assert len(exchange_instance.messages) == 10 + + exchange_instance.generate() + + # the exchange instance has a PassiveModerator so the messages are not truncated nor summarized + assert len(exchange_instance.messages) == 11 + assert len(exchange_instance.checkpoint_data.checkpoints) == 2 + + # we now tell the summarizer to summarize the exchange + summarizer_instance.rewrite(exchange_instance) + + assert exchange_instance.checkpoint_data.total_token_count <= 200 + assert len(exchange_instance.messages) == 2 + + # Assert that summarized content is the first message + first_message = exchange_instance.messages[0] + assert first_message.role == "user" or first_message.role == "assistant" + assert any("summarized" in content.text.lower() for content in first_message.content) + + # Ensure roles alternate in the output + for i in range(1, len(exchange_instance.messages)): + assert ( + exchange_instance.messages[i - 1].role != exchange_instance.messages[i].role + ), "Messages must alternate between user and assistant" + + +MESSAGE_SEQUENCE = [ + Message.user("Hi, can you help me with my homework?"), + Message.assistant("Of course! What do you need help with?"), + Message.user("I need help with math problems."), + Message.assistant("Sure, I can help with that. Let's get started."), + Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"), + Message( + role="assistant", + content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]), + Message( + role="assistant", + content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})], + ), + Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]), + Message( + role="assistant", + content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})], + ), + Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]), + Message( + role="assistant", + content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})], + ), + Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]), + Message( + role="assistant", + content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]), + Message.assistant("I'm done calculating the answers to your math questions."), + Message.user("Can you also help with my science homework?"), + Message.assistant("Yes, I can help with science too."), + Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"), + Message( + role="assistant", + content=[ToolUse(id="6", name="speed_of_light", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]), + Message( + role="assistant", + content=[ToolUse(id="7", name="photon_frequency", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]), + Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]), + Message( + role="user", + content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")], + ), + Message.assistant("I'm done calculating the answers to your science questions."), + Message.user("That's great! How about history?"), + Message.assistant("Of course, I can help with history as well."), + Message.user("Thanks! You're very helpful."), + Message.assistant("You're welcome! I'm here to help."), +] + + +class AnotherMockProvider: + def __init__(self): + self.sequence = MESSAGE_SEQUENCE + self.current_index = 1 + self.summarize_next = False + self.summarized_count = 0 + + def complete(self, model, system, messages, tools): + system_prompt_tokens = 100 + input_token_count = system_prompt_tokens + + message = self.sequence[self.current_index] + if self.summarize_next: + text = "Summary message here" + self.summarize_next = False + self.summarized_count += 1 + return Message.assistant(text=text), Usage( + # in this case, input tokens can really be whatever + input_tokens=40, + output_tokens=len(text) * 2, + total_tokens=40 + len(text) * 2, + ) + + if len(messages) > 0 and type(messages[0].content[0]) is ToolResult: + raise ValueError("ToolResult should not be the first message") + + if len(messages) == 1 and messages[0].text == "a": + # adding a +1 for the "a" + return Message.assistant("Getting system prompt size"), Usage( + input_tokens=80 + 1, output_tokens=20, total_tokens=system_prompt_tokens + 1 + ) + + for i in range(len(messages)): + if type(messages[i].content[0]) in (ToolResult, ToolUse): + input_token_count += 10 + else: + input_token_count += len(messages[i].text) * 2 + + if type(message.content[0]) in (ToolResult, ToolUse): + output_tokens = 10 + else: + output_tokens = len(message.text) * 2 + + total_tokens = input_token_count + output_tokens + if total_tokens > 300: + self.summarize_next = True + usage = Usage( + input_tokens=input_token_count, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + self.current_index += 2 + return message, usage + + +@pytest.fixture +def conversation_exchange_instance(): + ex = Exchange( + provider=AnotherMockProvider(), + model="test-model", + system="test-system", + moderator=ContextSummarizer(max_tokens=300), + # TODO: make it work with an offset so we don't have to send off requests basically + # at every generate step + ) + return ex + + +def test_summarizer_generic_conversation(conversation_exchange_instance: Exchange): + i = 0 + while i < len(MESSAGE_SEQUENCE): + next_message = MESSAGE_SEQUENCE[i] + conversation_exchange_instance.add(next_message) + message = conversation_exchange_instance.generate() + if message.text != "Summary message here": + i += 2 + checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints + assert conversation_exchange_instance.checkpoint_data.total_token_count == 570 + assert len(checkpoints) == 10 + assert len(conversation_exchange_instance.messages) == 10 + assert checkpoints[0].start_index == 20 + assert checkpoints[0].end_index == 20 + assert checkpoints[-1].start_index == 29 + assert checkpoints[-1].end_index == 29 + assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20 + assert conversation_exchange_instance.provider.summarized_count == 12 + assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 diff --git a/packages/exchange/tests/test_token_usage_collector.py b/packages/exchange/tests/test_token_usage_collector.py new file mode 100644 index 000000000..d277f63e9 --- /dev/null +++ b/packages/exchange/tests/test_token_usage_collector.py @@ -0,0 +1,24 @@ +from exchange.token_usage_collector import _TokenUsageCollector + + +def test_collect(usage_factory): + usage_collector = _TokenUsageCollector() + usage_collector.collect("model1", usage_factory(100, 1000, 1100)) + usage_collector.collect("model1", usage_factory(200, 2000, 2200)) + usage_collector.collect("model2", usage_factory(400, 4000, 4400)) + usage_collector.collect("model3", usage_factory(500, 5000, 5500)) + usage_collector.collect("model3", usage_factory(600, 6000, 6600)) + assert usage_collector.get_token_usage_group_by_model() == { + "model1": usage_factory(300, 3000, 3300), + "model2": usage_factory(400, 4000, 4400), + "model3": usage_factory(1100, 11000, 12100), + } + + +def test_collect_with_non_input_or_output_token(usage_factory): + usage_collector = _TokenUsageCollector() + usage_collector.collect("model1", usage_factory(100, None, None)) + usage_collector.collect("model1", usage_factory(None, 2000, None)) + assert usage_collector.get_token_usage_group_by_model() == { + "model1": usage_factory(100, 2000, 0), + } diff --git a/packages/exchange/tests/test_tool.py b/packages/exchange/tests/test_tool.py new file mode 100644 index 000000000..847e79fb5 --- /dev/null +++ b/packages/exchange/tests/test_tool.py @@ -0,0 +1,161 @@ +import attrs +from exchange.tool import Tool + + +def get_current_weather(location: str) -> None: + """Get the current weather in a given location + + Args: + location (str): The city and state, e.g. San Francisco, CA + """ + pass + + +def test_load(): + tool = Tool.from_function(get_current_weather) + + expected = { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + "required": ["location"], + }, + "function": get_current_weather, + } + + assert attrs.asdict(tool) == expected + + +def another_function( + param1: int, + param2: str, + param3: bool, + param4: float, + param5: list[int], + param6: dict[str, float], +) -> None: + """ + This is another example function with various types + + Args: + param1 (int): Description for param1 + param2 (str): Description for param2 + param3 (bool): Description for param3 + param4 (float): Description for param4 + param5 (list[int]): Description for param5 + param6 (dict[str, float]): Description for param6 + """ + pass + + +def test_load_types(): + tool = Tool.from_function(another_function) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + "param2": {"type": "string", "description": "Description for param2"}, + "param3": {"type": "boolean", "description": "Description for param3"}, + "param4": {"type": "number", "description": "Description for param4"}, + "param5": { + "type": "array", + "items": {"type": "integer"}, + "description": "Description for param5", + }, + "param6": { + "type": "object", + "additionalProperties": {"type": "number"}, + "description": "Description for param6", + }, + }, + "required": ["param1", "param2", "param3", "param4", "param5", "param6"], + } + assert tool.parameters == expected_schema + + +def numpy_function(param1: int, param2: str) -> None: + """ + This function uses numpy style docstrings + + Parameters + ---------- + param1 : int + Description for param1 + param2 : str + Description for param2 + """ + pass + + +def test_load_numpy_style(): + tool = Tool.from_function(numpy_function) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + "param2": {"type": "string", "description": "Description for param2"}, + }, + "required": ["param1", "param2"], + } + assert tool.parameters == expected_schema + + +def sphinx_function(param1: int, param2: str, param3: bool) -> None: + """ + This function uses sphinx style docstrings + + :param param1: Description for param1 + :type param1: int + :param param2: Description for param2 + :type param2: str + :param param3: Description for param3 + :type param3: bool + """ + pass + + +def test_load_sphinx_style(): + tool = Tool.from_function(sphinx_function) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + "param2": {"type": "string", "description": "Description for param2"}, + "param3": {"type": "boolean", "description": "Description for param3"}, + }, + "required": ["param1", "param2", "param3"], + } + assert tool.parameters == expected_schema + + +class FunctionLike: + def __init__(self, state: int) -> None: + self.state = state + + def __call__(self, param1: int) -> int: + """Example + + Args: + param1 (int): Description for param1 + """ + return self.state + param1 + + +def test_load_stateful_class(): + tool = Tool.from_function(FunctionLike(1)) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + }, + "required": ["param1"], + } + assert tool.parameters == expected_schema + assert tool.function(2) == 3 diff --git a/packages/exchange/tests/test_truncate.py b/packages/exchange/tests/test_truncate.py new file mode 100644 index 000000000..3875303e7 --- /dev/null +++ b/packages/exchange/tests/test_truncate.py @@ -0,0 +1,132 @@ +import pytest +from exchange import Exchange +from exchange.content import ToolResult, ToolUse +from exchange.message import Message +from exchange.moderators.truncate import ContextTruncate +from exchange.providers import Provider, Usage + +MAX_TOKENS = 300 +SYSTEM_PROMPT_TOKENS = 100 + +MESSAGE_SEQUENCE = [ + Message.user("Hi, can you help me with my homework?"), + Message.assistant("Of course! What do you need help with?"), + Message.user("I need help with math problems."), + Message.assistant("Sure, I can help with that. Let's get started."), + Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"), + Message( + role="assistant", + content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]), + Message( + role="assistant", + content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})], + ), + Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]), + Message( + role="assistant", + content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})], + ), + Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]), + Message( + role="assistant", + content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})], + ), + Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]), + Message( + role="assistant", + content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]), + Message.assistant("I'm done calculating the answers to your math questions."), + Message.user("Can you also help with my science homework?"), + Message.assistant("Yes, I can help with science too."), + Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"), + Message( + role="assistant", + content=[ToolUse(id="6", name="speed_of_light", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]), + Message( + role="assistant", + content=[ToolUse(id="7", name="photon_frequency", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]), + Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]), + Message( + role="user", + content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")], + ), + Message.assistant("I'm done calculating the answers to your science questions."), + Message.user("That's great! How about history?"), + Message.assistant("Of course, I can help with history as well."), + Message.user("Thanks! You're very helpful."), + Message.assistant("You're welcome! I'm here to help."), +] + + +class TruncateLinearProvider(Provider): + def __init__(self): + self.sequence = MESSAGE_SEQUENCE + self.current_index = 1 + self.summarize_next = False + self.summarized_count = 0 + + def complete(self, model, system, messages, tools): + input_token_count = SYSTEM_PROMPT_TOKENS + + message = self.sequence[self.current_index] + + if len(messages) > 0 and type(messages[0].content[0]) is ToolResult: + raise ValueError("ToolResult should not be the first message") + + if len(messages) == 1 and messages[0].text == "a": + # adding a +1 for the "a" + return Message.assistant("Getting system prompt size"), Usage( + input_tokens=80 + 1, output_tokens=20, total_tokens=SYSTEM_PROMPT_TOKENS + 1 + ) + + for i in range(len(messages)): + if type(messages[i].content[0]) in (ToolResult, ToolUse): + input_token_count += 10 + else: + input_token_count += len(messages[i].text) * 2 + + if type(message.content[0]) in (ToolResult, ToolUse): + output_tokens = 10 + else: + output_tokens = len(message.text) * 2 + + total_tokens = input_token_count + output_tokens + usage = Usage( + input_tokens=input_token_count, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + self.current_index += 2 + return message, usage + + +@pytest.fixture +def conversation_exchange_instance(): + ex = Exchange( + provider=TruncateLinearProvider(), + model="test-model", + system="test-system", + moderator=ContextTruncate(max_tokens=500), + ) + return ex + + +def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchange): + i = 0 + while i < len(MESSAGE_SEQUENCE): + next_message = MESSAGE_SEQUENCE[i] + conversation_exchange_instance.add(next_message) + message = conversation_exchange_instance.generate() + if message.text != "Summary message here": + i += 2 + # ensure the total token count is not anything exhorbitant + assert conversation_exchange_instance.checkpoint_data.total_token_count < 700 + assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 diff --git a/packages/exchange/tests/test_utils.py b/packages/exchange/tests/test_utils.py new file mode 100644 index 000000000..6bc00f9e0 --- /dev/null +++ b/packages/exchange/tests/test_utils.py @@ -0,0 +1,125 @@ +import pytest +from exchange import utils +from unittest.mock import patch +from exchange.message import Message +from exchange.content import Text, ToolResult +from exchange.providers.utils import messages_to_openai_spec, encode_image + + +def test_encode_image(): + image_path = "tests/test_image.png" + encoded_image = encode_image(image_path) + + # Adjust this string based on the actual initial part of your base64-encoded image. + expected_start = "iVBORw0KGgo" + assert encoded_image.startswith(expected_start) + + +def test_create_object_id() -> None: + prefix = "test" + object_id = utils.create_object_id(prefix) + assert object_id.startswith(prefix + "_") + assert len(object_id) == len(prefix) + 1 + 24 # prefix + _ + 24 chars + + +def test_compact() -> None: + content = "This is \n\n a test" + compacted = utils.compact(content) + assert compacted == "This is a test" + + +def test_parse_docstring() -> None: + def dummy_func(a, b, c): + """ + This function does something. + + Args: + a (int): The first parameter. + b (str): The second parameter. + c (list): The third parameter. + """ + pass + + description, parameters = utils.parse_docstring(dummy_func) + assert description == "This function does something." + assert parameters == [ + {"name": "a", "annotation": "int", "description": "The first parameter."}, + {"name": "b", "annotation": "str", "description": "The second parameter."}, + {"name": "c", "annotation": "list", "description": "The third parameter."}, + ] + + +def test_parse_docstring_no_description() -> None: + def dummy_func(a, b, c): + """ + Args: + a (int): The first parameter. + b (str): The second parameter. + c (list): The third parameter. + """ + pass + + with pytest.raises(ValueError) as e: + utils.parse_docstring(dummy_func) + + assert "Attempted to load from a function" in str(e.value) + + +def test_json_schema() -> None: + def dummy_func(a: int, b: str, c: list) -> None: + pass + + schema = utils.json_schema(dummy_func) + + assert schema == { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"}, + "c": {"type": "string"}, + }, + "required": ["a", "b", "c"], + } + + +def test_load_plugins() -> None: + class DummyEntryPoint: + def __init__(self, name, plugin): + self.name = name + self.plugin = plugin + + def load(self): + return self.plugin + + with patch("exchange.utils.entry_points") as entry_points_mock: + entry_points_mock.return_value = [ + DummyEntryPoint("plugin1", object()), + DummyEntryPoint("plugin2", object()), + ] + + plugins = utils.load_plugins("dummy_group") + + assert "plugin1" in plugins + assert "plugin2" in plugins + assert len(plugins) == 2 + + +def test_messages_to_openai_spec(): + # Use provided test image + png_path = "tests/test_image.png" + + # Create a list of messages as input + messages = [ + Message(role="user", content=[Text(text="Hello, Assistant!")]), + Message(role="assistant", content=[Text(text="Here is a text with tool usage")]), + Message( + role="tool", + content=[ToolResult(tool_use_id="1", output=f'"image:{png_path}')], + ), + ] + + # Call the function + output = messages_to_openai_spec(messages) + + assert "This tool result included an image that is uploaded in the next message." in str(output) + assert "{'role': 'user', 'content': [{'type': 'image_url'" in str(output) diff --git a/pyproject.toml b/pyproject.toml index 0e55a9a6c..4db2570f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,10 +5,10 @@ version = "0.9.3" readme = "README.md" requires-python = ">=3.10" dependencies = [ + "ai-exchange", "attrs>=23.2.0", "rich>=13.7.1", "ruamel-yaml>=0.18.6", - "ai-exchange>=0.9.3", "click>=8.1.7", "prompt-toolkit>=3.0.47", ] @@ -53,7 +53,6 @@ dev-dependencies = [ "mkdocs-gen-files>=0.5.0", "mkdocs-git-authors-plugin>=0.9.0", "mkdocs-git-committers-plugin>=0.2.3", - "mkdocs-git-revision-date-localized-plugin", "mkdocs-git-revision-date-localized-plugin>=1.2.9", "mkdocs-glightbox>=0.4.0", "mkdocs-include-markdown-plugin>=6.2.2", @@ -66,3 +65,9 @@ dev-dependencies = [ "pytest-mock>=3.14.0", "pytest>=8.3.2" ] + +[tool.uv.sources] +ai-exchange = { workspace = true } + +[tool.uv.workspace] +members = ["packages/*"]