diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0fb1701e0..dbe5ab6dd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -5,7 +5,7 @@ on: branches: [main] jobs: - build: + exchange: runs-on: ubuntu-latest steps: @@ -19,9 +19,82 @@ jobs: - name: Ruff run: | - uvx ruff check - uvx ruff format --check + uvx ruff check packages/exchange + uvx ruff format packages/exchange --check - name: Run tests + working-directory: ./packages/exchange run: | uv run pytest tests -m 'not integration' + + goose: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Install UV + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Source Cargo Environment + run: source $HOME/.cargo/env + + - name: Ruff + run: | + uvx ruff check src tests + uvx ruff format src tests --check + + - name: Run tests + run: | + uv run pytest tests -m 'not integration' + + + # This runs integration tests of the OpenAI API, using Ollama to host models. + # This lets us test PRs from forks which can't access secrets like API keys. + ollama: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: + # Only test the lastest python version. + - "3.12" + ollama-model: + # For quicker CI, use a smaller, tool-capable model than the default. + - "qwen2.5:0.5b" + + steps: + - uses: actions/checkout@v4 + + - name: Install UV + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Source Cargo Environment + run: source $HOME/.cargo/env + + - name: Set up Python + run: uv python install ${{ matrix.python-version }} + + - name: Install Ollama + run: curl -fsSL https://ollama.com/install.sh | sh + + - name: Start Ollama + run: | + # Run the background, in a way that survives to the next step + nohup ollama serve > ollama.log 2>&1 & + + # Block using the ready endpoint + time curl --retry 5 --retry-connrefused --retry-delay 1 -sf http://localhost:11434 + + # Tests use OpenAI which does not have a mechanism to pull models. Run a + # simple prompt to (pull and) test the model first. + - name: Test Ollama model + run: ollama run $OLLAMA_MODEL hello || cat ollama.log + env: + OLLAMA_MODEL: ${{ matrix.ollama-model }} + + - name: Run Ollama tests + run: uv run pytest tests -m integration -k ollama + working-directory: ./packages/exchange + env: + OLLAMA_MODEL: ${{ matrix.ollama-model }} diff --git a/.github/workflows/deploy_docs.yaml b/.github/workflows/deploy_docs.yaml new file mode 100644 index 000000000..06cdfb82b --- /dev/null +++ b/.github/workflows/deploy_docs.yaml @@ -0,0 +1,25 @@ +name: Deploy MkDocs + +on: + push: + branches: + - main # Trigger deployment on pushes to main + + paths: + - 'docs/**' + - 'mkdocs.yml' + - '.github/workflows/deploy_docs.yaml' + +jobs: + deploy: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install UV + uses: astral-sh/setup-uv@v3 + + - name: Build the documentation + run: uv run mkdocs gh-deploy --force diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 000000000..969ebb7e7 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,50 @@ +name: Publish + +# A release on goose will also publish exchange, if it has updated +# This means in some cases we may need to make a bump in goose without other changes to release exchange +on: + release: + types: [published] + +jobs: + publish: + permissions: + id-token: write + contents: read + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Get current version from pyproject.toml + id: get_version + run: | + echo "VERSION=$(grep -m 1 'version =' "pyproject.toml" | awk -F'"' '{print $2}')" >> $GITHUB_ENV + + - name: Extract tag version + id: extract_tag + run: | + TAG_VERSION=$(echo "${{ github.event.release.tag_name }}" | sed -E 's/v(.*)/\1/') + echo "TAG_VERSION=$TAG_VERSION" >> $GITHUB_ENV + + - name: Check if tag matches version from pyproject.toml + id: check_tag + run: | + if [ "${{ env.TAG_VERSION }}" != "${{ env.VERSION }}" ]; then + echo "::error::Tag version (${{ env.TAG_VERSION }}) does not match version in pyproject.toml (${{ env.VERSION }})." + exit 1 + fi + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v1 + with: + version: "latest" + + - name: Build Package + run: | + uv build -o dist --package goose-ai + uv build -o dist --package ai-exchange + + - name: Publish package to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + skip-existing: true diff --git a/.github/workflows/pypi_release.yaml b/.github/workflows/pypi_release.yaml deleted file mode 100644 index 98758fb96..000000000 --- a/.github/workflows/pypi_release.yaml +++ /dev/null @@ -1,47 +0,0 @@ -name: PYPI Release - -on: - push: - tags: - - 'v*' - -jobs: - pypi_release: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Install UV - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - - name: Source Cargo Environment - run: source $HOME/.cargo/env - - - name: Build with UV - run: uvx --from build pyproject-build --installer uv - - - name: Check version - id: check_version - run: | - PACKAGE_NAME=$(grep '^name =' pyproject.toml | sed -E 's/name = "(.*)"/\1/') - TAG_VERSION=$(echo "$GITHUB_REF" | sed -E 's/refs\/tags\/v(.+)/\1/') - CURRENT_VERSION=$(curl -s https://pypi.org/pypi/$PACKAGE_NAME/json | jq -r .info.version) - PROJECT_VERSION=$(grep '^version =' pyproject.toml | sed -E 's/version = "(.*)"/\1/') - if [ "$TAG_VERSION" != "$PROJECT_VERSION" ]; then - echo "Tag version does not match version in pyproject.toml" - exit 1 - fi - if python -c "from packaging.version import parse as parse_version; exit(0 if parse_version('$TAG_VERSION') > parse_version('$CURRENT_VERSION') else 1)"; then - echo "new_version=true" >> $GITHUB_OUTPUT - else - exit 1 - fi - - - name: Publish - uses: pypa/gh-action-pypi-publish@v1.4.2 - if: steps.check_version.outputs.new_version == 'true' - with: - user: __token__ - password: ${{ secrets.PYPI_TOKEN_TEMP }} - packages_dir: ./dist/ diff --git a/.gitignore b/.gitignore index 8733789a9..f799b7221 100644 --- a/.gitignore +++ b/.gitignore @@ -20,12 +20,10 @@ parts/ sdist/ var/ wheels/ -pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg -MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -49,7 +47,6 @@ coverage.xml *.cover *.py,cover .hypothesis/ -.pytest_cache/ # Translations *.mo @@ -58,8 +55,6 @@ coverage.xml # Django stuff: *.log local_settings.py -db.sqlite3 -db.sqlite3-journal # Flask stuff: instance/ @@ -68,7 +63,11 @@ instance/ # Scrapy stuff: .scrapy +# Sphinx documentation +docs/_build/ + # PyBuilder +.pybuilder/ target/ # Jupyter Notebook @@ -88,7 +87,8 @@ ipython_config.py # install all needed dependencies. #Pipfile.lock -# PEP 582; used by e.g. github.com/David-OConnor/pyflow + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff @@ -100,12 +100,8 @@ celerybeat.pid # Environments .env +.env.* .venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ # Spyder project settings .spyderproject @@ -114,21 +110,12 @@ venv.bak/ # Rope project settings .ropeproject -# Ignore mkdocs site files generated locally for testing/validation, but generated -# at buildtime in production -site/ -docs/docs/notebooks* +# mkdocs documentation +/site # mypy .mypy_cache/ .dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# Hermit -.hermit # VSCode .vscode @@ -136,8 +123,5 @@ dmypy.json # Autogenerated docs files docs/docs/reference -## goose session files -.goose - -# ignore lockfile +# uv lock file uv.lock diff --git a/.goosehints b/.goosehints index f445dda0e..8b6535a63 100644 --- a/.goosehints +++ b/.goosehints @@ -1,3 +1,3 @@ This is a python CLI app that uses UV. Read CONTRIBUTING.md for information on how to build and test it as needed. Some key concepts are that it is run as a command line interface, dependes on the "ai-exchange" package, and has the concept of toolkits which are ways that its behavior can be extended. Look in src/goose and tests. -Once the user has UV installed it should be able to be used effectively along with uvx to run tasks as needed \ No newline at end of file +Once the user has UV installed it should be able to be used effectively along with uvx to run tasks as needed diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 000000000..e69de29bb diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index fcb9c1b82..223a1a870 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -163,4 +163,4 @@ def example_history(self): ... ``` -[exchange]: https://github.com/squareup/exchange +[exchange]: https://github.com/block-open-source/goose/tree/main/packages/exchange diff --git a/CHANGELOG.md b/CHANGELOG.md index df7906769..18cef1199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,32 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.9.3] - 2024-09-25 + +- feat: auto save sessions before next user input (#94) +- fix: removed the diff when default profile changes (#92) +- feat: add shell-completions subcommand (#76) +- chore: update readme! (#96) +- chore: update docs again (#77) +- fix: remove overly general match for long running commands (#87) +- fix: default ollama to tested model (#88) +- fix: Resize file in screen toolkit (#81) +- fix: enhance shell() to know when it is interactive (#66) +- docs: document how to run goose fully from source from any dir (#83) +- feat: track cost and token usage in log file (#80) +- chore: add link to docs in read me (#85) +- docs: add in ollama (#82) +- chore: add just command for releasing goose (#55) +- feat: support markdown plans (#79) +- feat: add version options (#74) +- docs: fixing exchange url to public version (#67) +- docs: Update CONTRIBUTING.md (#69) +- chore: create mkdocs for goose (#70) +- docs: fix broken link (#71) +- feat: give commands the ability to execute logic (#63) +- feat: jira toolkit (#59) +- feat: run goose in a docker-style sandbox (#44) + ## [0.9.0] - 2024-09-10 This also updates the minimum version of exchange to 0.9.0. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c5cb699dc..783a010a6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,141 +1,83 @@ # Contributing -We welcome Pull Requests for general contributions. If you have a larger new feature or any questions on how -to develop a fix, we recommend you open an issue before starting. +We welcome Pull Requests for general contributions. If you have a larger new feature or any questions on how to develop a fix, we recommend you open an [issue][issues] before starting. ## Prerequisites -We provide a shortcut to standard commands using [just][just] in our `justfile`. +Goose uses [uv][uv] for dependency management, and formats with [ruff][ruff]. +Clone goose and make sure you have installed `uv` to get started. When you use +`uv` below in your local goose directly, it will automatically setup the virtualenv +and install dependencies. -* *goose* uses [uv][uv] for dependency management, and formats with [ruff][ruff] - install UV first: https://pypi.org/project/uv/ +We provide a shortcut to standard commands using [just][just] in our `justfile`. -## Developing and testing +## Development -Now that you have a local environment, you can make edits and run our tests. +Now that you have a local environment, you can make edits and run our tests! -```sh -uv run pytest tests -m "not integration" -``` +### Run Goose -or, as a shortcut, +If you've made edits and want to try them out, use -```sh -just test +``` +uv run goose session start ``` -## Running goose from source - -`uv run goose session start` - -will run a fresh goose session (can use the usual goose commands with `uv run` prefixed) - -## Running ai-exchange from source - -goose depends heavily on the https://github.com/square/exchange project, you can clone that repo and then work on both by running: - -```sh -uv add --editable - -then when you run goose with `uv run goose` it will be running it all from source. - -## Evaluations - -Given that so much of *goose* involves interactions with LLMs, our unit tests only go so far to -confirming things work as intended. - -We're currently developing a suite of evalutions, to make it easier to make improvements to *goose* more confidently. +or other `goose` commands. -In the meantime, we typically incubate any new additions that change the behavior of the *goose* -through **opt-in** plugins - `Toolkit`s, `Moderator`s, and `Provider`s. We welcome contributions of plugins -that add new capabilities to *goose*. We recommend sending in several examples of the new capabilities -in action with your pull request. +If you want to run your local changes but in another directory, you can use the path in +the virtualenv created by uv: -Additions to the [developer toolkit][developer] change the core performance, and so will need to be measured carefully. +``` +alias goosedev=`uv run which goose` +``` -## Build a Toolkit +You can then run `goosedev` from another dir and it will use your current changes. -To add a toolkit, start out with a plugin as mentioned above. In your code (which doesn't necessarily need to be -in the goose package thanks to [plugin metadata][plugin]!), create a class that derives from Toolkit. +### Run Tests -```python -import os -import platform +To run the test suite against your edges, use `pytest`: -from goose.toolkit.base import Toolkit, tool +```sh +uv run pytest tests -m "not integration" +``` +or, as a shortcut, -class Demo(Toolkit): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) +```sh +just test +``` - # Provide any additional tools as needed! - # The docstring of the tool provides instructions to the LLM, so they are important to tune - # you do not have to provide any tools, but any function decorated with @tool will be available - @tool - def authenticate(self, user: str): - """Output an authentication code for this user +## Exchange - Args: - user (str): The username to authenticate for - """ - # notifier supports any rich renderable https://rich.readthedocs.io/en/stable/introduction.html#quick-start - self.notifier.log(f"[bold red]auth: {str(hash(user))}[/]") +The lower level generation behind goose is powered by the [`exchange`][ai-exchange] package, also in this repo. - # Provide any system instructions for the model - # This can be generated dynamically, and is run at startup time - def system(self) -> str: - print("new") - return f"""**You must preceed your first message by using the authenticate tool for the current user** +Thanks to `uv` workspaces, any changes you make to `exchange` will be reflected in using your local goose. To run tests +for exchange, head to `packages/exchange` and run tests just like above - ``` - platform: {platform.system()} - cwd: {os.getcwd()} - user: {os.environ.get('USER')} - ``` - """ +```sh +uv run pytest tests -m "not integration" ``` -To make the toolkit available, add it as a plugin. For example in a pyproject.toml -``` -[project.entry-points."goose.toolkit"] -developer = "goose.toolkit.developer:Developer" -github = "goose.toolkit.github:Github" -# Add a line like this - the key becomes the name used in profiles -demo = "goose.toolkit.demo:Demo" -``` +## Evaluations -And then to setup a profile that uses it, add something to ~/.config/goose/profiles.yaml -```yaml -default: - provider: openai - processor: gpt-4o - accelerator: gpt-4o-mini - moderator: passive - toolkits: - - name: developer - requires: {} -demo: - provider: openai - processor: gpt-4o - accelerator: gpt-4o-mini - moderator: passive - toolkits: - - developer - - demo -``` +Given that so much of Goose involves interactions with LLMs, our unit tests only go so far to confirming things work as intended. -And now you can run goose with this new profile to use the new toolkit! +We're currently developing a suite of evaluations, to make it easier to make improvements to Goose more confidently. -```sh -goose session start --profile demo -``` +In the meantime, we typically incubate any new additions that change the behavior of the Goose through **opt-in** plugins - `Toolkit`s, `Moderator`s, and `Provider`s. We welcome contributions of plugins that add new capabilities to *goose*. We recommend sending in several examples of the new capabilities in action with your pull request. + +Additions to the [developer toolkit][developer] change the core performance, and so will need to be measured carefully. ## Conventional Commits This project follows the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification for PR titles. Conventional Commits make it easier to understand the history of a project and facilitate automation around versioning and changelog generation. -[developer]: src/goose/toolkit/developer.py +[issues]: https://github.com/block-open-source/goose/issues +[goose-plugins]: https://github.com/block-open-source/goose-plugins +[ai-exchange]: https://github.com/block-open-source/goose/tree/main/packages/exchange +[developer]: https://github.com/block-open-source/goose/blob/dfecf829a83021b697bf2ecc1dbdd57d31727ddd/src/goose/toolkit/developer.py [uv]: https://docs.astral.sh/uv/ [ruff]: https://docs.astral.sh/ruff/ [just]: https://github.com/casey/just -[plugin]: https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata +[adding-toolkit]: https://block-open-source.github.io/goose/configuration.html#adding-a-toolkit diff --git a/README.md b/README.md index 375723424..6f980ff76 100644 --- a/README.md +++ b/README.md @@ -1,242 +1,191 @@

-goose +Goose is your on-machine developer agent, automating engineering tasks seamlessly within your IDE or terminal

-

goose is a programming agent that runs on your machine.

+

+ Goose Drawing +

+

+ Generated by Goose from its VincentVanCode toolkit. +

- + + + + + + + + + + + Discord +

-Usage • -Configuration • -Tips • -FAQ • -Open Source +Unique features 🤖 • + Block Employees on Goose Block Emoji • +Quick start guide 🚀 • +Getting involved! 👋

-`goose` assists in solving a wide range of programming and operational tasks. It is a live virtual developer you can interact with, guide, and learn from. +> [!TIP] +> **Quick install:** +> ``` +> pipx install goose-ai +> ``` -To solve problems, `goose` breaks down instructions into sequences of tasks and carries them out using tools. Its ability to connect its changes with real outcomes (e.g. errors) and course correct is its most powerful and exciting feature. `goose` is free open source software and is built to be extensible and customizable. +**Goose** is a developer agent that supercharges your software development by automating an array of coding tasks directly within your terminal or IDE. Guided by you, it can intelligently assess your project's needs, generate the required code or modifications, and implement these changes on its own. Goose can **interact with a multitude of tools via external APIs** such as Jira, GitHub, Slack, infrastructure and data pipelines, and more -- if your task uses a **shell command or can be carried out by a Python script, Goose can do it for you too!** Like semi-autonomous driving, Goose handles the heavy lifting, allowing you to focus on other priorities. Simply set it on a task and return later to find it completed, boosting your productivity with less manual effort. -![goose_demo](https://github.com/user-attachments/assets/0794eaba-97ab-40ef-af64-6fc7f68eb8e2) +

+

-## Usage -### Installation +## Unique features of Goose compared to other AI assistants -To install `goose`, we recommend `pipx` +- **Autonomy**: A copilot should be able to also fly the plane at times, which in the development world means running code, debugging tests, installing dependencies, not just providing text output and autocomplete or search. Goose moves beyond just generating code snippets by (1) **using the shell** and (2) by seeing what happens with the code it writes and starting a feedback loop to solve harder problems, **refining solutions iteratively like a human developer**. Your code's best wingman. -First make sure you've [installed pipx][pipx] - for example +- **Extensibility**: Open-source and fully customizable, Goose integrates with your workflow and allows you to extend it for even more control. **Toolkits let you add new capabilities to Goose.** They are anything you can implement as a Python function (e.g. API requests, deployments, search, etc). We have a growing library of toolkits to use, but more importantly you can create your own. This gives Goose the ability to run these commands and decide if and when a tool is needed to complete your request! **Creating your own toolkits give you a way to bring your own private context into Goose's capabilities.** And you can use *any* LLM you want under the hood, as long as it supports tool use. -``` sh -brew install pipx -pipx ensurepath -``` +## What Block employees have to say about Goose -Then you can install `goose` with +> With Goose, I feel like I am Maverick. +> +> Thanks a ton for creating this. 🙏 +> I have been having way too much fun with it today. -```sh -pipx install goose-ai -``` -#### IDEs -There is an early version of a VS Code extension with goose support you can try here: https://github.com/square/goose-vscode - more to come soon. +-- P, Machine Learning Engineer -### LLM provider access setup -`goose` works on top of LLMs (you need to bring your own LLM). By default, `goose` uses `openai` as LLM provider. You need to set OPENAI_API_KEY as an environment variable if you would like to use `openai`. -```sh -export OPENAI_API_KEY=your_open_api_key -``` -Otherwise, please refer Configuration to customise `goose` +> I wanted to construct some fake data for an API with a large request body and business rules I haven't memorized. So I told Goose which object to update and a test to run that calls the vendor. Got it to use the errors descriptions from the vendor response to keep correcting the request until it was successful. So good! -### Start `goose` session -From your terminal, navigate to the directory you'd like to start from and run: -```sh -goose session start -``` +-- J, Software Engineer -You will see a prompt `G❯`: -``` -G❯ type your instructions here exactly as you would tell a developer. -``` -Now you are interact with `goose` in conversational sessions - something like a natural language driven code interpreter. -The default toolkit lets it take actions through shell commands and file edits. -You can interrupt `goose` at any time to help redirect its efforts. +> I asked Goose to write up a few Google Scripts that mimic Clockwise's functionality (particularly, creating blocks on my work calendar based on events in my personal calendar, as well as color-coding calendar entries based on type and importance). Took me under an hour. If you haven't tried Goose yet, I highly encourage you to do so! -### Exit `goose` session -If you are looking to exit, use `CTRL+D`, although `goose` should help you figure that out if you forget. See below for some examples. +-- M, Software Engineer -### Resume `goose` session -When you exit a session, it will save the history in `~/.config/goose/sessions` directory and you can resume it later on: +> If anyone was looking for another reason to check it out: I just asked Goose to break a string-array into individual string resources across eleven localizations, and it performed amazingly well and saved me a bunch of time doing it manually or figuring out some way to semi-automate it. -``` sh -goose session resume -``` +-- A, Android Engineer -## Configuration -`goose` can detect what LLM and toolkits it can work with from the configuration file `~/.config/goose/profiles.yaml` automatically. +> Hi team, thank you for much for making Goose, it's so amazing. Our team is working on migrating Dashboard components to React components. I am working with Goose to help the migration. -### Configuration options -Example: +-- K, Software Engineer -```yaml -default: - provider: openai - processor: gpt-4o - accelerator: gpt-4o-mini - moderator: truncate - toolkits: - - name: developer - requires: {} - - name: screen - requires: {} -``` -You can edit this configuration file to use different LLMs and toolkits in `goose`. `goose can also be extended to support any LLM or combination of LLMs +> Got Goose to update a dependency, run tests, make a branch and a commit... it was 🤌. Not that complicated but I was impressed it figured out how to run tests from the README. + +-- J, Software Engineer -#### provider -Provider of LLM. LLM providers that currently are supported by `goose`: -| Provider | Required environment variable(s) to access provider | -| :----- | :------------------------------ | -| openai | `OPENAI_API_KEY` | -| anthropic | `ANTHROPIC_API_KEY` | -| databricks | `DATABRICKS_HOST` and `DATABRICKS_TOKEN` | +> Wanted to document what I had Goose do -- took about 30 minutes end to end! I created a custom CLI command in the `gh CLI` library to download in-line comments on PRs about code changes (currently they aren't directly viewable). I don't know Go *that well* and I definitely didn't know where to start looking in the code base or how to even test the new command was working and Goose did it all for me 😁 +-- L, Software Engineer -#### processor -Model for complex, multi-step tasks such as writing code and executing commands. Example: `gpt-4o`. You should choose the model based the provider you configured. -#### accelerator -Small model for fast, lightweight tasks. Example: `gpt-4o-mini`. You should choose the model based the provider you configured. +> Hi Team, just wanted to share my experience of using Goose as a non-engineer! ... I just asked Goose to ensure that my environment is up to date and copied over a guide into my prompt. Goose managed everything flawlessly, keeping me informed at every step... I was truly impressed with how well it works and how easy it was to get started! 😍 -#### moderator -Rules designed to control or manage the output of the model. Moderators that currently are supported by `goose`: +-- M, Product Manager -- `passive`: does not actively intervene in every response -- `truncate`: truncates the first contexts when the contexts exceed the max token size +**See more of our use-cases in our [docs][use-cases]!** -#### toolkits +## Quick start guide -`goose` can be extended with toolkits, and out of the box there are some available: +### Installation -* `developer`: for general-purpose development capabilities, including plan management, shell execution, and file operations, with default shell strategies like using ripgrep. -* `screen`: for letting goose take a look at your screen to help debug or work on designs (gives goose eyes) -* `github`: for awareness and suggestions on how to use github -* `repo_context`: for summarizing and understanding a repository you are working in. -* `jira`: for working with JIRA (issues, backlogs, tasks, bugs etc) +To install Goose, use `pipx`. First ensure [pipx][pipx] is installed: +``` sh +brew install pipx +pipx ensurepath +``` +You can also place `.goosehints` in `~/.config/goose/.goosehints` if you like for always loaded hints personal to you. -#### Configuring goose per repo +Then install Goose: -If you are using the `developer` toolkit, `goose` adds the content from `.goosehints` - file in working directory to the system prompt of the `developer` toolkit. The hints -file is meant to provide additional context about your project. The context can be -user-specific or at the project level in which case, you -can commit it to git. `.goosehints` file is Jinja templated so you could have something -like this: +```sh +pipx install goose-ai ``` -Here is an overview of how to contribute: -{% include 'CONTRIBUTING.md' %} -The following justfile shows our common commands: -```just -{% include 'justfile' %} +### Running Goose + +#### Start a session + +From your terminal, navigate to the directory you'd like to start from and run: + +```sh +goose session start ``` -### Examples -#### provider as `anthropic` +You will see the Goose prompt `G❯`: -```yaml -default: - provider: anthropic - processor: claude-3-5-sonnet-20240620 - accelerator: claude-3-5-sonnet-20240620 -... ``` -#### provider as `databricks` -```yaml -default: - provider: databricks - processor: databricks-meta-llama-3-1-70b-instruct - accelerator: databricks-meta-llama-3-1-70b-instruct - moderator: passive - toolkits: - - name: developer - requires: {} +G❯ type your instructions here exactly as you would tell a developer. ``` -## Tips - -Here are some collected tips we have for working efficiently with `goose` +Now you are interacting with Goose in conversational sessions - something like a natural language driven code interpreter. The default toolkit allows Goose to take actions through shell commands and file edits. You can interrupt Goose with `CTRL+D` or `ESC+Enter` at any time to help redirect its efforts. -- **`goose` can and will edit files**. Use a git strategy to avoid losing anything - such as staging your - personal edits and leaving `goose` edits unstaged until reviewed. Or consider using individual commits which can be reverted. -- **`goose` can and will run commands**. You can ask it to check with you first if you are concerned. It will check commands for safety as well. -- You can interrupt `goose` with `CTRL+C` to correct it or give it more info. -- `goose` works best when solving concrete problems - experiment with how far you need to break that problem - down to get `goose` to solve it. Be specific! E.g. it will likely fail to `"create a banking app"`, - but probably does a good job if prompted with `"create a Fastapi app with an endpoint for deposit and withdrawal - and with account balances stored in mysql keyed by id"` -- If `goose` doesn't have enough context to start with, it might go down the wrong direction. Tell it - to read files that you are referring to or search for objects in code. Even better, ask it to summarize - them for you, which will help it set up its own next steps. -- Refer to any objects in files with something that is easy to search for, such as `"the MyExample class" -- `goose` *loves* to know how to run tests to get a feedback loop going, just like you do. If you tell it how you test things locally and quickly, it can make use of that when working on your project -- You can use `goose` for tasks that would require scripting at times, even looking at your screen and correcting designs/helping you fix bugs, try asking it to help you in a way you would ask a person. -- `goose` will make mistakes, and go in the wrong direction from times, feel free to correct it, or start again. -- You can tell `goose` to run things for you continuously (and it will iterate, try, retry) but you can also tell it to check with you before doing things (and then later on tell it to go off on its own and do its best to solve). -- `goose` can run anywhere, doesn't have to be in a repo, just ask it! +#### Exit the session +If you are looking to exit, use `CTRL+D`, although Goose should help you figure that out if you forget. -### Examples +#### Resume a session -Here are some examples that have been used: +When you exit a session, it will save the history in `~/.config/goose/sessions` directory and you can resume it later on: -``` -G❯ Looking at the in progress changes in this repo, help me finish off the feature. CONTRIBUTING.md shows how to run the tests. +``` sh +goose session resume ``` -``` -G❯ In this golang project, I want you to add open telemetry to help me get started with it. Look in the moneymovements module, run the `just test` command to check things work. -``` +To see more documentation on the CLI commands currently available to Goose check out the documentation [here][cli]. If you’d like to develop your own CLI commands for Goose, check out the [Contributing document][contributing]. -``` -G❯ This project uses an old version of jooq. Upgrade to the latest version, and ensure there are no incompatibilities by running all tests. Dependency versions are in gradle/libs.versions.toml and to run gradle, use the binary located in bin/gradle -``` +### Next steps -``` -G❯ This is a fresh checkout of a golang project. I do not have my golang environment set up. Set it up and run tests for this project, and ensure they pass. Use the zookeeper jar included in this repository rather than installing zookeeper via brew. -``` +Learn how to modify your Goose profiles.yaml file to add and remove functionality (toolkits) and providing context to get the most out of Goose in our [Getting Started Guide][getting-started]. -``` -G❯ In this repo, I want you to look at how to add a new provider for azure. -Some hints are in this github issue: https://github.com/square/exchange/issues -/4 (you can use gh cli to access it). -``` +**Want to move out of the terminal and into an IDE?** -``` -G❯ I want you to help me increase the test coverage in src/java... use mvn test to run the unit tests to check it works. -``` +We have some experimental IDE integrations for VSCode and JetBrains IDEs: +* https://github.com/square/goose-vscode +* https://github.com/Kvadratni/goose-intellij + +## Getting involved! + +There is a lot to do! If you're interested in contributing, a great place to start is picking a `good-first-issue`-labelled ticket from our [issues list][gh-issues]. More details on how to develop Goose can be found in our [Contributing Guide][contributing]. We are a friendly, collaborative group and look forward to working together![^1] -## FAQ -**Q:** Why did I get error message of "The model `gpt-4o` does not exist or you do not have access to it.` when I talked goose? +Check out and contribute to more experimental features in [Goose Plugins][goose-plugins]! -**A:** You can find out the LLM provider and models in the configuration file `~/.config/goose/profiles.yaml` here to check whether your LLM provider account has access to the models. For example, after you have made a successful payment of $5 or more (usage tier 1), you'll be able to access the GPT-4, GPT-4 Turbo, GPT-4o models via the OpenAI API. [How can I access GPT-4, GPT-4 Turbo, GPT-4o, and GPT-4o mini?](https://help.openai.com/en/articles/7102672-how-can-i-access-gpt-4-gpt-4-turbo-gpt-4o-and-gpt-4o-mini). +Let us know what you think in our [Discussions][discussions] or the [**`#goose`** channel on Discord][goose-channel]. -## Open Source +[^1]: Yes, Goose is open source and always will be. Goose is released under the ASL2.0 license meaning you are free to use it however you like. See [LICENSE.md][license] for more details. -Yes, `goose` is open source and always will be. `goose` is released under the ASL2.0 license meaning you can use it however you like. -See LICENSE.md for more details. -To run `goose` from source, please see `CONTRIBUTING.md` for instructions on how to set up your environment and you can then run `uv run `goose` session start`. +[goose-plugins]: https://github.com/block-open-source/goose-plugins [pipx]: https://github.com/pypa/pipx?tab=readme-ov-file#install-pipx +[contributing]: https://github.com/block-open-source/goose/blob/main/CONTRIBUTING.md +[license]: https://github.com/block-open-source/goose/blob/main/LICENSE + +[goose-docs]: https://block-open-source.github.io/goose/ +[toolkits]: https://block-open-source.github.io/goose/plugins/available-toolkits.html +[configuration]: https://block-open-source.github.io/goose/configuration.html +[cli]: https://block-open-source.github.io/goose/plugins/cli.html +[providers]: https://block-open-source.github.io/goose/providers.html +[use-cases]: https://block-open-source.github.io/goose/guidance/applications.html +[getting-started]: https://block-open-source.github.io/goose/guidance/getting-started.html + +[discord-invite]: https://discord.gg/7GaTvbDwga +[gh-issues]: https://github.com/block-open-source/goose/issues +[van-code]: https://github.com/block-open-source/goose-plugins/blob/de98cd6c29f8e7cd3b6ace26535f24ac57c9effa/src/goose_plugins/toolkits/artify.py +[discussions]: https://github.com/block-open-source/goose/discussions +[goose-channel]: https://discord.com/channels/1287729918100246654/1287729920319033345 diff --git a/docs/assets/bg.png b/docs/assets/bg.png new file mode 100644 index 000000000..ce4acf5ba Binary files /dev/null and b/docs/assets/bg.png differ diff --git a/docs/assets/bg2.png b/docs/assets/bg2.png new file mode 100644 index 000000000..f9227edc8 Binary files /dev/null and b/docs/assets/bg2.png differ diff --git a/docs/assets/bg3.png b/docs/assets/bg3.png new file mode 100644 index 000000000..661cb3f12 Binary files /dev/null and b/docs/assets/bg3.png differ diff --git a/docs/assets/bg4.png b/docs/assets/bg4.png new file mode 100644 index 000000000..ec4344fb8 Binary files /dev/null and b/docs/assets/bg4.png differ diff --git a/docs/assets/docs.css b/docs/assets/docs.css new file mode 100644 index 000000000..67d86ca28 --- /dev/null +++ b/docs/assets/docs.css @@ -0,0 +1,213 @@ +/*@media only screen and (min-width: 76.25em) {*/ +/* .md-main__inner {*/ +/* max-width: none;*/ +/* }*/ +/* .md-sidebar--primary {*/ +/* left: 0;*/ +/* }*/ +/* .md-sidebar--secondary {*/ +/* right: 0;*/ +/* margin-left: 0;*/ +/* -webkit-transform: none;*/ +/* transform: none;*/ +/* }*/ +/*}*/ + +body { + --md-code-fg-color: white !important; + --md-code-bg-color: rgba(0, 0, 0, .5) !important; + --shadow-color: #FF9E9E; + --shadow-color-light: white; +} + +#__mermaid_0 { + font-size: 16px; +} + +.md-typeset code { + border-radius: 5px; +} + +/* Reduce the space between the term and definition in a definition list + Reads better for flags and their documention in CLI options lists */ +.md-typeset dd { + margin-top: 0.125em; +} + +/* We want syntax highlighting in fenced codeblocks describing shell commands + because it's nice to see comments dimmed and quoted strings highlighted. */ +.md-typeset .language-bash { + /* We don't need to syntax-highlight numbers in bash blocks */ + --md-code-hl-number-color: var(--md-code-fg-color); + /* We don't need to syntax-highlight shell-native functions (like `cd`) in bash blocks */ + --md-code-hl-constant-color: var(--md-code-fg-color); +} + +.highlight .kc, .highlight .n { + color: rgba(255, 255, 255, 0.8); +} + +body .md-sidebar--primary .md-sidebar__scrollwrap { + border-right: 1px solid #454755; +} + +.md-container { + opacity: 1 !important; +} + +@media screen and (min-width: 76.25em) { + .md-main::before { + content: ""; + + background-size: cover !important; + background-repeat: no-repeat !important; + background-attachment: fixed !important; + background-position: center !important; + + position: absolute; + z-index: -99999; + top: 0; + right: 0; + bottom: 0; + left: 0; + opacity: .06; + + animation: changeBg 15s infinite ease-in-out; + + -webkit-transition: background 15s linear; + -moz-transition: background 15s linear; + -o-transition: background 15s linear; + -ms-transition: background 15s linear; + transition: background 15s linear; + + animation-duration: 15s; + animation-iteration-count: infinite; + animation-direction: alternate; + } +} + +@keyframes changeBg { + 0% { + background-image: var(--bg1); + } + 25% { + background-image: var(--bg2); + } + 50% { + background-image: var(--bg3); + } + 75% { + background-image: var(--bg4); + } + 100% { + background-image: var(--bg1); + } +} + +.md-header { + background-color: rgba(14, 20, 24, 0.9) !important; +} + +.md-tabs { + background-color: rgba(14, 20, 24, 0.6) !important; +} + +@media screen and (min-width: 76.25em) { + .md-nav--lifted > .md-nav__list > .md-nav__item--active > .md-nav__link { + background: none; + box-shadow: none; + } +} + +.md-nav__toggle.md-toggle--indeterminate~.md-nav, .md-nav__toggle:checked~.md-nav, .md-nav__toggle~.md-nav { + -webkit-transition-property: none; + -moz-transition-property: none; + -o-transition-property: none; + transition-property: none; +} + +@media screen and (min-width: 60em) { + .md-nav--secondary .md-nav__title { + background: none; + box-shadow: none; + } +} + +/*add a subtle breathing effect to admonitions border*/ +.admonition { + animation: pulsate 10s infinite; + border-radius: 7px; +} + +@keyframes pulsate { + 0% { + -webkit-box-shadow: inset 0 0 .075rem rgb(138, 163, 255); + -moz-box-shadow: inset 0 0 .075rem rgb(138, 163, 255); + box-shadow: inset 0 0 .075rem rgb(138, 163, 255); + } + 50% { + border-color: rgba(255, 255, 255, .5); + -webkit-box-shadow: inset 0 0 .075rem rgba(255, 255, 255, .5); + -moz-box-shadow: inset 0 0 .075rem rgba(255, 255, 255, .5); + box-shadow: inset 0 0 .075rem rgba(255, 255, 255, .5); + } + 100% { + -webkit-box-shadow: inset 0 0 .075rem rgb(138, 163, 255); + -moz-box-shadow: inset 0 0 .075rem rgb(138, 163, 255); + box-shadow: inset 0 0 .075rem rgb(138, 163, 255); + } +} + +/*pop code elements a tad*/ +code { + border: .075rem solid rgba(0, 0, 0, .3); +} + +img { + border-radius: 10px; +} + +.neon { + color: white; + animation: neon 3s infinite; + margin: calc(50vh - 40px) auto 0 auto; + font-size: 25px; + text-transform: uppercase; + font-family: "Archivo Black", "Archivo", sans-serif; + font-weight: normal; + display: block; + height: auto; + text-align: center; +} + +@keyframes neon { + 0% { + text-shadow: -1px -1px 1px var(--shadow-color-light), -1px 1px 1px var(--shadow-color-light), 1px -1px 1px var(--shadow-color-light), 1px 1px 1px var(--shadow-color-light), + 0 0 3px var(--shadow-color-light), 0 0 10px var(--shadow-color-light), 0 0 20px var(--shadow-color-light), + 0 0 30px var(--shadow-color), 0 0 20px var(--shadow-color), 0 0 25px var(--shadow-color), 0 0 35px var(--shadow-color), 0 0 25px var(--shadow-color), 0 0 25px var(--shadow-color); + } + 50% { + text-shadow: -1px -1px 1px var(--shadow-color-light), -1px 1px 1px var(--shadow-color-light), 1px -1px 1px var(--shadow-color-light), 1px 1px 1px var(--shadow-color-light), + 0 0 5px var(--shadow-color-light), 0 0 15px var(--shadow-color-light), 0 0 25px var(--shadow-color-light), + 0 0 40px var(--shadow-color), 0 0 25px var(--shadow-color), 0 0 30px var(--shadow-color), 0 0 40px var(--shadow-color), 0 0 30px var(--shadow-color), 0 0 30px var(--shadow-color); + } + 100% { + text-shadow: -1px -1px 1px var(--shadow-color-light), -1px 1px 1px var(--shadow-color-light), 1px -1px 1px var(--shadow-color-light), 1px 1px 1px var(--shadow-color-light), + 0 0 3px var(--shadow-color-light), 0 0 10px var(--shadow-color-light), 0 0 20px var(--shadow-color-light), + 0 0 30px var(--shadow-color), 0 0 20px var(--shadow-color), 0 0 25px var(--shadow-color), 0 0 35px var(--shadow-color), 0 0 25px var(--shadow-color), 0 0 25px var(--shadow-color); + } +} + +.md-nav__item--section>.md-nav__link[for] { + color: white; +} + +/* this is the top nav item side left */ +.md-nav--lifted>.md-nav__list>.md-nav__item>[for] { + color: white; + position: absolute; /* otherwise scroll overflow does not look great */ +} + +.md-nav__link--active { + color: white !important; +} \ No newline at end of file diff --git a/docs/assets/docs.js b/docs/assets/docs.js new file mode 100644 index 000000000..b99200c65 --- /dev/null +++ b/docs/assets/docs.js @@ -0,0 +1,261 @@ +const backgrounds = [ + "/assets/bg.png", + "/assets/bg2.png", + "/assets/bg3.png", + "/assets/bg4.png", +]; + +// this is to preload the images so the transition is smooth. +// otherwise, on transition image will flicker without smooth transition. + +var hiddenContainer = document.createElement('div'); +hiddenContainer.style.display = 'none'; +document.body.appendChild(hiddenContainer); + +let index = 1; +for (let bg of backgrounds) { + let img = []; + img[index] = new Image(); + img[index].src = bg; + + hiddenContainer.appendChild(img[index]); + index++; +} + +function shuffleBackgrounds() { + let images = []; // preload + let index = 1; + for (let bg of shuffle(backgrounds)) { + document.body.style.setProperty("--bg" + index, "url(" + bg + ")"); + index++; + } +} + +function shuffle(array) { + let currentIndex = array.length, randomIndex; + + // While there remain elements to shuffle. + while (currentIndex !== 0) { + + // Pick a remaining element. + randomIndex = Math.floor(Math.random() * currentIndex); + currentIndex--; + + // And swap it with the current element. + [array[currentIndex], array[randomIndex]] = [ + array[randomIndex], array[currentIndex]]; + } + + return array; +} + +shuffleBackgrounds(); + +window.onload = function () { + onLoad(); +}; + +let origOpen = XMLHttpRequest.prototype.open; +XMLHttpRequest.prototype.open = function () { + let url = arguments[1]; // The second argument is the URL + + this.addEventListener('loadend', function (e) { + if (url.includes("https://codeserver.sq.dev/api/v1/health")) { + return; + } + + console.log(url); + + // make sure we have a full render before calling onLoad + setTimeout(() => { + onLoad(); + }, 100); + }); + origOpen.apply(this, arguments); +}; + +let healthCheckTimer = null; + +function onLoad() { + + console.log("onLoad"); + let interactiveCodePage = false; + document.querySelectorAll("a").forEach((e) => { + if (e.innerText === "Run Code") { + interactiveCodePage = true; + } + }); + + if (interactiveCodePage) { + document.querySelectorAll("code").forEach((e) => { + e.style.maxHeight = "40vh"; + }); + } + + if (healthCheckTimer) { + console.log("clearing health check timer"); + clearInterval(healthCheckTimer); + } + + healthCheckTimer = setInterval(() => { + // if the tab is not visible, don't check + if (document.hidden) { + return; + } + + // check if https://codeserver.sq.dev/api/v1/health is up + const xhr = new XMLHttpRequest(); + xhr.open("GET", "https://codeserver.sq.dev/api/v1/health", true); + xhr.send(); + xhr.timeout = 1000; + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + document.querySelectorAll("a").forEach((e) => { + if (e.innerText === "Code Server is down") { + e.innerText = "Run Code"; + e.onclick = function () { + }; + } + }); + } else { + document.querySelectorAll("a").forEach((e) => { + if (e.innerText === "Run Code") { + e.innerText = "Code Server is down"; + e.onclick = function () { + alert("Code Server is down.\n\nRun\n\nsq dev up codeserver\n\nto start the code server in your local development environment."); + return false; + }; + } + }); + } + } + }; + }, 1000); +} + +let codeServerRequestLoading = false; + +// register button listener +// this is for code running examples +document.addEventListener("click", function (e) { + if (e.target.innerText !== "Run Code") { + return; + } + + if (codeServerRequestLoading) { + alert("Please wait for the previous request to finish."); + return; + } + +// console.log("e", e.target); +// console.log("parent", e.target.parentElement); +// console.log("parent pu", getPreviousUntil(e.target.parentElement, '.tabbed-block')); +// console.log("parent > 1", e.target.parentElement.previousElementSibling); +// console.log("parent > 2", e.target.parentElement.previousElementSibling.previousElementSibling); +// console.log("parent > tabbed-block", e.target.parentElement.previousElementSibling.previousElementSibling); +// console.log("parent > tabbed-block", e.target.parentElement.previousElementSibling.previousElementSibling.querySelectorAll('.tabbed-block')); + + // const codeClass = e.target.parentElement.previousElementSibling.previousElementSibling.querySelectorAll('.tabbed-block'); + const codeClass = getPreviousUntil(e.target.parentElement, '.tabbed-content')[0].querySelectorAll('.tabbed-block'); + +// console.log("ele", codeClass); + + let language = ""; + codeClass.forEach((e) => { + console.log(window.getComputedStyle(e).display); + // this is the visible code block + if (window.getComputedStyle(e).display === "block") { + console.log(e); + const codeBlock = e.querySelector('code'); + if (codeBlock) { + language = codeBlock.closest('div').className.split(" ")[0].split("-")[1]; + + console.log("code block", codeBlock); + console.log(codeBlock.closest('div')); + } + } + }); + +// console.log("Language", language); + + document.getElementById("loader")?.remove(); + document.getElementById("output")?.remove(); + document.getElementById("output-error")?.remove(); + + const output = document.createElement("pre"); + output.id = "loader"; + output.innerHTML = "Loading..."; + e.target.parentElement.appendChild(output); + + const code = e.target.parentElement.previousSibling.previousSibling.querySelector('.language-' + CSS.escape(language)); + if (code) { + const xhr = new XMLHttpRequest(); + xhr.open("POST", "https://codeserver.sq.dev/api/v1/code/" + language + "/run", true); + xhr.setRequestHeader("Content-Type", "application/x-www-form-urlencoded"); + xhr.send(code.innerText); + + codeServerRequestLoading = true; + xhr.onreadystatechange = function () { + document.getElementById("loader")?.remove(); + + if (xhr.readyState === 4) { + + if (xhr.status === 200) { + + // code result + const output = document.createElement("div"); + output.id = "output"; + output.innerHTML = "

" + toTitleCase(language) + " Execution

" + xhr.responseText + "
"; + e.target.parentElement.appendChild(output); + output.scrollIntoView({behavior: "smooth", block: "center", inline: "nearest"}); + } else { + console.log("Error", xhr.statusText); + const output = document.createElement("pre"); + output.id = "output-error"; + output.innerHTML = "" + xhr.responseText + ""; + e.target.parentElement.appendChild(output); + output.scrollIntoView({behavior: "smooth", block: "center", inline: "nearest"}); + } + } + + codeServerRequestLoading = false; + + }; + } + + // do something +}); + +function toTitleCase(str) { + return str.replace( + /\w\S*/g, + function (txt) { + return txt.charAt(0).toUpperCase() + txt.substr(1).toLowerCase(); + } + ); +} + +const getPreviousUntil = function (elem, selector) { + + // Setup siblings array and get previous sibling + const siblings = []; + let prev = elem.previousElementSibling; + + // Loop through all siblings + while (prev) { + + // If the matching item is found, quit + if (selector && prev.matches(selector)) break; + + // Otherwise, push to array + siblings.push(prev); + + // Get the previous sibling + prev = prev.previousElementSibling; + + } + + return siblings; + +}; diff --git a/docs/assets/goose-in-action.gif b/docs/assets/goose-in-action.gif new file mode 100644 index 000000000..bd4d71831 Binary files /dev/null and b/docs/assets/goose-in-action.gif differ diff --git a/docs/assets/goose-in-action.mp4 b/docs/assets/goose-in-action.mp4 new file mode 100644 index 000000000..83f3270e8 Binary files /dev/null and b/docs/assets/goose-in-action.mp4 differ diff --git a/docs/assets/goose.png b/docs/assets/goose.png new file mode 100644 index 000000000..6c9e8e9ef Binary files /dev/null and b/docs/assets/goose.png differ diff --git a/docs/assets/logo.gif b/docs/assets/logo.gif new file mode 100644 index 000000000..90eea64ac Binary files /dev/null and b/docs/assets/logo.gif differ diff --git a/docs/assets/logo.ico b/docs/assets/logo.ico new file mode 100644 index 000000000..9e2a5c4af Binary files /dev/null and b/docs/assets/logo.ico differ diff --git a/docs/assets/logo.png b/docs/assets/logo.png new file mode 100644 index 000000000..3ee041891 Binary files /dev/null and b/docs/assets/logo.png differ diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 000000000..4e7d81f26 --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,149 @@ +# Configuring Goose + +## Profiles + +If you need to customize goose, one way is via editing: `~/.config/goose/profiles.yaml`. + +It will look by default something like (and when you run `goose session start` without the `--profile` flag it will use the `default` profile): + +```yaml +default: + provider: open-ai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: passive + toolkits: + - name: developer + requires: {} +``` + +### Fields + +#### provider + +Provider of LLM. LLM providers that currently are supported by Goose: + +| Provider | Required environment variable(s) to access provider | +| ----- | ------------------------------ | +| openai | `OPENAI_API_KEY` | +| anthropic | `ANTHROPIC_API_KEY` | +| databricks | `DATABRICKS_HOST` and `DATABRICKS_TOKEN` | + +#### processor + +This is the model used for the main Goose loop and main tools -- it should be be capable of complex, multi-step tasks such as writing code and executing commands. Example: `gpt-4o`. You should choose the model based the provider you configured. + +#### accelerator + +Small model for fast, lightweight tasks. Example: `gpt-4o-mini`. You should choose the model based the provider you configured. + +#### moderator + +Rules designed to control or manage the output of the model. Moderators that currently are supported by Goose: + +- `passive`: does not actively intervene in every response +- `truncate`: truncates the first contexts when the contexts exceed the max token size + +### Example `profiles.yaml` files + +#### provider as `anthropic` + +```yaml + +default: + provider: anthropic + processor: claude-3-5-sonnet-20240620 + accelerator: claude-3-5-sonnet-20240620 +``` + +#### provider as `databricks` + +```yaml +default: + provider: databricks + processor: databricks-meta-llama-3-1-70b-instruct + accelerator: databricks-meta-llama-3-1-70b-instruct + moderator: passive + toolkits: + - name: developer + requires: {} +``` + +You can tell it to use another provider for example for Anthropic: + +```yaml +default: + provider: anthropic + processor: claude-3-5-sonnet-20240620 + accelerator: claude-3-5-sonnet-20240620 + moderator: passive + toolkits: + - name: developer + requires: {} +``` + +this will then use the claude-sonnet model, you will need to set the `ANTHROPIC_API_KEY` to your anthropic API key. + +You can also customize Goose's behavior through toolkits. These are set up automatically for you in the same `~/.config/goose/profiles.yaml` file, but you can include or remove toolkits as you see fit. + +For example, Goose's `unit-test-gen` command sets up a new profile in this file for you: + +```yaml +unit-test-gen: + provider: openai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: passive + toolkits: + - name: developer + requires: {} + - name: unit-test-gen + requires: {} + - name: java + requires: {} +``` + +[jinja-guide]: https://jinja.palletsprojects.com/en/3.1.x/ + + +## Adding a toolkit +To make a toolkit available to Goose, add it to your project's pyproject.toml. For example in the Goose pyproject.toml file: +``` +[project.entry-points."goose.toolkit"] +developer = "goose.toolkit.developer:Developer" +github = "goose.toolkit.github:Github" +# Add a line like this - the key becomes the name used in profiles +my-new-toolkit = "goose.toolkit.my_toolkits:MyNewToolkit" # this is the path to the class that implements the toolkit +``` + +Then to set up a profile that uses it, add something to `~/.config/goose/profiles.yaml`: +```yaml +my-profile: + provider: openai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: passive + toolkits: # new toolkit gets added here + - developer + - my-new-toolkit +``` + +And now you can run Goose with this new profile to use the new toolkit! + +```sh +goose session start --profile my-profile +``` + +Or, if you're developing a new toolkit and want to test it: +```sh +uv run goose session start --profile my-profile +``` + +## Tuning Goose to your repo + +Goose ships with the ability to read in the contents of a file named `.goosehints` from your repo. If you find yourself repeating the same information across sessions to Goose, this file is the right place to add this information. + +This file will be read into the Goose system prompt if it is present in the current working directory. + +> [!NOTE] +> `.goosehints` follows [jinja templating rules][jinja-guide] in case you want to leverage templating to insert file contents or variables. \ No newline at end of file diff --git a/docs/contributing.md b/docs/contributing.md new file mode 120000 index 000000000..44fcc6343 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1 @@ +../CONTRIBUTING.md \ No newline at end of file diff --git a/docs/css/code_select.css b/docs/css/code_select.css new file mode 100644 index 000000000..856995da0 --- /dev/null +++ b/docs/css/code_select.css @@ -0,0 +1,5 @@ +.highlight .gp, +.highlight .go { + /* Generic.Prompt, Generic.Output */ + user-select: none; +} \ No newline at end of file diff --git a/docs/guidance/applications.md b/docs/guidance/applications.md new file mode 100644 index 000000000..9f974d9f7 --- /dev/null +++ b/docs/guidance/applications.md @@ -0,0 +1,15 @@ +## Uses of Goose so Far + +We've been using Goose to help us with a variety of tasks. Here are some examples: + +- Conduct code migrations like: + - Ember to React + - Ruby to Kotlin + - Prefect-1 to Prefect-2 +- Dive into a new project in an unfamiliar coding language +- Transition a code-base from field-based injection to constructor-based injection in a dependency injection framework +- Conduct performance benchmarks for a build command using a build automation tool +- Increasing code coverage above a specific threshold +- Scaffolding an API for data retention +- Creating Datadog monitors +- Removing or adding feature flags diff --git a/docs/guidance/getting-started.md b/docs/guidance/getting-started.md new file mode 100644 index 000000000..cac0cf81f --- /dev/null +++ b/docs/guidance/getting-started.md @@ -0,0 +1,141 @@ +# Your first run with Goose + +This page contains two sections that will help you get started with Goose: + +1. [Configuring Goose with the `profiles.yaml` file](#configuring-goose-with-the-profilesyaml-file): how to set up Goose with the right LLMs and toolkits. +2. [Working with Goose](#working-with-goose): how to guide Goose through a task, and how to provide context for Goose to work with. + +## Configuring Goose with the `profiles.yaml` file +On the first run, Goose will detect what LLMs are available from your environment, and generate a configuration file at `~/.config/goose/profiles.yaml`. You can edit those profiles to further configure goose. + +Here’s what the default `profiles.yaml` could look like if Goose detects an OpenAI API key: + +```yaml +default: + provider: openai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: truncate + toolkits: + - name: developer + requires: {} +``` + +You can edit this configuration file to use different LLMs and toolkits in Goose. Check out the [configuration docs][configuration] to better understand the different fields of the `profiles.yaml` file! You can add new profiles with different settings to change how goose works from one section to the next - use `goose session start --profile {profile}` to select which to use. + +### LLM provider access setup + +Goose works on top of LLMs. You'll need to configure one before using it. By default, Goose uses `openai` as the LLM provider but you can customize it as needed. You need to set OPENAI_API_KEY as an environment variable if you would like to use `openai`. + +To learn more about providers and modes of access, check out the [provider docs][providers]. +```sh +export OPENAI_API_KEY=your_open_api_key +``` + +## Working with Goose + +Goose works best with some amount of context or instructions for a given task. You can guide goose through gathering the context it needs by giving it instructions or asking it to explore with its tools. But to make this easier, context in Goose can be extended a few additional ways: + +1. User-directed input +2. A `.goosehints` file +3. Toolkits +4. Plans + +### User-directed input + +Directing Goose to read a specific file before requesting changes ensures that the file's contents are loaded into its operational context. Similarly, asking Goose to summarize the current project before initiating changes across multiple files provides a detailed overview of the project structure, including the locations of specific classes, functions, and other components. + +### `.goosehints` + +If you are using the `developer` toolkit, `goose` adds the content from `.goosehints` file in the working directory to the system prompt. The hints file is meant to provide additional context about your project. The context can be user-specific or at the project level in which case, you can commit it to git. `.goosehints` file is Jinja templated so you could have something like this: + +``` +Here is an overview of how to contribute: +{% include 'CONTRIBUTING.md' %} + +The following justfile shows our common commands: +{% include 'justfile' %} + +Write all code comments in French +``` + +### Toolkits + +Toolkits expand Goose’s capabilities and tailor its functionality to specific development tasks. Toolkits provide Goose with additional contextual information and interactive abilities, allowing for a more comprehensive and efficient workflow. + +Here are some out-of-the-box examples: + +* `developer`: for general-purpose development capabilities, including plan management, shell execution, and file operations, with default shell strategies like using ripgrep. +* `screen`: for letting goose take a look at your screen to help debug or work on designs (gives goose eyes) +* `github`: for suggestions on how to use Github +* `repo_context`: for summarizing and understanding a repository you are working in. +* `jira`: for working with JIRA (issues, backlogs, tasks, bugs etc.) + +You can see the current toolkits available to Goose [here][available-toolkits]. There's also a [public plugins repository where toolkits are defined][goose-plugins] for Goose that has toolkits you can try out. + +### Plans + +Goose creates plans for itself to execute to achieve its goals. In some cases, you may already have a plan in mind for Goose — this is where you can define your own `plan.md` file, and it will set the first message and also hard code Goose's initial plan. + +The plan.md file can be text in any format and uses `jinja` templating, and the last group of lines that start with “-” will be considered the plan. + +Here are some examples: + +#### Basic example plan + +```md +Your goal is to refactor this fastapi application to use a sqlite database. Use `pytest -s -v -x` to run the tests when needed. + +- Use ripgrep to find the fastapi app and its tests in this directory +- read the files you found +- Add sqlalchemy and alembic as dependencies with poetry +- Run alembic init to set up the basic configuration +- Add sqlite dependency with Poetry +- Create new module for database code and include sqlalchemy and alembic setup +- Define an accounts table with SQLAlchemy +- Implement CRUD operations for accounts table +- Update main.py to integrate with SQLite database and use CRUD operation +- Use alembic to create the table +- Use conftest to set up a test database with a new DB URL +- Run existing test suite and ensure all tests pass. Do not edit the test case behavior, instead use tests to find issues. +``` + +The starting plan is specified with the tasks. Each list entry is a different step in the plan. This is a pretty detailed set of tasks, but is really just a break-down of the conversation we had in the previous section. + +The kickoff message is what gets set as the first user message when goose starts running (with the plan). This message should contain the overall goal of the tasks and could also contain extra context you want to include for this problem. In our case, we are just mentioning the test command we want to use to run the tests. + +To run Goose with this plan: + +``` sh +goose session start --plan plan.md +``` + +#### Injecting arguments into a plan + +You can also inject arguments into your plan. `plan.md` files can be templated with `jinja` and can include variables that are passed in when you start the session. + +The kickoff message gives Goose directions to use poetry and a dependency, and then a plan is to open a file, run a test, and set up a repo: + +```md +Here is the python repo + +- use {{ dep }} +- use poetry + +Here is the plan: + +- Open a file +- Run a test +- Set up {{ repo }} +``` + +To run Goose with this plan with the arguments `dep=pytest,repo=github`, you would run the following command: + +```sh +goose session start --plan plan.md --args dep=pytest,repo=github +``` + +[configuration]: ../configuration.md +[available-toolkits]: ../plugins/available-toolkits.md +[providers]: ../plugins/providers.md +[goose-plugins]:https://github.com/block-open-source/goose-plugins diff --git a/docs/guidance/goose-in-action.md b/docs/guidance/goose-in-action.md new file mode 100644 index 000000000..43def53cd --- /dev/null +++ b/docs/guidance/goose-in-action.md @@ -0,0 +1,21 @@ +# Goose in action + +This page is frequently updated with the latest use-cases and applications of Goose! + +## Goose as a Github Action + +**What it does**: + +An early version of a GitHub action that uses Goose to automatically address issues in your repository. It operates in the background to attempt fixes or enhancements based on issue descriptions. + +The action attempts to fix issues described in GitHub. It takes the issue's title and body as input and tries to resolve the issue programmatically. + +If the action successfully fixes the issue, it will automatically create a pull request with the fix. If it cannot confidently fix the issue, no pull request is created. + +**Where you can find it**: https://github.com/marketplace/actions/goose-ai-developer-agent + +**How you can do something similar**: + +1. Decide what specific task you want Goose to automate. This could be anything from auto-linting code, updating dependencies, auto-merging approved pull requests, or even automating responses to issue comments. +2. In the `action.yml`, specify any inputs your action needs (like GitHub tokens, configuration files, specific command inputs) and outputs it may produce. +3. Write the script (e.g., Python or JavaScript) that Goose will use to perform the tasks. This involves setting up the Goose environment, handling GitHub API requests, and processing the task-specific logic. diff --git a/docs/guidance/tips.md b/docs/guidance/tips.md new file mode 100644 index 000000000..edd827fae --- /dev/null +++ b/docs/guidance/tips.md @@ -0,0 +1,20 @@ +## Tips for working with Goose: + +Here are some collected tips we have for working efficiently with Goose + +- **Goose can and will edit files**. Use a git strategy to avoid losing anything - such as staging your +personal edits and leaving Goose edits unstaged until reviewed. Or consider using individual commits which can be reverted. +- **Goose can and will run commands**. You can ask it to check with you first if you are concerned. It will check commands for safety as well. +- You can interrupt Goose with `CTRL+C` to correct it or give it more info. +- Goose works best when solving concrete problems - experiment with how far you need to break that problem +down to get Goose to solve it. Be specific! E.g. it will likely fail to `"create a banking app"`, +but probably does a good job if prompted with `"create a Fastapi app with an endpoint for deposit and withdrawal and with account balances stored in mysql keyed by id"` +- If Goose doesn't have enough context to start with, it might go down the wrong direction. Tell it +to read files that you are referring to or search for objects in code. Even better, ask it to summarize +them for you, which will help it set up its own next steps. +- Refer to any objects in files with something that is easy to search for, such as `"the MyExample class" +- Goose *loves* to know how to run tests to get a feedback loop going, just like you do. If you tell it how you test things locally and quickly, it can make use of that when working on your project +- You can use Goose for tasks that would require scripting at times, even looking at your screen and correcting designs/helping you fix bugs, try asking it to help you in a way you would ask a person. +- Goose will make mistakes, and go in the wrong direction from times, feel free to correct it, or start again. +- You can tell Goose to run things for you continuously (and it will iterate, try, retry) but you can also tell it to check with you before doing things (and then later on tell it to go off on its own and do its best to solve). +- Goose can run anywhere, doesn't have to be in a repo, just ask it! \ No newline at end of file diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 000000000..ec8e7e5e5 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,185 @@ +

+Goose is your on-machine developer agent, automating engineering tasks seamlessly within your IDE or terminal +

+ +

+ Goose Drawing +

+

+ Generated by Goose from its VincentVanCode toolkit. +

+ +

+ + + + + + + + + + + Discord + +

+ +

+Unique features 🤖 • + Block Employees on Goose Block Emoji • +Quick start guide 🚀 • +Getting involved! 👋 +

+ +>[!TIP] +> **Quick install:** +> ``` +> pipx install goose-ai +> ``` + +**Goose** is a developer agent that supercharges your software development by automating an array of coding tasks directly within your terminal or IDE. Guided by you, it can intelligently assess your project's needs, generate the required code or modifications, and implement these changes on its own. Goose can **interact with a multitude of tools via external APIs** such as Jira, GitHub, Slack, infrastructure and data pipelines, and more -- if your task uses a **shell command or can be carried out by a Python script, Goose can do it for you too!** Like semi-autonomous driving, Goose handles the heavy lifting, allowing you to focus on other priorities. Simply set it on a task and return later to find it completed, boosting your productivity with less manual effort. + +## Unique features of Goose compared to other AI assistants + +- **Autonomy**: A copilot should be able to also fly the plane at times, which in the development world means running code, debugging tests, installing dependencies, not just providing text output and autocomplete or search. Goose moves beyond just generating code snippets by (1) **using the shell** and (2) by seeing what happens with the code it writes and starting a feedback loop to solve harder problems, **refining solutions iteratively like a human developer**. Your code's best wingman. + +- **Extensibility**: Open-source and fully customizable, Goose integrates with your workflow and allows you to extend it for even more control. **Toolkits let you add new capabilities to Goose.** They are anything you can implement as a Python function (e.g. API requests, deployments, search, etc). We have a growing library of toolkits to use, but more importantly you can create your own. This gives Goose the ability to run these commands and decide if and when a tool is needed to complete your request! **Creating your own toolkits give you a way to bring your own private context into Goose's capabilities.** And you can use *any* LLM you want under the hood, as long as it supports tool use. + +## What Block employees have to say about Goose + +> With Goose, I feel like I am Maverick. +> +> Thanks a ton for creating this. 🙏 +> I have been having way too much fun with it today. + +-- P, Machine Learning Engineer + + +> I wanted to construct some fake data for an API with a large request body and business rules I haven't memorized. So I told Goose which object to update and a test to run that calls the vendor. Got it to use the errors descriptions from the vendor response to keep correcting the request until it was successful. So good! + +-- J, Software Engineer + + +> I asked Goose to write up a few Google Scripts that mimic Clockwise's functionality (particularly, creating blocks on my work calendar based on events in my personal calendar, as well as color-coding calendar entries based on type and importance). Took me under an hour. If you haven't tried Goose yet, I highly encourage you to do so! + +-- M, Software Engineer + + +> If anyone was looking for another reason to check it out: I just asked Goose to break a string-array into individual string resources across eleven localizations, and it performed amazingly well and saved me a bunch of time doing it manually or figuring out some way to semi-automate it. + +-- A, Android Engineer + + +> Hi team, thank you for much for making Goose, it's so amazing. Our team is working on migrating Dashboard components to React components. I am working with Goose to help the migration. + +-- K, Software Engineer + + +> Got Goose to update a dependency, run tests, make a branch and a commit... it was 🤌. Not that complicated but I was impressed it figured out how to run tests from the README. + +-- J, Software Engineer + + +> Wanted to document what I had Goose do -- took about 30 minutes end to end! I created a custom CLI command in the `gh CLI` library to download in-line comments on PRs about code changes (currently they aren't directly viewable). I don't know Go *that well* and I definitely didn't know where to start looking in the code base or how to even test the new command was working and Goose did it all for me 😁 + +-- L, Software Engineer + + +> Hi Team, just wanted to share my experience of using Goose as a non-engineer! ... I just asked Goose to ensure that my environment is up to date and copied over a guide into my prompt. Goose managed everything flawlessly, keeping me informed at every step... I was truly impressed with how well it works and how easy it was to get started! 😍 + +-- M, Product Manager + +**See more of our use-cases in our [docs][use-cases]!** + +## Quick start guide + +### Installation + +To install Goose, use `pipx`. First ensure [pipx][pipx] is installed: + +``` sh +brew install pipx +pipx ensurepath +``` +You can also place `.goosehints` in `~/.config/goose/.goosehints` if you like for always loaded hints personal to you. + +Then install Goose: + +```sh +pipx install goose-ai +``` + +### Running Goose + +#### Start a session + +From your terminal, navigate to the directory you'd like to start from and run: + +```sh +goose session start +``` + +You will see the Goose prompt `G❯`: + +``` +G❯ type your instructions here exactly as you would tell a developer. +``` + +Now you are interacting with Goose in conversational sessions - something like a natural language driven code interpreter. The default toolkit allows Goose to take actions through shell commands and file edits. You can interrupt Goose with `CTRL+D` or `ESC+Enter` at any time to help redirect its efforts. + +#### Exit the session + +If you are looking to exit, use `CTRL+D`, although Goose should help you figure that out if you forget. + +#### Resume a session + +When you exit a session, it will save the history in `~/.config/goose/sessions` directory and you can resume it later on: + +``` sh +goose session resume +``` + +To see more documentation on the CLI commands currently available to Goose check out the documentation [here][cli]. If you’d like to develop your own CLI commands for Goose, check out the [Contributing document][contributing]. + +### Next steps + +Learn how to modify your Goose profiles.yaml file to add and remove functionality (toolkits) and providing context to get the most out of Goose in our [Getting Started Guide][getting-started]. + +**Want to move out of the terminal and into an IDE?** + +We have some experimental IDE integrations for VSCode and JetBrains IDEs: +* https://github.com/square/goose-vscode +* https://github.com/Kvadratni/goose-intellij + +## Getting involved! + +There is a lot to do! If you're interested in contributing, a great place to start is picking a `good-first-issue`-labelled ticket from our [issues list][gh-issues]. More details on how to develop Goose can be found in our [Contributing Guide][contributing]. We are a friendly, collaborative group and look forward to working together![^1] + + +Check out and contribute to more experimental features in [Goose Plugins][goose-plugins]! + +Let us know what you think in our [Discussions][discussions] or the [**`#goose`** channel on Discord][goose-channel]. + +[^1]: Yes, Goose is open source and always will be. Goose is released under the ASL2.0 license meaning you are free to use it however you like. See [LICENSE.md][license] for more details. + + + +[goose-plugins]: https://github.com/block-open-source/goose-plugins + +[pipx]: https://github.com/pypa/pipx?tab=readme-ov-file#install-pipx +[contributing]: https://github.com/block-open-source/goose/blob/main/CONTRIBUTING.md +[license]: https://github.com/block-open-source/goose/blob/main/LICENSE + +[goose-docs]: https://block-open-source.github.io/goose/ +[toolkits]: https://block-open-source.github.io/goose/plugins/available-toolkits.html +[configuration]: https://block-open-source.github.io/goose/configuration.html +[cli]: https://block-open-source.github.io/goose/plugins/cli.html +[providers]: https://block-open-source.github.io/goose/providers.html +[use-cases]: https://block-open-source.github.io/goose/guidance/applications.html +[getting-started]: https://block-open-source.github.io/goose/guidance/getting-started.html + +[discord-invite]: https://discord.gg/7GaTvbDwga +[gh-issues]: https://github.com/block-open-source/goose/issues +[van-code]: https://github.com/block-open-source/goose-plugins/blob/de98cd6c29f8e7cd3b6ace26535f24ac57c9effa/src/goose_plugins/toolkits/artify.py +[discussions]: https://github.com/block-open-source/goose/discussions +[goose-channel]: https://discord.com/channels/1287729918100246654/1287729920319033345 diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 000000000..641f0e096 --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,18 @@ +# Installation + +To install Goose, use `pipx`.First ensure [pipx][pipx] is installed: + +``` sh +brew install pipx +pipx ensurepath +``` + +Then install Goose: + +```sh +pipx install goose-ai +``` + +[pipx]: https://github.com/pypa/pipx?tab=readme-ov-file#install-pipx + +You can then run `goose` from the command line with `goose session start`. \ No newline at end of file diff --git a/docs/plugins/available-toolkits.md b/docs/plugins/available-toolkits.md new file mode 100644 index 000000000..a89bb8af8 --- /dev/null +++ b/docs/plugins/available-toolkits.md @@ -0,0 +1,60 @@ +# Available Toolkits in Goose + +Goose provides a variety of toolkits designed to help developers with different tasks. Here's an overview of each available toolkit and its functionalities: + +## 1. Developer Toolkit + +The **Developer** toolkit offers general-purpose development capabilities, including: + +- **System Configuration Details:** Retrieves system configuration details. +- **Task Management:** Update the plan by overwriting all current tasks. +- **File Operations:** + - `patch_file`: Patch a file by replacing specific content. + - `read_file`: Read the content of a specified file. + - `write_file`: Write content to a specified file. +- **Shell Command Execution:** Execute shell commands with safety checks. + +## 2. GitHub Toolkit + +The **GitHub** toolkit provides detailed configuration and procedural guidelines for GitHub operations. + +## 3. Lint Toolkit + +The **Lint** toolkit ensures that all toolkits have proper documentation. It performs the following checks: + +- Toolkit must have a docstring. +- The first line of the docstring should contain more than 5 words and fewer than 12 words. +- The first letter of the docstring should be capitalized. + +## 4. RepoContext Toolkit + +The **RepoContext** toolkit provides context about the current repository. It includes: + +- **Repository Size:** Get the size of the repository. +- **Monorepo Check:** Determine if the repository is a monorepo. +- **Project Summarization:** Summarize the current project based on the repository or the current project directory. + +## 5. Screen Toolkit + +The **Screen** toolkit assists users in taking screenshots for debugging or designing purposes. It provides: + +- **Take Screenshot:** Capture a screenshot and provide the path to the screenshot file. +- **System Instructions:** Instructions on how to work with screenshots. + +## 6. SummarizeRepo Toolkit + +The **SummarizeRepo** toolkit helps in summarizing a repository. It includes: + +- **Summarize Repository:** Clone the repository (if not already cloned) and summarize the files based on specified extensions. + +## 7. SummarizeProject Toolkit + +The **SummarizeProject** toolkit generates or retrieves a summary of a project directory based on specified file extensions. It includes: + +- **Get Project Summary:** Generate or retrieve a summary of the project in the specified directory. + +## 8. SummarizeFile Toolkit + +The **SummarizeFile** toolkit helps in summarizing a specific file. It includes: + +- **Summarize File:** Summarize the contents of a specified file with optional instructions. diff --git a/docs/plugins/cli.md b/docs/plugins/cli.md new file mode 100644 index 000000000..5d27563c9 --- /dev/null +++ b/docs/plugins/cli.md @@ -0,0 +1,63 @@ +# Goose CLI Commands + +Goose provides a command-line interface (CLI) with various commands to manage sessions, toolkits, and more. Below is a list of the available commands and their descriptions: + +## Goose CLI + +### `version` + +**Usage:** +```sh + goose version +``` + +Lists the version of Goose and any associated plugins. + +### `session` + +#### `start` + +**Usage:** +```sh + goose session start [--profile PROFILE] [--plan PLAN] +``` + +Starts a new Goose session. + +#### `resume` + +**Usage:** +```sh + goose session resume [NAME] [--profile PROFILE] +``` + +Resumes an existing Goose session. + +#### `list` + +**Usage:** +```sh + goose session list +``` + +Lists all Goose sessions. + +#### `clear` + +**Usage:** +```sh + goose session clear [--keep KEEP] +``` + +Deletes old Goose sessions, keeping the most recent ones as specified by the `--keep` option. + +### `toolkit` + +#### `list` + +**Usage:** +```sh + goose toolkit list +``` + +Lists all available toolkits with their descriptions. diff --git a/docs/plugins/creating-a-new-toolkit.md b/docs/plugins/creating-a-new-toolkit.md new file mode 100644 index 000000000..8574c125a --- /dev/null +++ b/docs/plugins/creating-a-new-toolkit.md @@ -0,0 +1,94 @@ +# Creating a New Toolkit + +To add a toolkit, in your code (which doesn't necessarily need to be in the Goose package thanks to [plugin metadata][plugin]!), create a class that derives from the `Toolkit` class. + +## Example toolkit class +Below is an example of a simple toolkit called `Demo` that derives from the `Toolkit` class. This toolkit provides an `authenticate` tool that outputs an authentication code for a user. It also provides system instructions for the model. +```python +import os +import platform + +from goose.toolkit.base import Toolkit, tool + + +class Demo(Toolkit): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Provide any additional tools as needed! + # The docstring of the tool provides instructions to the LLM, so they are important to tune + # you do not have to provide any tools, but any function decorated with @tool will be available + @tool + def authenticate(self, user: str): + """Output an authentication code for this user + + Args: + user (str): The username to authenticate for + """ + # notifier supports any rich renderable https://rich.readthedocs.io/en/stable/introduction.html#quick-start + self.notifier.log(f"[bold red]auth: {str(hash(user))}[/]") + + # Provide any system instructions for the model + # This can be generated dynamically, and is run at startup time + def system(self) -> str: + print("new") + return f"""**You must preceed your first message by using the authenticate tool for the current user** + + ``` + platform: {platform.system()} + cwd: {os.getcwd()} + user: {os.environ.get('USER')} + ``` + """ +``` + +## Exposing the New Toolkit to Goose + +To make the toolkit available, add it to the `pyproject.toml` file and then update your `profiles.yaml` file. + +### Update the `pyproject.toml` file +If you're adding the new toolkit to Goose or the Goose Plugins repo, simply find the `[project.entry-points."goose.toolkit"]` section in `pyproject.toml` and add a line like this: +```toml +[project.entry-points."goose.toolkit"] +developer = "goose.toolkit.developer:Developer" +github = "goose.toolkit.github:Github" +# Add a line like this - the key becomes the name used in profiles +demo = "goose.toolkit.demo:Demo" +``` + +If you are adding the toolkit to a different package, see the docs for `goose-plugins` for more information on how to create a plugins repository that can be used by Goose. + +### Update the `profiles.yaml` file +And then to set up a profile that uses it, add something to ~/.config/goose/profiles.yaml +```yaml +default: + provider: openai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: passive + toolkits: + - name: developer + requires: {} +demo-profile: + provider: openai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: passive + toolkits: + - developer + - demo +``` + +And now you can run goose with this new profile to use the new toolkit! + +```sh +goose session start --profile demo-profile +``` + +> [!NOTE] +> If you're using a plugin from `goose-plugins`, make sure `goose-plugins` is installed in your environment. You can install it via pip: +> +> `pipx install goose-ai --preinstall goose-plugins` + +[plugin]: https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata +[goose-plugins]: https://github.com/block-open-source/goose-plugins diff --git a/docs/plugins/plugins.md b/docs/plugins/plugins.md new file mode 100644 index 000000000..b95ca5a9d --- /dev/null +++ b/docs/plugins/plugins.md @@ -0,0 +1,15 @@ +# Plugins in Goose + +Goose's functionality is extended via plugins. These plugins fall into three main categories: + +1. **Toolkits**: + * Provides Goose with tools (functions) it can call and optionally will load additional context into the system prompt (such as 'The Github CLI is called via `gh` and you should use it to run git commands'). + * Toolkits can do basically anything, from calling external APIs, to taking a screenshot of your screen, to summarizing your current project. +2. **CLI commands**: + * Provides additional commands to the Goose CLI. + * These commands can be used to interact with the Goose system, such as listing available toolkits or summarizing a session. +3. **Providers**: + * Provides Goose with access to external LLMs. + * For example, the OpenAI provider allows Goose to interact with the OpenAI API. + * Most providers for Goose are defined in the Exchange library. + diff --git a/docs/plugins/providers.md b/docs/plugins/providers.md new file mode 100644 index 000000000..7527f798a --- /dev/null +++ b/docs/plugins/providers.md @@ -0,0 +1,14 @@ +# Providers + +Providers in Goose mean "LLM providers" that Goose can interact with. Providers are defined in the [Exchange library][exchange-providers] for the most part, but you can define your own. + +**Currently available providers:** + +* Anthropic +* Azure +* Bedrock +* Databricks +* Ollama +* OpenAI + +[exchange-providers]: https://github.com/block-open-source/goose/tree/main/packages/exchange/src/exchange/providers diff --git a/docs/plugins/using-toolkits.md b/docs/plugins/using-toolkits.md new file mode 100644 index 000000000..189e0b7ce --- /dev/null +++ b/docs/plugins/using-toolkits.md @@ -0,0 +1,41 @@ +# Using Toolkits + +Use `goose toolkit list` to list the available toolkits. + +## Toolkits defined in Goose + +Using Goose with toolkits is simple. You can add toolkits to your profile in the `profiles.yaml` file. Here's an example of how to add `my-toolkit` toolkit to your profile: + +```yaml +my-profile: + provider: openai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: passive + toolkits: + - my-toolkit +``` + +Then run Goose with the specified profile: + +```sh +goose session start --profile my-profile +``` + +## Toolkits defined in Goose Plugins + +1. First make sure that `goose-plugins` is intalled with Goose: +```sh +pipx install goose-ai --preinstall goose-plugins +``` +2. Update the `profiles.yaml` file to include the desired toolkit: +```yaml +my-profile: + provider: openai + processor: gpt-4o + accelerator: gpt-4o-mini + moderator: passive + toolkits: + - my-goose-plugins-toolkit +``` + diff --git a/docs/reference/goose/build.md b/docs/reference/goose/build.md new file mode 100644 index 000000000..bbeed2811 --- /dev/null +++ b/docs/reference/goose/build.md @@ -0,0 +1 @@ +::: goose.build \ No newline at end of file diff --git a/docs/reference/goose/cli/config.md b/docs/reference/goose/cli/config.md new file mode 100644 index 000000000..daa19236d --- /dev/null +++ b/docs/reference/goose/cli/config.md @@ -0,0 +1 @@ +::: goose.cli.config \ No newline at end of file diff --git a/docs/reference/goose/cli/index.md b/docs/reference/goose/cli/index.md new file mode 100644 index 000000000..90b9e4082 --- /dev/null +++ b/docs/reference/goose/cli/index.md @@ -0,0 +1 @@ +::: goose.cli \ No newline at end of file diff --git a/docs/reference/goose/cli/main.md b/docs/reference/goose/cli/main.md new file mode 100644 index 000000000..03542bff4 --- /dev/null +++ b/docs/reference/goose/cli/main.md @@ -0,0 +1 @@ +::: goose.cli.main \ No newline at end of file diff --git a/docs/reference/goose/cli/prompt/completer.md b/docs/reference/goose/cli/prompt/completer.md new file mode 100644 index 000000000..fd51b520c --- /dev/null +++ b/docs/reference/goose/cli/prompt/completer.md @@ -0,0 +1 @@ +::: goose.cli.prompt.completer \ No newline at end of file diff --git a/docs/reference/goose/cli/prompt/create.md b/docs/reference/goose/cli/prompt/create.md new file mode 100644 index 000000000..9fb517150 --- /dev/null +++ b/docs/reference/goose/cli/prompt/create.md @@ -0,0 +1 @@ +::: goose.cli.prompt.create \ No newline at end of file diff --git a/docs/reference/goose/cli/prompt/goose_prompt_session.md b/docs/reference/goose/cli/prompt/goose_prompt_session.md new file mode 100644 index 000000000..e92b0fa9d --- /dev/null +++ b/docs/reference/goose/cli/prompt/goose_prompt_session.md @@ -0,0 +1 @@ +::: goose.cli.prompt.goose_prompt_session \ No newline at end of file diff --git a/docs/reference/goose/cli/prompt/index.md b/docs/reference/goose/cli/prompt/index.md new file mode 100644 index 000000000..fdcbffd28 --- /dev/null +++ b/docs/reference/goose/cli/prompt/index.md @@ -0,0 +1 @@ +::: goose.cli.prompt \ No newline at end of file diff --git a/docs/reference/goose/cli/prompt/lexer.md b/docs/reference/goose/cli/prompt/lexer.md new file mode 100644 index 000000000..90c2e3ed5 --- /dev/null +++ b/docs/reference/goose/cli/prompt/lexer.md @@ -0,0 +1 @@ +::: goose.cli.prompt.lexer \ No newline at end of file diff --git a/docs/reference/goose/cli/prompt/prompt_validator.md b/docs/reference/goose/cli/prompt/prompt_validator.md new file mode 100644 index 000000000..1f71b5fc5 --- /dev/null +++ b/docs/reference/goose/cli/prompt/prompt_validator.md @@ -0,0 +1 @@ +::: goose.cli.prompt.prompt_validator \ No newline at end of file diff --git a/docs/reference/goose/cli/prompt/user_input.md b/docs/reference/goose/cli/prompt/user_input.md new file mode 100644 index 000000000..12d7d049e --- /dev/null +++ b/docs/reference/goose/cli/prompt/user_input.md @@ -0,0 +1 @@ +::: goose.cli.prompt.user_input \ No newline at end of file diff --git a/docs/reference/goose/cli/session.md b/docs/reference/goose/cli/session.md new file mode 100644 index 000000000..4b6461478 --- /dev/null +++ b/docs/reference/goose/cli/session.md @@ -0,0 +1 @@ +::: goose.cli.session \ No newline at end of file diff --git a/docs/reference/goose/command/base.md b/docs/reference/goose/command/base.md new file mode 100644 index 000000000..d5a77f02a --- /dev/null +++ b/docs/reference/goose/command/base.md @@ -0,0 +1 @@ +::: goose.command.base \ No newline at end of file diff --git a/docs/reference/goose/command/file.md b/docs/reference/goose/command/file.md new file mode 100644 index 000000000..b9b83b6b2 --- /dev/null +++ b/docs/reference/goose/command/file.md @@ -0,0 +1 @@ +::: goose.command.file \ No newline at end of file diff --git a/docs/reference/goose/command/index.md b/docs/reference/goose/command/index.md new file mode 100644 index 000000000..2869685a3 --- /dev/null +++ b/docs/reference/goose/command/index.md @@ -0,0 +1 @@ +::: goose.command \ No newline at end of file diff --git a/docs/reference/goose/index.md b/docs/reference/goose/index.md new file mode 100644 index 000000000..33f433469 --- /dev/null +++ b/docs/reference/goose/index.md @@ -0,0 +1 @@ +::: goose \ No newline at end of file diff --git a/docs/reference/goose/notifier.md b/docs/reference/goose/notifier.md new file mode 100644 index 000000000..2e585d282 --- /dev/null +++ b/docs/reference/goose/notifier.md @@ -0,0 +1 @@ +::: goose.notifier \ No newline at end of file diff --git a/docs/reference/goose/profile.md b/docs/reference/goose/profile.md new file mode 100644 index 000000000..ebd8a9a12 --- /dev/null +++ b/docs/reference/goose/profile.md @@ -0,0 +1 @@ +::: goose.profile \ No newline at end of file diff --git a/docs/reference/goose/toolkit/base.md b/docs/reference/goose/toolkit/base.md new file mode 100644 index 000000000..5452fb0a1 --- /dev/null +++ b/docs/reference/goose/toolkit/base.md @@ -0,0 +1 @@ +::: goose.toolkit.base \ No newline at end of file diff --git a/docs/reference/goose/toolkit/developer.md b/docs/reference/goose/toolkit/developer.md new file mode 100644 index 000000000..0dac07692 --- /dev/null +++ b/docs/reference/goose/toolkit/developer.md @@ -0,0 +1 @@ +::: goose.toolkit.developer \ No newline at end of file diff --git a/docs/reference/goose/toolkit/github.md b/docs/reference/goose/toolkit/github.md new file mode 100644 index 000000000..53d15b997 --- /dev/null +++ b/docs/reference/goose/toolkit/github.md @@ -0,0 +1 @@ +::: goose.toolkit.github \ No newline at end of file diff --git a/docs/reference/goose/toolkit/index.md b/docs/reference/goose/toolkit/index.md new file mode 100644 index 000000000..8a615c5eb --- /dev/null +++ b/docs/reference/goose/toolkit/index.md @@ -0,0 +1 @@ +::: goose.toolkit \ No newline at end of file diff --git a/docs/reference/goose/toolkit/lint.md b/docs/reference/goose/toolkit/lint.md new file mode 100644 index 000000000..16f875a8c --- /dev/null +++ b/docs/reference/goose/toolkit/lint.md @@ -0,0 +1 @@ +::: goose.toolkit.lint \ No newline at end of file diff --git a/docs/reference/goose/toolkit/repo_context/index.md b/docs/reference/goose/toolkit/repo_context/index.md new file mode 100644 index 000000000..e7cb5edb3 --- /dev/null +++ b/docs/reference/goose/toolkit/repo_context/index.md @@ -0,0 +1 @@ +::: goose.toolkit.repo_context \ No newline at end of file diff --git a/docs/reference/goose/toolkit/repo_context/repo_context.md b/docs/reference/goose/toolkit/repo_context/repo_context.md new file mode 100644 index 000000000..79b964dd0 --- /dev/null +++ b/docs/reference/goose/toolkit/repo_context/repo_context.md @@ -0,0 +1 @@ +::: goose.toolkit.repo_context.repo_context \ No newline at end of file diff --git a/docs/reference/goose/toolkit/repo_context/utils.md b/docs/reference/goose/toolkit/repo_context/utils.md new file mode 100644 index 000000000..f1adc4327 --- /dev/null +++ b/docs/reference/goose/toolkit/repo_context/utils.md @@ -0,0 +1 @@ +::: goose.toolkit.repo_context.utils \ No newline at end of file diff --git a/docs/reference/goose/toolkit/screen.md b/docs/reference/goose/toolkit/screen.md new file mode 100644 index 000000000..62f5c9f77 --- /dev/null +++ b/docs/reference/goose/toolkit/screen.md @@ -0,0 +1 @@ +::: goose.toolkit.screen \ No newline at end of file diff --git a/docs/reference/goose/toolkit/summarization/index.md b/docs/reference/goose/toolkit/summarization/index.md new file mode 100644 index 000000000..d8360eefe --- /dev/null +++ b/docs/reference/goose/toolkit/summarization/index.md @@ -0,0 +1 @@ +::: goose.toolkit.summarization \ No newline at end of file diff --git a/docs/reference/goose/toolkit/summarization/summarize_file.md b/docs/reference/goose/toolkit/summarization/summarize_file.md new file mode 100644 index 000000000..08c2f80ad --- /dev/null +++ b/docs/reference/goose/toolkit/summarization/summarize_file.md @@ -0,0 +1 @@ +::: goose.toolkit.summarization.summarize_file \ No newline at end of file diff --git a/docs/reference/goose/toolkit/summarization/summarize_project.md b/docs/reference/goose/toolkit/summarization/summarize_project.md new file mode 100644 index 000000000..b8da8a157 --- /dev/null +++ b/docs/reference/goose/toolkit/summarization/summarize_project.md @@ -0,0 +1 @@ +::: goose.toolkit.summarization.summarize_project \ No newline at end of file diff --git a/docs/reference/goose/toolkit/summarization/summarize_repo.md b/docs/reference/goose/toolkit/summarization/summarize_repo.md new file mode 100644 index 000000000..6f4855bbe --- /dev/null +++ b/docs/reference/goose/toolkit/summarization/summarize_repo.md @@ -0,0 +1 @@ +::: goose.toolkit.summarization.summarize_repo \ No newline at end of file diff --git a/docs/reference/goose/toolkit/summarization/utils.md b/docs/reference/goose/toolkit/summarization/utils.md new file mode 100644 index 000000000..4dcab4ae1 --- /dev/null +++ b/docs/reference/goose/toolkit/summarization/utils.md @@ -0,0 +1 @@ +::: goose.toolkit.summarization.utils \ No newline at end of file diff --git a/docs/reference/goose/toolkit/utils.md b/docs/reference/goose/toolkit/utils.md new file mode 100644 index 000000000..22f9daf44 --- /dev/null +++ b/docs/reference/goose/toolkit/utils.md @@ -0,0 +1 @@ +::: goose.toolkit.utils \ No newline at end of file diff --git a/docs/reference/goose/utils/ask.md b/docs/reference/goose/utils/ask.md new file mode 100644 index 000000000..629a4badf --- /dev/null +++ b/docs/reference/goose/utils/ask.md @@ -0,0 +1 @@ +::: goose.utils.ask \ No newline at end of file diff --git a/docs/reference/goose/utils/check_shell_command.md b/docs/reference/goose/utils/check_shell_command.md new file mode 100644 index 000000000..d28665593 --- /dev/null +++ b/docs/reference/goose/utils/check_shell_command.md @@ -0,0 +1 @@ +::: goose.utils.check_shell_command \ No newline at end of file diff --git a/docs/reference/goose/utils/file_utils.md b/docs/reference/goose/utils/file_utils.md new file mode 100644 index 000000000..af9feeb25 --- /dev/null +++ b/docs/reference/goose/utils/file_utils.md @@ -0,0 +1 @@ +::: goose.utils.file_utils \ No newline at end of file diff --git a/docs/reference/goose/utils/index.md b/docs/reference/goose/utils/index.md new file mode 100644 index 000000000..86b0901a8 --- /dev/null +++ b/docs/reference/goose/utils/index.md @@ -0,0 +1 @@ +::: goose.utils \ No newline at end of file diff --git a/docs/reference/goose/utils/session_file.md b/docs/reference/goose/utils/session_file.md new file mode 100644 index 000000000..c0a7f5b12 --- /dev/null +++ b/docs/reference/goose/utils/session_file.md @@ -0,0 +1 @@ +::: goose.utils.session_file \ No newline at end of file diff --git a/docs/reference/goose/view.md b/docs/reference/goose/view.md new file mode 100644 index 000000000..e38281a4d --- /dev/null +++ b/docs/reference/goose/view.md @@ -0,0 +1 @@ +::: goose.view \ No newline at end of file diff --git a/docs/reference/index.md b/docs/reference/index.md new file mode 100644 index 000000000..ac3905dfa --- /dev/null +++ b/docs/reference/index.md @@ -0,0 +1,38 @@ +# Reference Documentation + +## Goose +- [goose.build](goose/build.md) +- [goose.notifier](goose/notifier.md) +- [goose.profile](goose/profile.md) +- [goose.view](goose/view.md) + +## Command +- [goose.command.base](goose/command/base.md) +- [goose.command.file](goose/command/file.md) + +## CLI +- [goose.cli.config](goose/cli/config.md) +- [goose.cli.main](goose/cli/main.md) +- [goose.cli.prompt.create](goose/cli/prompt/create.md) +- [goose.cli.prompt.goose_prompt_session](goose/cli/prompt/goose_prompt_session.md) +- [goose.cli.session](goose/cli/session.md) + +## Toolkits + [goose.toolkit.base](goose/toolkit/base.md) +- [goose.toolkit.developer](goose/toolkit/developer.md) +- [goose.toolkit.github](goose/toolkit/github.md) +- [goose.toolkit.repo_context.repo_context](goose/toolkit/repo_context/repo_context.md) +- [goose.toolkit.repo_context.utils](goose/toolkit/repo_context/utils.md) +- [goose.toolkit.screen](goose/toolkit/screen.md) +- [goose.toolkit.summarization.summarize_file](goose/toolkit/summarization/summarize_file.md) +- [goose.toolkit.summarization.summarize_project](goose/toolkit/summarization/summarize_project.md) +- [goose.toolkit.summarization.summarize_repo](goose/toolkit/summarization/summarize_repo.md) +- [goose.toolkit.summarization.utils](goose/toolkit/summarization/utils.md) +- [goose.toolkit.utils](goose/toolkit/utils.md) + +## Utils +- [goose.utils](goose/utils/index.md) +- [goose.utils.ask](goose/utils/ask.md) +- [goose.utils.check_shell_command](goose/utils/check_shell_command.md) +- [goose.utils.file_utils](goose/utils/file_utils.md) +- [goose.utils.session_file](goose/utils/session_file.md) \ No newline at end of file diff --git a/docs/scripts/gen_ref_pages.py b/docs/scripts/gen_ref_pages.py new file mode 100644 index 000000000..da964e59c --- /dev/null +++ b/docs/scripts/gen_ref_pages.py @@ -0,0 +1,68 @@ +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +root = Path(__file__).parent.parent.parent +src = root / "src" + +# Collecting all modules for the index page +module_links = [] +core_modules = [] +toolkit_modules = [] +utils_modules = [] + +for path in sorted(src.rglob("*.py")): + module_path = path.relative_to(src).with_suffix("") # Removes the '.py' suffix + doc_path = path.relative_to(src).with_suffix(".md") # Creates .md path + + full_doc_path = Path("reference", doc_path) + + parts = tuple(module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1] == "__main__": + continue + + # Construct a dynamic identifier based on module path + ident = ".".join(parts) + + # Organize modules into categories + if "toolkit" in ident: + toolkit_modules.append(f"- [{ident}]({doc_path})") + elif "utils" in ident: + utils_modules.append(f"- [{ident}]({doc_path})") + else: + core_modules.append(f"- [{ident}]({doc_path})") + + # Generate the markdown file for each module + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + fd.write(f"::: {ident}") + + nav[parts] = doc_path.as_posix() + +# Write the enhanced SUMMARY.md with categories +with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: + nav_file.write("# Reference Documentation Summary\n\n") + + nav_file.write("## Core Modules\n") + nav_file.write("\n".join(core_modules) + "\n\n") + + nav_file.write("## Toolkit Modules\n") + nav_file.write("\n".join(toolkit_modules) + "\n\n") + + nav_file.write("## Utility Modules\n") + nav_file.write("\n".join(utils_modules) + "\n\n") + +# Create an index.md file for the reference section +with mkdocs_gen_files.open("reference/index.md", "w") as index_file: + index_file.write("# Reference Documentation\n\n") + index_file.write("Welcome to the reference documentation for the project's Python modules.\n\n") + index_file.write("Below is a list of available modules:\n\n") + index_file.write("\n".join(core_modules + toolkit_modules + utils_modules)) + + diff --git a/justfile b/justfile index 994da5f9b..8b2ca2338 100644 --- a/justfile +++ b/justfile @@ -1,5 +1,3 @@ -# This is the default recipe when no arguments are provided -[private] default: @just --list --unsorted @@ -10,10 +8,65 @@ integration *FLAGS: uv run pytest tests -m integration {{FLAGS}} format: - uvx ruff check . --fix - uvx ruff format . + #!/usr/bin/env bash + UVX_PATH="$(which uvx)" + + if [ -z "$UVX_PATH" ]; then + echo "[error]: unable to find uvx" + exit 1 + fi + eval "$UVX_PATH ruff format ." + eval "$UVX_PATH ruff check . --fix" + coverage *FLAGS: uv run coverage run -m pytest tests -m "not integration" {{FLAGS}} uv run coverage report uv run coverage lcov -o lcov.info + +docs: + uv sync && uv run mkdocs serve + +install-hooks: + #!/usr/bin/env bash + HOOKS_DIR="$(git rev-parse --git-path hooks)" + + if [ ! -d "$HOOKS_DIR" ]; then + mkdir -p "$HOOKS_DIR" + fi + + cat > "$HOOKS_DIR/pre-commit" <=.*/ai-exchange>='"${ai_exchange_version}"'\",/' pyproject.toml + git checkout -b release-version-{{version}} + git add pyproject.toml + git commit -m "chore(release): release version {{version}}" + +tag_version: + grep 'version' pyproject.toml | cut -d '"' -f 2 + +tag: + git tag v$(just tag_version) + +# this will kick of ci for release +# use this when release branch is merged to main +tag-push: + just tag + git push origin tag v$(just tag_version) diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 000000000..df77335ee --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,133 @@ +site_name: Goose Documentation +site_author: Block +site_description: Documentation for Goose +repo_url: https://github.com/block-open-source/goose +repo_name: "block-open-source/goose" +edit_uri: "https://github.com/block-open-source/goose/blob/main/docs/" +site_url: "https://block-open-source.github.io/goose/" +use_directory_urls: false +# theme +theme: + name: material + features: + - announce.dismiss + - content.action.edit + - content.action.view + - content.code.annotate + - content.code.copy + - content.tooltips + - content.tabs.link + - navigation.footer + - navigation.indexes + - navigation.instant + - navigation.sections + - navigation.top + - navigation.tracking + - navigation.expand + - search.share + - search.suggest + - toc.follow + palette: + - scheme: slate + primary: black + accent: indigo + logo: assets/logo.png + font: + text: Roboto + code: Roboto Mono + favicon: assets/logo.ico + icon: + logo: assets/logo.ico +# plugins +plugins: + - include-markdown + - callouts + - glightbox + - mkdocstrings: + handlers: + python: + paths: [src] + setup_commands: + - "import sys; sys.path.append('src')" # Add src folder to Python path + - search: + separator: '[\s\u200b\-_,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' + - redirects: + redirect_maps: + - git-committers: # Show git committers + branch: main + enabled: !ENV [ENV_PROD, false] + repository: block-open-source/goose + - git-revision-date-localized: # Show git revision date + enable_creation_date: true + enabled: !ENV [ENV_PROD, false] +extra: + annotate: + json: + - .s2 + social: [] + analytics: + provider: google + property: !ENV GOOGLE_ANALYTICS_KEY +markdown_extensions: + - abbr + - admonition + - attr_list + - def_list + - footnotes + - md_in_html + - nl2br # Newline to
(like GitHub) + - pymdownx.arithmatex: + generic: true + - pymdownx.betterem: + smart_enable: all + - pymdownx.caret + - pymdownx.details + - pymdownx.emoji: + emoji_generator: !!python/name:material.extensions.emoji.to_svg + emoji_index: !!python/name:material.extensions.emoji.twemoji + - pymdownx.highlight: + anchor_linenums: true + line_spans: __span + pygments_lang_class: true + - pymdownx.inlinehilite + - pymdownx.keys + - pymdownx.magiclink: + repo_url_shorthand: false + - pymdownx.mark + - pymdownx.smartsymbols + - pymdownx.snippets + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format + - pymdownx.tabbed: + alternate_style: true + - pymdownx.tasklist: + custom_checkbox: true + - pymdownx.tilde + - toc: + permalink: true +nav: + - Home: index.md + - "Installation": installation.md + - "Contributing": contributing.md + - Guidance: + - "Getting Started": guidance/getting-started.md + - "Quick Tips": guidance/tips.md + - "Applications of Goose": guidance/applications.md + - "Goose in Action": guidance/goose-in-action.md + - Plugins: + - "Overview": plugins/plugins.md + - Toolkits: + - "Using Toolkits": plugins/using-toolkits.md + - "Creating a New Toolkit": plugins/creating-a-new-toolkit.md + - "Available Toolkits": plugins/available-toolkits.md + - CLI Commands: + - "Available CLI Commands": plugins/cli.md + - Providers: + - "Available Providers": plugins/providers.md + - Advanced: + - Configuration: configuration.md + - "Reference": + - "API Docs": reference/index.md diff --git a/packages/exchange/README.md b/packages/exchange/README.md new file mode 100644 index 000000000..04a7b3968 --- /dev/null +++ b/packages/exchange/README.md @@ -0,0 +1,94 @@ +

+ +

+ +

+ Example • + Plugins +

+ +

Exchange - a uniform python SDK for message generation with LLMs

+ +- Provides a flexible layer for message handling and generation +- Directly integrates python functions into tool calling +- Persistently surfaces errors to the underlying models to support reflection + +## Example + +> [!NOTE] +> Before you can run this example, you need to setup an API key with +> `export OPENAI_API_KEY=your-key-here` + +``` python +from exchange import Exchange, Message, Tool +from exchange.providers import OpenAiProvider + +def word_count(text: str): + """Get the count of words in text + + Args: + text (str): The text with words to count + """ + return len(text.split(" ")) + +ex = Exchange( + provider=OpenAiProvider.from_env(), + model="gpt-4o", + system="You are a helpful assistant.", + tools=[Tool.from_function(word_count)], +) +ex.add(Message.user("Count the number of words in this current message")) + +# The model sees it has a word count tool, and should use it along the way to answer +# This will call all the tools as needed until the model replies with the final result +reply = ex.reply() +print(reply.text) + +# you can see all the tool calls in the message history +print(ex.messages) +``` + +## Plugins + +*exchange* has a plugin mechanism to add support for additional providers and moderators. If you need a +provider not supported here, we'd be happy to review contributions. But you +can also consider building and using your own plugin. + +To create a `Provider` plugin, subclass `exchange.provider.Provider`. You will need to +implement the `complete` method. For example this is what we use as a mock in our tests. +You can see a full implementation example of the [OpenAiProvider][openaiprovider]. We +also generally recommend implementing a `from_env` classmethod to instantiate the provider. + +``` python +class MockProvider(Provider): + def __init__(self, sequence: List[Message]): + # We'll use init to provide a preplanned reply sequence + self.sequence = sequence + self.call_count = 0 + + def complete( + self, model: str, system: str, messages: List[Message], tools: List[Tool] + ) -> Message: + output = self.sequence[self.call_count] + self.call_count += 1 + return output +``` + +Then use [python packaging's entrypoints][plugins] to register your plugin. + +``` toml +[project.entry-points.'exchange.provider'] +example = 'path.to.plugin:ExampleProvider' +``` + +Your plugin will then be available in your application or other applications built on *exchange* +through: + +``` python +from exchange.providers import get_provider + +provider = get_provider('example').from_env() +``` + +[openaiprovider]: src/exchange/providers/openai.py +[plugins]: https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ diff --git a/packages/exchange/pyproject.toml b/packages/exchange/pyproject.toml new file mode 100644 index 000000000..83a9e3c25 --- /dev/null +++ b/packages/exchange/pyproject.toml @@ -0,0 +1,48 @@ +[project] +name = "ai-exchange" +version = "0.9.3" +description = "a uniform python SDK for message generation with LLMs" +readme = "README.md" +requires-python = ">=3.10" +author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] +packages = [{ include = "exchange", from = "src" }] +dependencies = [ + "griffe>=1.1.1", + "attrs>=24.2.0", + "jinja2>=3.1.4", + "tiktoken>=0.7.0", + "httpx>=0.27.0", + "tenacity>=9.0.0", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/exchange"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.uv] +dev-dependencies = ["pytest>=8.3.2", "pytest-vcr>=1.0.2", "codecov>=2.1.13"] + +[project.entry-points."exchange.provider"] +openai = "exchange.providers.openai:OpenAiProvider" +azure = "exchange.providers.azure:AzureProvider" +databricks = "exchange.providers.databricks:DatabricksProvider" +anthropic = "exchange.providers.anthropic:AnthropicProvider" +bedrock = "exchange.providers.bedrock:BedrockProvider" +ollama = "exchange.providers.ollama:OllamaProvider" +google = "exchange.providers.google:GoogleProvider" + +[project.entry-points."exchange.moderator"] +passive = "exchange.moderators.passive:PassiveModerator" +truncate = "exchange.moderators.truncate:ContextTruncate" +summarize = "exchange.moderators.summarizer:ContextSummarizer" + +[project.entry-points."metadata.plugins"] +ai-exchange = "exchange:module_name" + +[tool.pytest.ini_options] +markers = [ + "integration: marks tests that need to authenticate (deselect with '-m \"not integration\"')", +] diff --git a/packages/exchange/src/exchange/__init__.py b/packages/exchange/src/exchange/__init__.py new file mode 100644 index 000000000..41adfcf3a --- /dev/null +++ b/packages/exchange/src/exchange/__init__.py @@ -0,0 +1,9 @@ +"""Classes for interacting with the exchange API.""" + +from exchange.tool import Tool # noqa +from exchange.content import Text, ToolResult, ToolUse # noqa +from exchange.message import Message # noqa +from exchange.exchange import Exchange # noqa +from exchange.checkpoint import CheckpointData, Checkpoint # noqa + +module_name = "ai-exchange" diff --git a/packages/exchange/src/exchange/checkpoint.py b/packages/exchange/src/exchange/checkpoint.py new file mode 100644 index 000000000..f355dd0a2 --- /dev/null +++ b/packages/exchange/src/exchange/checkpoint.py @@ -0,0 +1,67 @@ +from copy import deepcopy +from typing import List +from attrs import define, field + + +@define +class Checkpoint: + """Checkpoint that counts the tokens in messages between the start and end index""" + + start_index: int = field(default=0) # inclusive + end_index: int = field(default=0) # inclusive + token_count: int = field(default=0) + + def __deepcopy__(self, _) -> "Checkpoint": # noqa: ANN001 + """ + Returns a deep copy of the Checkpoint object. + """ + return Checkpoint( + start_index=self.start_index, + end_index=self.end_index, + token_count=self.token_count, + ) + + +@define +class CheckpointData: + """Aggregates all information about checkpoints""" + + # the total number of tokens in the exchange. this is updated every time a checkpoint is + # added or removed + total_token_count: int = field(default=0) + + # in order list of individual checkpoints in the exchange + checkpoints: List[Checkpoint] = field(factory=list) + + # the offset to apply to the message index when calculating the last message index + # this is useful because messages on the exchange behave like a queue, where you can only + # pop from the left or right sides. This offset allows us to map the checkpoint indices + # to the correct message index, even if we have popped messages from the left side of + # the exchange in the past. we reset this offset to 0 when we empty the checkpoint data. + message_index_offset: int = field(default=0) + + def __deepcopy__(self, memo: dict) -> "CheckpointData": + """Returns a deep copy of the CheckpointData object.""" + return CheckpointData( + total_token_count=self.total_token_count, + checkpoints=deepcopy(self.checkpoints, memo), + message_index_offset=self.message_index_offset, + ) + + @property + def last_message_index(self) -> int: + if not self.checkpoints: + return -1 # we don't have enough information to know + return self.checkpoints[-1].end_index - self.message_index_offset + + def reset(self) -> None: + """Resets the checkpoint data to its initial state.""" + self.checkpoints = [] + self.message_index_offset = 0 + self.total_token_count = 0 + + def pop(self, index: int = -1) -> Checkpoint: + """Removes and returns the checkpoint at the given index.""" + popped_checkpoint = self.checkpoints.pop(index) + self.total_token_count = self.total_token_count - popped_checkpoint.token_count + return popped_checkpoint diff --git a/packages/exchange/src/exchange/content.py b/packages/exchange/src/exchange/content.py new file mode 100644 index 000000000..b9cc986fc --- /dev/null +++ b/packages/exchange/src/exchange/content.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, Optional + +from attrs import define, asdict + + +CONTENT_TYPES = {} + + +class Content: + def __init_subclass__(cls, **kwargs: Dict[str, Any]) -> None: + super().__init_subclass__(**kwargs) + CONTENT_TYPES[cls.__name__] = cls + + def to_dict(self) -> Dict[str, Any]: + data = asdict(self, recurse=True) + data["type"] = self.__class__.__name__ + return data + + +@define +class Text(Content): + text: str + + +@define +class ToolUse(Content): + id: str + name: str + parameters: Any + is_error: bool = False + error_message: Optional[str] = None + + +@define +class ToolResult(Content): + tool_use_id: str + output: str + is_error: bool = False diff --git a/packages/exchange/src/exchange/exchange.py b/packages/exchange/src/exchange/exchange.py new file mode 100644 index 000000000..b2fdbc5ec --- /dev/null +++ b/packages/exchange/src/exchange/exchange.py @@ -0,0 +1,336 @@ +import json +import traceback +from copy import deepcopy +from typing import Any, Dict, List, Mapping, Tuple + +from attrs import define, evolve, field, Factory +from tiktoken import get_encoding + +from exchange.checkpoint import Checkpoint, CheckpointData +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.moderators import Moderator +from exchange.moderators.truncate import ContextTruncate +from exchange.providers import Provider, Usage +from exchange.tool import Tool +from exchange.token_usage_collector import _token_usage_collector + + +def validate_tool_output(output: str) -> None: + """Validate tool output for the given model""" + max_output_chars = 2**20 + max_output_tokens = 16000 + encoder = get_encoding("cl100k_base") + if len(output) > max_output_chars or len(encoder.encode(output)) > max_output_tokens: + raise ValueError("This tool call created an output that was too long to handle!") + + +@define(frozen=True) +class Exchange: + """An exchange of messages with an LLM + + The exchange class is meant to be largely immutable, with only the message list + growing once constructed. Use .replace to alter the model, tools, etc. + + The exchange supports tool usage, calling tools and letting the model respond when + using the .reply method. It handles most forms of errors and sends those errors back + to the model, to let it attempt to recover. + """ + + provider: Provider + model: str + system: str + moderator: Moderator = field(default=ContextTruncate()) + tools: Tuple[Tool] = field(factory=tuple, converter=tuple) + messages: List[Message] = field(factory=list) + checkpoint_data: CheckpointData = field(factory=CheckpointData) + generation_args: dict = field(default=Factory(dict)) + + @property + def _toolmap(self) -> Mapping[str, Tool]: + return {tool.name: tool for tool in self.tools} + + def replace(self, **kwargs: Dict[str, Any]) -> "Exchange": + """Make a copy of the exchange, replacing any passed arguments""" + # TODO: ensure that the checkpoint data is updated correctly. aka, + # if we replace the messages, we need to update the checkpoint data + # if we change the model, we need to update the checkpoint data (?) + + if kwargs.get("messages") is None: + kwargs["messages"] = deepcopy(self.messages) + if kwargs.get("checkpoint_data") is None: + kwargs["checkpoint_data"] = deepcopy( + self.checkpoint_data, + ) + return evolve(self, **kwargs) + + def add(self, message: Message) -> None: + """Add a message to the history.""" + if self.messages and message.role == self.messages[-1].role: + raise ValueError("Messages in the exchange must alternate between user and assistant") + self.messages.append(message) + + def generate(self) -> Message: + """Generate the next message.""" + self.moderator.rewrite(self) + message, usage = self.provider.complete( + self.model, + self.system, + messages=self.messages, + tools=self.tools, + **self.generation_args, + ) + self.add(message) + self.add_checkpoints_from_usage(usage) # this has to come after adding the response + + # TODO: also call `rewrite` here, as this will make our + # messages *consistently* below the token limit. this currently + # is not the case because we could append a large message after calling + # `rewrite` above. + # self.moderator.rewrite(self) + + _token_usage_collector.collect(self.model, usage) + return message + + def reply(self, max_tool_use: int = 128) -> Message: + """Get the reply from the underlying model. + + This will process any requests for tool calls, calling them immediately, and + storing the intermediate tool messages in the queue. It will return after the + first response that does not request a tool use + + Args: + max_tool_use: The maximum number of tool calls to make before returning. Defaults to 128. + """ + if max_tool_use <= 0: + raise ValueError("max_tool_use must be greater than 0") + response = self.generate() + curr_iter = 1 # generate() already called once + while response.tool_use: + content = [] + for tool_use in response.tool_use: + tool_result = self.call_function(tool_use) + content.append(tool_result) + self.add(Message(role="user", content=content)) + + # We've reached the limit of tool calls - break out of the loop + if curr_iter >= max_tool_use: + # At this point, the most recent message is `Message(role='user', content=ToolResult(...))` + response = Message.assistant( + f"We've stopped executing additional tool cause because we reached the limit of {max_tool_use}", + ) + self.add(response) + break + else: + response = self.generate() + curr_iter += 1 + + return response + + def call_function(self, tool_use: ToolUse) -> ToolResult: + """Call the function indicated by the tool use""" + tool = self._toolmap.get(tool_use.name) + + if tool is None or tool_use.is_error: + output = f"ERROR: Failed to use tool {tool_use.id}.\nDo NOT use the same tool name and parameters again - that will lead to the same error." # noqa: E501 + + if tool_use.is_error: + output += f"\n{tool_use.error_message}" + elif tool is None: + valid_tool_names = ", ".join(self._toolmap.keys()) + output += f"\nNo tool exists with the name '{tool_use.name}'. Valid tool names are: {valid_tool_names}" + + return ToolResult(tool_use_id=tool_use.id, output=output, is_error=True) + + try: + if isinstance(tool_use.parameters, dict): + output = json.dumps(tool.function(**tool_use.parameters)) + elif isinstance(tool_use.parameters, list): + output = json.dumps(tool.function(*tool_use.parameters)) + else: + raise ValueError( + f"The provided tool parameters, {tool_use.parameters} could not be interpreted as a mapping of arguments." # noqa: E501 + ) + + validate_tool_output(output) + + is_error = False + except Exception as e: + tb = traceback.format_exc() + output = str(tb) + "\n" + str(e) + is_error = True + + return ToolResult(tool_use_id=tool_use.id, output=output, is_error=is_error) + + def add_tool_use(self, tool_use: ToolUse) -> None: + """Manually add a tool use and corresponding result + + This will call the implied function and add an assistant + message requesting the ToolUse and a user message with the ToolResult + """ + tool_result = self.call_function(tool_use) + self.add(Message(role="assistant", content=[tool_use])) + self.add(Message(role="user", content=[tool_result])) + + def add_checkpoints_from_usage(self, usage: Usage) -> None: + """ + Add checkpoints to the exchange based on the token counts of the last two + groups of messages, as well as the current token total count of the exchange + """ + # we know we just appended one message as the response from the LLM + # so we need to create two checkpoints as we know the token counts + # of the last two groups of messages: + # 1. from the last checkpoint to the most recent user message + # 2. the most recent assistant message + last_checkpoint_end_index = ( + self.checkpoint_data.checkpoints[-1].end_index - self.checkpoint_data.message_index_offset + if len(self.checkpoint_data.checkpoints) > 0 + else -1 + ) + new_start_index = last_checkpoint_end_index + 1 + + # here, our self.checkpoint_data.total_token_count is the previous total token count from the last time + # that we performed a request. if we subtract this value from the input_tokens from our + # latest response, we know how many tokens our **1** from above is. + first_block_token_count = usage.input_tokens - self.checkpoint_data.total_token_count + second_block_token_count = usage.output_tokens + + if len(self.messages) - new_start_index > 1: + # this will occur most of the time, as we will have one new user message and one + # new assistant message. + + self.checkpoint_data.checkpoints.append( + Checkpoint( + start_index=new_start_index + self.checkpoint_data.message_index_offset, + # end index below is equivalent to the second last message. why? becuase + # the last message is the assistant message that we add below. we need to also + # track the token count of the user message sent. + end_index=len(self.messages) - 2 + self.checkpoint_data.message_index_offset, + token_count=first_block_token_count, + ) + ) + self.checkpoint_data.checkpoints.append( + Checkpoint( + start_index=len(self.messages) - 1 + self.checkpoint_data.message_index_offset, + end_index=len(self.messages) - 1 + self.checkpoint_data.message_index_offset, + token_count=second_block_token_count, + ) + ) + + # TODO: check if the front of the checkpoints doesn't overlap with + # the first message. if so, we are missing checkpoint data from + # message[0] to message[checkpoint_data.checkpoints[0].start_index] + # we can fill in this data by performing an extra request and doing some math + self.checkpoint_data.total_token_count = usage.total_tokens + + def pop_last_message(self) -> Message: + """Pop the last message from the exchange, handling checkpoints correctly""" + if ( + len(self.checkpoint_data.checkpoints) > 0 + and self.checkpoint_data.last_message_index > len(self.messages) - 1 + ): + raise ValueError("Our checkpoint data is out of sync with our message data") + if ( + len(self.checkpoint_data.checkpoints) > 0 + and self.checkpoint_data.last_message_index == len(self.messages) - 1 + ): + # remove the last checkpoint, because we no longer know the token count of it's contents. + # note that this is not the same as reverting to the last checkpoint, as we want to + # keep the messages from the last checkpoint. they will have a new checkpoint created for + # them when we call generate() again + self.checkpoint_data.pop() + self.messages.pop() + + def pop_first_message(self) -> Message: + """Pop the first message from the exchange, handling checkpoints correctly""" + if len(self.messages) == 0: + raise ValueError("There are no messages to pop") + if len(self.checkpoint_data.checkpoints) == 0: + raise ValueError("There must be at least one checkpoint to pop the first message") + + # get the start and end indexes of the first checkpoint, use these to remove message + first_checkpoint = self.checkpoint_data.checkpoints[0] + first_checkpoint_start_index = first_checkpoint.start_index - self.checkpoint_data.message_index_offset + + # check if the first message is part of the first checkpoint + if first_checkpoint_start_index == 0: + # remove this checkpoint, as it no longer has any messages + self.checkpoint_data.pop(0) + + self.messages.pop(0) + self.checkpoint_data.message_index_offset += 1 + + if len(self.checkpoint_data.checkpoints) == 0: + # we've removed all the checkpoints, so we need to reset the message index offset + self.checkpoint_data.message_index_offset = 0 + + def pop_last_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + """ + Reverts the exchange back to the last checkpoint, removing associated messages + """ + removed_checkpoint = self.checkpoint_data.checkpoints.pop() + # pop messages until we reach the start of the next checkpoint + messages = [] + while len(self.messages) > removed_checkpoint.start_index - self.checkpoint_data.message_index_offset: + messages.append(self.messages.pop()) + return removed_checkpoint, messages + + def pop_first_checkpoint(self) -> Tuple[Checkpoint, List[Message]]: + """ + Pop the first checkpoint from the exchange, removing associated messages + """ + if len(self.checkpoint_data.checkpoints) == 0: + raise ValueError("There are no checkpoints to pop") + first_checkpoint = self.checkpoint_data.pop(0) + + # remove messages until we reach the start of the next checkpoint + messages = [] + stop_at_index = first_checkpoint.end_index - self.checkpoint_data.message_index_offset + for _ in range(stop_at_index + 1): # +1 because it's inclusive + messages.append(self.messages.pop(0)) + self.checkpoint_data.message_index_offset += 1 + + if len(self.checkpoint_data.checkpoints) == 0: + # we've removed all the checkpoints, so we need to reset the message index offset + self.checkpoint_data.message_index_offset = 0 + return first_checkpoint, messages + + def prepend_checkpointed_message(self, message: Message, token_count: int) -> None: + """Prepend a message to the exchange, updating the checkpoint data""" + self.messages.insert(0, message) + new_index = max(0, self.checkpoint_data.message_index_offset - 1) + self.checkpoint_data.checkpoints.insert( + 0, + Checkpoint( + start_index=new_index, + end_index=new_index, + token_count=token_count, + ), + ) + self.checkpoint_data.message_index_offset = new_index + + def rewind(self) -> None: + if not self.messages: + return + + # we remove messages until we find the last user text message + while not (self.messages[-1].role == "user" and type(self.messages[-1].content[-1]) is Text): + self.pop_last_message() + + # now we remove that last user text message, putting us at a good point + # to ask the user for their input again + if self.messages: + self.pop_last_message() + + @property + def is_allowed_to_call_llm(self) -> bool: + """ + Returns True if the exchange is allowed to call the LLM, False otherwise + """ + # TODO: reconsider whether this function belongs here and whether it is necessary + # Some models will have different requirements than others, so it may be better for + # this to be a required method of the provider instead. + return len(self.messages) > 0 and self.messages[-1].role == "user" + + def get_token_usage(self) -> Dict[str, Usage]: + return _token_usage_collector.get_token_usage_group_by_model() diff --git a/packages/exchange/src/exchange/invalid_choice_error.py b/packages/exchange/src/exchange/invalid_choice_error.py new file mode 100644 index 000000000..ffbb9899f --- /dev/null +++ b/packages/exchange/src/exchange/invalid_choice_error.py @@ -0,0 +1,13 @@ +from typing import List + + +class InvalidChoiceError(Exception): + def __init__(self, attribute_name: str, attribute_value: str, available_values: List[str]) -> None: + self.attribute_name = attribute_name + self.attribute_value = attribute_value + self.available_values = available_values + self.message = ( + f"Unknown {attribute_name}: {attribute_value}." + + f" Available {attribute_name}s: {', '.join(available_values)}" + ) + super().__init__(self.message) diff --git a/packages/exchange/src/exchange/message.py b/packages/exchange/src/exchange/message.py new file mode 100644 index 000000000..035c60345 --- /dev/null +++ b/packages/exchange/src/exchange/message.py @@ -0,0 +1,121 @@ +import inspect +import time +from pathlib import Path +from typing import Any, Dict, List, Literal, Type + +from attrs import define, field +from jinja2 import Environment, FileSystemLoader + +from exchange.content import CONTENT_TYPES, Content, Text, ToolResult, ToolUse +from exchange.utils import create_object_id + +Role = Literal["user", "assistant"] + + +def validate_role_and_content(instance: "Message", *_: Any) -> None: # noqa: ANN401 + if instance.role == "user": + if not (instance.text or instance.tool_result): + raise ValueError("User message must include a Text or ToolResult") + if instance.tool_use: + raise ValueError("User message does not support ToolUse") + elif instance.role == "assistant": + if not (instance.text or instance.tool_use): + raise ValueError("Assistant message must include a Text or ToolUsage") + if instance.tool_result: + raise ValueError("Assistant message does not support ToolResult") + + +def content_converter(contents: List[Dict[str, Any]]) -> List[Content]: + return [(CONTENT_TYPES[c.pop("type")](**c) if c.__class__ not in CONTENT_TYPES.values() else c) for c in contents] + + +@define +class Message: + """A message to or from a language model. + + This supports several content types to extend to tool usage and (tbi) images. + + We also provide shortcuts for simplified text usage; these two are identical: + ``` + m = Message(role='user', content=[Text(text='abcd')]) + assert m.content[0].text == 'abcd' + + m = Message.user('abcd') + assert m.text == 'abcd' + ``` + """ + + role: Role = field(default="user") + id: str = field(factory=lambda: str(create_object_id(prefix="msg"))) + created: int = field(factory=lambda: int(time.time())) + content: List[Content] = field(factory=list, validator=validate_role_and_content, converter=content_converter) + + def to_dict(self) -> Dict[str, Any]: + return { + "role": self.role, + "id": self.id, + "created": self.created, + "content": [item.to_dict() for item in self.content], + } + + @property + def text(self) -> str: + """The text content of this message.""" + result = [] + for content in self.content: + if isinstance(content, Text): + result.append(content.text) + return "\n".join(result) + + @property + def tool_use(self) -> List[ToolUse]: + """All tool use content of this message.""" + result = [] + for content in self.content: + if isinstance(content, ToolUse): + result.append(content) + return result + + @property + def tool_result(self) -> List[ToolResult]: + """All tool result content of this message.""" + result = [] + for content in self.content: + if isinstance(content, ToolResult): + result.append(content) + return result + + @classmethod + def load( + cls: Type["Message"], + filename: str, + role: Role = "user", + **kwargs: Dict[str, Any], + ) -> "Message": + """Load the message from filename relative to where the load is called. + + This only supports simplified content, with a single text entry + + This is meant to emulate importing code rather than a runtime filesystem. So + if you have a directory of code that contains example.py, and example.py has + a function that calls User.load('example.jinja'), it will look in the same + directory as example.py for the jinja file. + """ + frm = inspect.stack()[1] + mod = inspect.getmodule(frm[0]) + + base_path = Path(mod.__file__).parent + + env = Environment(loader=FileSystemLoader(base_path)) + template = env.get_template(filename) + rendered_content = template.render(**kwargs) + + return cls(role=role, content=[Text(text=rendered_content)]) + + @classmethod + def user(cls: Type["Message"], text: str) -> "Message": + return cls(role="user", content=[Text(text)]) + + @classmethod + def assistant(cls: Type["Message"], text: str) -> "Message": + return cls(role="assistant", content=[Text(text)]) diff --git a/packages/exchange/src/exchange/moderators/__init__.py b/packages/exchange/src/exchange/moderators/__init__.py new file mode 100644 index 000000000..82d032e42 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/__init__.py @@ -0,0 +1,17 @@ +from functools import cache +from typing import Type + +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.moderators.base import Moderator +from exchange.utils import load_plugins +from exchange.moderators.passive import PassiveModerator # noqa +from exchange.moderators.truncate import ContextTruncate # noqa +from exchange.moderators.summarizer import ContextSummarizer # noqa + + +@cache +def get_moderator(name: str) -> Type[Moderator]: + moderators = load_plugins(group="exchange.moderator") + if name not in moderators: + raise InvalidChoiceError("moderator", name, moderators.keys()) + return moderators[name] diff --git a/packages/exchange/src/exchange/moderators/base.py b/packages/exchange/src/exchange/moderators/base.py new file mode 100644 index 000000000..d7c630c6a --- /dev/null +++ b/packages/exchange/src/exchange/moderators/base.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Type + + +class Moderator(ABC): + @abstractmethod + def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + pass diff --git a/packages/exchange/src/exchange/moderators/passive.py b/packages/exchange/src/exchange/moderators/passive.py new file mode 100644 index 000000000..e3a24efbd --- /dev/null +++ b/packages/exchange/src/exchange/moderators/passive.py @@ -0,0 +1,7 @@ +from typing import Type +from exchange.moderators.base import Moderator + + +class PassiveModerator(Moderator): + def rewrite(self, _: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + pass diff --git a/packages/exchange/src/exchange/moderators/summarizer.jinja b/packages/exchange/src/exchange/moderators/summarizer.jinja new file mode 100644 index 000000000..00c29ed82 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/summarizer.jinja @@ -0,0 +1,9 @@ +You are an expert technical summarizer. + +During your conversation with the user, you may be asked to summarize the content in you conversational history. +When asked to summarize, you should concisely summarize the conversation giving emphasis to newer content. Newer content will be towards the end of the conversation. +Preferentially keep user supplied content in the summary. + +The summary *MUST* include filenames that were touched and/or modified. If the updates occurred more recently, keep the latest modifications made to the files in the summary. If the changes occurred earlier in the chat, briefly summarize the changes and don't include the changes in the summary. + +There will likely be json formatted blocks referencing ToolUse and ToolResults. You can ignore ToolUse references, but keep the ToolResult outputs, summarizing as needed and with the same guidelines as above. diff --git a/packages/exchange/src/exchange/moderators/summarizer.py b/packages/exchange/src/exchange/moderators/summarizer.py new file mode 100644 index 000000000..7e2dd5588 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/summarizer.py @@ -0,0 +1,46 @@ +from typing import Type + +from exchange import Message +from exchange.checkpoint import CheckpointData +from exchange.moderators import ContextTruncate, PassiveModerator + + +class ContextSummarizer(ContextTruncate): + def rewrite(self, exchange: Type["exchange.exchange.Exchange"]) -> None: # noqa: F821 + """Summarize the context history up to the last few messages in the exchange""" + + self._update_system_prompt_token_count(exchange) + + # TODO: use an offset for summarization + if exchange.checkpoint_data.total_token_count < self.max_tokens: + return + + messages_to_summarize = self._get_messages_to_remove(exchange) + num_messages_to_remove = len(messages_to_summarize) + + # the llm will throw an error if the last message isn't a user message + if messages_to_summarize[-1].role == "assistant" and (not messages_to_summarize[-1].tool_use): + messages_to_summarize.append(Message.user("Summarize our the above conversation")) + + summarizer_exchange = exchange.replace( + system=Message.load("summarizer.jinja").text, + moderator=PassiveModerator(), + model=self.model, + messages=messages_to_summarize, + checkpoint_data=CheckpointData(), + ) + + # get the summarized content and the tokens associated with this content + summary = summarizer_exchange.reply() + summary_checkpoint = summarizer_exchange.checkpoint_data.checkpoints[-1] + + # remove the checkpoints that were summarized from the original exchange + for _ in range(num_messages_to_remove): + exchange.pop_first_message() + + # insert summary as first message/checkpoint + if len(exchange.messages) == 0 or exchange.messages[0].role == "assistant": + summary_message = Message.user(summary.text) + else: + summary_message = Message.assistant(summary.text) + exchange.prepend_checkpointed_message(summary_message, summary_checkpoint.token_count) diff --git a/packages/exchange/src/exchange/moderators/truncate.py b/packages/exchange/src/exchange/moderators/truncate.py new file mode 100644 index 000000000..41115f663 --- /dev/null +++ b/packages/exchange/src/exchange/moderators/truncate.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +from exchange.checkpoint import CheckpointData +from exchange.message import Message +from exchange.moderators import PassiveModerator +from exchange.moderators.base import Moderator + +if TYPE_CHECKING: + from exchange.exchange import Exchange + +# currently this is the point at which we start to truncate, so +# so once we get to this token size the token count will exceed this +# by a little bit. +# TODO: make this configurable for each provider +MAX_TOKENS = 100000 + + +class ContextTruncate(Moderator): + def __init__( + self, + model: Optional[str] = None, + max_tokens: int = MAX_TOKENS, + ) -> None: + self.model = model + self.system_prompt_token_count = 0 + self.max_tokens = max_tokens + self.last_system_prompt = None + + def rewrite(self, exchange: Exchange) -> None: + """Truncate the exchange messages with a FIFO strategy.""" + self._update_system_prompt_token_count(exchange) + + if exchange.checkpoint_data.total_token_count < self.max_tokens: + return + + messages_to_remove = self._get_messages_to_remove(exchange) + for _ in range(len(messages_to_remove)): + exchange.pop_first_message() + + def _update_system_prompt_token_count(self, exchange: Exchange) -> None: + is_different_system_prompt = False + if self.last_system_prompt != exchange.system: + is_different_system_prompt = True + self.last_system_prompt = exchange.system + + if not self.system_prompt_token_count or is_different_system_prompt: + # calculate the system prompt tokens (includes functions etc...) + # we use a placeholder message with one token, which we subtract later + # this ensures compatibility with providers that require a user message + _system_token_exchange = exchange.replace( + messages=[Message.user("a")], + checkpoint_data=CheckpointData(), + moderator=PassiveModerator(), + model=self.model if self.model else exchange.model, + ) + _system_token_exchange.generate() + last_system_prompt_token_count = self.system_prompt_token_count + self.system_prompt_token_count = _system_token_exchange.checkpoint_data.total_token_count - 1 + + exchange.checkpoint_data.total_token_count -= last_system_prompt_token_count + exchange.checkpoint_data.total_token_count += self.system_prompt_token_count + + def _get_messages_to_remove(self, exchange: Exchange) -> List[Message]: + # this keeps all the messages/checkpoints + throwaway_exchange = exchange.replace( + moderator=PassiveModerator(), + ) + + # get the messages that we want to remove + messages_to_remove = [] + while throwaway_exchange.checkpoint_data.total_token_count > self.max_tokens: + _, messages = throwaway_exchange.pop_first_checkpoint() + messages_to_remove.extend(messages) + + while len(throwaway_exchange.messages) > 0 and throwaway_exchange.messages[0].tool_result: + # we would need a corresponding tool use once we resume, so we pop this one off too + # and summarize it as well + _, messages = throwaway_exchange.pop_first_checkpoint() + messages_to_remove.extend(messages) + return messages_to_remove diff --git a/packages/exchange/src/exchange/providers/__init__.py b/packages/exchange/src/exchange/providers/__init__.py new file mode 100644 index 000000000..f92d4f769 --- /dev/null +++ b/packages/exchange/src/exchange/providers/__init__.py @@ -0,0 +1,21 @@ +from functools import cache +from typing import Type + +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers.anthropic import AnthropicProvider # noqa +from exchange.providers.base import Provider, Usage # noqa +from exchange.providers.databricks import DatabricksProvider # noqa +from exchange.providers.openai import OpenAiProvider # noqa +from exchange.providers.ollama import OllamaProvider # noqa +from exchange.providers.azure import AzureProvider # noqa +from exchange.providers.google import GoogleProvider # noqa + +from exchange.utils import load_plugins + + +@cache +def get_provider(name: str) -> Type[Provider]: + providers = load_plugins(group="exchange.provider") + if name not in providers: + raise InvalidChoiceError("provider", name, providers.keys()) + return providers[name] diff --git a/packages/exchange/src/exchange/providers/anthropic.py b/packages/exchange/src/exchange/providers/anthropic.py new file mode 100644 index 000000000..84ecd12fb --- /dev/null +++ b/packages/exchange/src/exchange/providers/anthropic.py @@ -0,0 +1,160 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange import Message, Tool +from exchange.content import Text, ToolResult, ToolUse +from exchange.providers.base import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status, raise_for_status + +ANTHROPIC_HOST = "https://api.anthropic.com/v1/messages" + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class AnthropicProvider(Provider): + """Provides chat completions for models hosted directly by Anthropic.""" + + PROVIDER_NAME = "anthropic" + REQUIRED_ENV_VARS = ["ANTHROPIC_API_KEY"] + + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["AnthropicProvider"]) -> "AnthropicProvider": + cls.check_env_vars() + url = os.environ.get("ANTHROPIC_HOST", ANTHROPIC_HOST) + key = os.environ.get("ANTHROPIC_API_KEY") + client = httpx.Client( + base_url=url, + headers={ + "x-api-key": key, + "content-type": "application/json", + "anthropic-version": "2023-06-01", + }, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: Dict) -> Usage: # noqa: ANN401 + usage = data.get("usage") + input_tokens = usage.get("input_tokens") + output_tokens = usage.get("output_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + @staticmethod + def anthropic_response_to_message(response: Dict) -> Message: + content_blocks = response.get("content", []) + content = [] + for block in content_blocks: + if block["type"] == "text": + content.append(Text(text=block["text"])) + elif block["type"] == "tool_use": + content.append( + ToolUse( + id=block["id"], + name=block["name"], + parameters=block["input"], + ) + ) + return Message(role="assistant", content=content) + + @staticmethod + def tools_to_anthropic_spec(tools: Tuple[Tool]) -> List[Dict[str, Any]]: + return [ + { + "name": tool.name, + "description": tool.description or "", + "input_schema": tool.parameters, + } + for tool in tools + ] + + @staticmethod + def messages_to_anthropic_spec(messages: List[Message]) -> List[Dict[str, Any]]: + messages_spec = [] + # if messages is empty - just make a default + for message in messages: + converted = {"role": message.role} + for content in message.content: + if isinstance(content, Text): + converted["content"] = [{"type": "text", "text": content.text}] + elif isinstance(content, ToolUse): + converted.setdefault("content", []).append( + { + "type": "tool_use", + "id": content.id, + "name": content.name, + "input": content.parameters, + } + ) + elif isinstance(content, ToolResult): + converted.setdefault("content", []).append( + { + "type": "tool_result", + "tool_use_id": content.tool_use_id, + "content": content.output, + } + ) + messages_spec.append(converted) + if len(messages_spec) == 0: + converted = { + "role": "user", + "content": [{"type": "text", "text": "Ignore"}], + } + messages_spec.append(converted) + return messages_spec + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: List[Tool] = [], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + tools_set = set() + unique_tools = [] + for tool in tools: + if tool.name not in tools_set: + unique_tools.append(tool) + tools_set.add(tool.name) + + payload = dict( + system=system, + model=model, + max_tokens=4096, + messages=self.messages_to_anthropic_spec(messages), + tools=self.tools_to_anthropic_spec(tuple(unique_tools)), + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + + response = self._post(payload) + message = self.anthropic_response_to_message(response) + usage = self.get_usage(response) + + return message, usage + + @retry_procedure + def _post(self, payload: dict) -> httpx.Response: + response = self.client.post(ANTHROPIC_HOST, json=payload) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/azure.py b/packages/exchange/src/exchange/providers/azure.py new file mode 100644 index 000000000..4d470f978 --- /dev/null +++ b/packages/exchange/src/exchange/providers/azure.py @@ -0,0 +1,39 @@ +from typing import Type + +import httpx +import os + +from exchange.providers import OpenAiProvider + + +class AzureProvider(OpenAiProvider): + """Provides chat completions for models hosted by the Azure OpenAI Service.""" + + PROVIDER_NAME = "azure" + REQUIRED_ENV_VARS = [ + "AZURE_CHAT_COMPLETIONS_HOST_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", + "AZURE_CHAT_COMPLETIONS_KEY", + ] + + def __init__(self, client: httpx.Client) -> None: + super().__init__(client) + + @classmethod + def from_env(cls: Type["AzureProvider"]) -> "AzureProvider": + cls.check_env_vars() + url = os.environ.get("AZURE_CHAT_COMPLETIONS_HOST_NAME") + deployment_name = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME") + api_version = os.environ.get("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION") + key = os.environ.get("AZURE_CHAT_COMPLETIONS_KEY") + + # format the url host/"openai/deployments/" + deployment_name + "/?api-version=" + api_version + url = f"{url}/openai/deployments/{deployment_name}/" + client = httpx.Client( + base_url=url, + headers={"api-key": key, "Content-Type": "application/json"}, + params={"api-version": api_version}, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) diff --git a/packages/exchange/src/exchange/providers/base.py b/packages/exchange/src/exchange/providers/base.py new file mode 100644 index 000000000..c8d860ecc --- /dev/null +++ b/packages/exchange/src/exchange/providers/base.py @@ -0,0 +1,51 @@ +import os +from abc import ABC, abstractmethod +from attrs import define, field +from typing import List, Optional, Tuple, Type + +from exchange.message import Message +from exchange.tool import Tool + + +@define(hash=True) +class Usage: + input_tokens: int = field(factory=None) + output_tokens: int = field(default=None) + total_tokens: int = field(default=None) + + +class Provider(ABC): + PROVIDER_NAME: str + REQUIRED_ENV_VARS: list[str] = [] + + @classmethod + def from_env(cls: Type["Provider"]) -> "Provider": + return cls() + + @classmethod + def check_env_vars(cls: Type["Provider"], instructions_url: Optional[str] = None) -> None: + for env_var in cls.REQUIRED_ENV_VARS: + if env_var not in os.environ: + raise MissingProviderEnvVariableError(env_var, cls.PROVIDER_NAME, instructions_url) + + @abstractmethod + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + ) -> Tuple[Message, Usage]: + """Generate the next message using the specified model""" + pass + + +class MissingProviderEnvVariableError(Exception): + def __init__(self, env_variable: str, provider: str, instructions_url: Optional[str] = None) -> None: + self.env_variable = env_variable + self.provider = provider + self.instructions_url = instructions_url + self.message = f"Missing environment variable: {env_variable} for provider {provider}." + if instructions_url: + self.message += f"\nPlease see {instructions_url} for instructions" + super().__init__(self.message) diff --git a/packages/exchange/src/exchange/providers/bedrock.py b/packages/exchange/src/exchange/providers/bedrock.py new file mode 100644 index 000000000..6c32d7cb3 --- /dev/null +++ b/packages/exchange/src/exchange/providers/bedrock.py @@ -0,0 +1,334 @@ +import hashlib +import hmac +import json +import logging +import os +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Tuple, Type +from urllib.parse import quote, urlparse + +import httpx + +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.providers import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.tool import Tool + +SERVICE = "bedrock-runtime" +UTC = timezone.utc + +logger = logging.getLogger(__name__) + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class AwsClient(httpx.Client): + def __init__( + self, + aws_region: str, + aws_access_key: str, + aws_secret_key: str, + aws_session_token: Optional[str] = None, + **kwargs: Dict[str, Any], + ) -> None: + self.region = aws_region + self.host = f"https://{SERVICE}.{aws_region}.amazonaws.com/" + self.access_key = aws_access_key + self.secret_key = aws_secret_key + self.session_token = aws_session_token + super().__init__(base_url=self.host, timeout=600, **kwargs) + + def post(self, path: str, json: Dict, **kwargs: Dict[str, Any]) -> httpx.Response: + signed_headers = self.sign_and_get_headers( + method="POST", + url=path, + payload=json, + service="bedrock", + ) + return super().post(url=path, json=json, headers=signed_headers, **kwargs) + + def sign_and_get_headers( + self, + method: str, + url: str, + payload: dict, + service: str, + ) -> Dict[str, str]: + """ + Sign the request and generate the necessary headers for AWS authentication. + + Args: + method (str): HTTP method (e.g., 'GET', 'POST'). + url (str): The request URL. + payload (dict): The request payload. + service (str): The AWS service name. + region (str): The AWS region. + access_key (str): The AWS access key. + secret_key (str): The AWS secret key. + session_token (Optional[str]): The AWS session token, if any. + + Returns: + Dict[str, str]: The headers required for the request. + """ + + def sign(key: bytes, msg: str) -> bytes: + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: str) -> bytes: + k_date = sign(("AWS4" + key).encode("utf-8"), date_stamp) + k_region = sign(k_date, region_name) + k_service = sign(k_region, service_name) + k_signing = sign(k_service, "aws4_request") + return k_signing + + # Convert payload to JSON string + request_parameters = json.dumps(payload) + + # Create a date for headers and the credential string + t = datetime.now(UTC) + amz_date = t.strftime("%Y%m%dT%H%M%SZ") + date_stamp = t.strftime("%Y%m%d") # Date w/o time, used in credential scope + + # Create canonical URI and headers + parsedurl = urlparse(url) + canonical_uri = quote(parsedurl.path if parsedurl.path else "/", safe="/-_.~") + canonical_headers = f"host:{parsedurl.netloc}\nx-amz-date:{amz_date}\n" + + # Create the list of signed headers. + signed_headers = "host;x-amz-date" + if self.session_token: + canonical_headers += "x-amz-security-token:" + self.session_token + "\n" + signed_headers += ";x-amz-security-token" + + # Create payload hash + payload_hash = hashlib.sha256(request_parameters.encode("utf-8")).hexdigest() + + # Canonical request + canonical_request = f"{method}\n{canonical_uri}\n\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + + # Create the string to sign + algorithm = "AWS4-HMAC-SHA256" + credential_scope = f"{date_stamp}/{self.region}/{service}/aws4_request" + string_to_sign = ( + f"{algorithm}\n{amz_date}\n{credential_scope}\n" + f'{hashlib.sha256(canonical_request.encode("utf-8")).hexdigest()}' + ) + + # Create the signing key + signing_key = get_signature_key(self.secret_key, date_stamp, self.region, service) + + # Sign the string_to_sign using the signing key + signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() + + # Add signing information to the request + authorization_header = ( + f"{algorithm} Credential={self.access_key}/{credential_scope}, SignedHeaders={signed_headers}, " + f"Signature={signature}" + ) + + # Headers + headers = { + "Content-Type": "application/json", + "Authorization": authorization_header, + "X-Amz-date": amz_date.encode(), + "x-amz-content-sha256": payload_hash, + } + if self.session_token: + headers["X-Amz-Security-Token"] = self.session_token + + return headers + + +class BedrockProvider(Provider): + """Provides chat completions for models hosted by the Amazon Bedrock Service""" + + PROVIDER_NAME = "bedrock" + REQUIRED_ENV_VARS = [ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + ] + + def __init__(self, client: AwsClient) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["BedrockProvider"]) -> "BedrockProvider": + cls.check_env_vars() + aws_region = os.environ.get("AWS_REGION", "us-east-1") + aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY") + aws_session_token = os.environ.get("AWS_SESSION_TOKEN") + + client = AwsClient( + aws_region=aws_region, + aws_access_key=aws_access_key, + aws_secret_key=aws_secret_key, + aws_session_token=aws_session_token, + ) + return cls(client=client) + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + """ + Generate a completion response from the Bedrock gateway. + + Args: + model (str): The model identifier. + system (str): The system prompt or configuration. + messages (List[Message]): A list of messages to be processed by the model. + tools (Tuple[Tool]): A tuple of tools to be used in the completion process. + **kwargs: Additional keyword arguments for inference configuration. + + Returns: + Tuple[Message, Usage]: A tuple containing the response message and usage data. + """ + + inference_config = dict( + temperature=kwargs.pop("temperature", None), + maxTokens=kwargs.pop("max_tokens", None), + stopSequences=kwargs.pop("stop", None), + topP=kwargs.pop("topP", None), + ) + inference_config = {k: v for k, v in inference_config.items() if v is not None} or None + + converted_messages = [self.message_to_bedrock_spec(message) for message in messages] + converted_system = [dict(text=system)] + tool_config = self.tools_to_bedrock_spec(tools) + payload = dict( + system=converted_system, + inferenceConfig=inference_config, + messages=converted_messages, + toolConfig=tool_config, + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + + path = f"{self.client.host}model/{model}/converse" + response = self._post(payload, path) + response_message = response["output"]["message"] + + usage_data = response["usage"] + usage = Usage( + input_tokens=usage_data.get("inputTokens"), + output_tokens=usage_data.get("outputTokens"), + total_tokens=usage_data.get("totalTokens"), + ) + + return self.response_to_message(response_message), usage + + @retry_procedure + def _post(self, payload: Any, path: str) -> dict: # noqa: ANN401 + response = self.client.post(path, json=payload) + return raise_for_status(response).json() + + @staticmethod + def message_to_bedrock_spec(message: Message) -> dict: + bedrock_content = [] + try: + for content in message.content: + if isinstance(content, Text): + bedrock_content.append({"text": content.text}) + elif isinstance(content, ToolUse): + for tool_use in message.tool_use: + bedrock_content.append( + { + "toolUse": { + "toolUseId": tool_use.id, + "name": tool_use.name, + "input": tool_use.parameters, + } + } + ) + elif isinstance(content, ToolResult): + for tool_result in message.tool_result: + # try to parse the output as json + try: + output = json.loads(tool_result.output) + if isinstance(output, dict): + content = [{"json": output}] + else: + content = [{"text": str(output)}] + except json.JSONDecodeError: + content = [{"text": tool_result.output}] + + bedrock_content.append( + { + "toolResult": { + "toolUseId": tool_result.tool_use_id, + "content": content, + **({"status": "error"} if tool_result.is_error else {}), + } + } + ) + return {"role": message.role, "content": bedrock_content} + + except AttributeError: + raise Exception("Invalid message") + + @staticmethod + def response_to_message(response_message: dict) -> Message: + content = [] + if response_message["role"] == "user": + for block in response_message["content"]: + if "text" in block: + content.append(Text(block["text"])) + if "toolResult" in block: + content.append( + ToolResult( + tool_use_id=block["toolResult"]["toolResultId"], + output=block["toolResult"]["content"][0]["json"], + is_error=block["toolResult"].get("status") == "error", + ) + ) + return Message(role="user", content=content) + elif response_message["role"] == "assistant": + for block in response_message["content"]: + if "text" in block: + content.append(Text(block["text"])) + if "toolUse" in block: + content.append( + ToolUse( + id=block["toolUse"]["toolUseId"], + name=block["toolUse"]["name"], + parameters=block["toolUse"]["input"], + ) + ) + return Message(role="assistant", content=content) + raise Exception("Invalid response") + + @staticmethod + def tools_to_bedrock_spec(tools: Tuple[Tool]) -> Optional[dict]: + if len(tools) == 0: + return None # API requires a non-empty tool config or None + tools_added = set() + tool_config_list = [] + for tool in tools: + if tool.name in tools_added: + logging.warning(f"Tool {tool.name} already added to tool config. Skipping.") + continue + tool_config_list.append( + { + "toolSpec": { + "name": tool.name, + "description": tool.description, + "inputSchema": {"json": tool.parameters}, + } + } + ) + tools_added.add(tool.name) + tool_config = {"tools": tool_config_list} + return tool_config diff --git a/packages/exchange/src/exchange/providers/databricks.py b/packages/exchange/src/exchange/providers/databricks.py new file mode 100644 index 000000000..9bd582dc5 --- /dev/null +++ b/packages/exchange/src/exchange/providers/databricks.py @@ -0,0 +1,100 @@ +from typing import Any, Dict, List, Tuple, Type + +import httpx +import os + +from exchange.message import Message +from exchange.providers.base import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import raise_for_status, retry_if_status +from exchange.providers.utils import ( + messages_to_openai_spec, + openai_response_to_message, + tools_to_openai_spec, +) +from exchange.tool import Tool + + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class DatabricksProvider(Provider): + """Provides chat completions for models on Databricks serving endpoints. + + Models are expected to follow the llm/v1/chat "task". This includes support + for foundation and external model endpoints + https://docs.databricks.com/en/machine-learning/model-serving/create-foundation-model-endpoints.html#create-generative-ai-model-serving-endpoints + + """ + + PROVIDER_NAME = "databricks" + REQUIRED_ENV_VARS = [ + "DATABRICKS_HOST", + "DATABRICKS_TOKEN", + ] + instructions_url = "https://docs.databricks.com/en/dev-tools/auth/index.html#general-host-token-and-account-id-environment-variables-and-fields" + + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["DatabricksProvider"]) -> "DatabricksProvider": + cls.check_env_vars(cls.instructions_url) + url = os.environ.get("DATABRICKS_HOST") + key = os.environ.get("DATABRICKS_TOKEN") + client = httpx.Client( + base_url=url, + auth=("token", key), + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: dict) -> Usage: + usage = data.pop("usage") + input_tokens = usage.get("prompt_tokens") + output_tokens = usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + payload = dict( + messages=[ + {"role": "system", "content": system}, + *messages_to_openai_spec(messages), + ], + tools=tools_to_openai_spec(tools) if tools else [], + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(model, payload) + message = openai_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, model: str, payload: dict) -> httpx.Response: + response = self.client.post( + f"serving-endpoints/{model}/invocations", + json=payload, + ) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/google.py b/packages/exchange/src/exchange/providers/google.py new file mode 100644 index 000000000..fe83cd605 --- /dev/null +++ b/packages/exchange/src/exchange/providers/google.py @@ -0,0 +1,154 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange import Message, Tool +from exchange.content import Text, ToolResult, ToolUse +from exchange.providers.base import Provider, Usage +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import raise_for_status, retry_if_status + +GOOGLE_HOST = "https://generativelanguage.googleapis.com/v1beta" + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class GoogleProvider(Provider): + """Provides chat completions for models hosted by Google, including Gemini and other experimental models.""" + + PROVIDER_NAME = "google" + REQUIRED_ENV_VARS = ["GOOGLE_API_KEY"] + instructions_url = "https://ai.google.dev/gemini-api/docs/api-key" + + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["GoogleProvider"]) -> "GoogleProvider": + cls.check_env_vars(cls.instructions_url) + url = os.environ.get("GOOGLE_HOST", GOOGLE_HOST) + key = os.environ.get("GOOGLE_API_KEY") + client = httpx.Client( + base_url=url, + headers={ + "Content-Type": "application/json", + }, + params={"key": key}, + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: Dict) -> Usage: # noqa: ANN401 + usage = data.get("usageMetadata") + input_tokens = usage.get("promptTokenCount") + output_tokens = usage.get("candidatesTokenCount") + total_tokens = usage.get("totalTokenCount") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + @staticmethod + def google_response_to_message(response: Dict) -> Message: + candidates = response.get("candidates", []) + if candidates: + # Only use first candidate for now + candidate = candidates[0] + content_parts = candidate.get("content", {}).get("parts", []) + content = [] + for part in content_parts: + if "text" in part: + content.append(Text(text=part["text"])) + elif "functionCall" in part: + content.append( + ToolUse( + id=part["functionCall"].get("name", ""), + name=part["functionCall"].get("name", ""), + parameters=part["functionCall"].get("args", {}), + ) + ) + return Message(role="assistant", content=content) + + # If no valid candidates were found, return an empty message + return Message(role="assistant", content=[]) + + @staticmethod + def tools_to_google_spec(tools: Tuple[Tool]) -> Dict[str, List[Dict[str, Any]]]: + if not tools: + return {} + converted_tools = [] + for tool in tools: + converted_tool: Dict[str, Any] = { + "name": tool.name, + "description": tool.description or "", + } + if tool.parameters["properties"]: + converted_tool["parameters"] = tool.parameters + converted_tools.append(converted_tool) + return {"functionDeclarations": converted_tools} + + @staticmethod + def messages_to_google_spec(messages: List[Message]) -> List[Dict[str, Any]]: + messages_spec = [] + for message in messages: + role = "user" if message.role == "user" else "model" + converted = {"role": role, "parts": []} + for content in message.content: + if isinstance(content, Text): + converted["parts"].append({"text": content.text}) + elif isinstance(content, ToolUse): + converted["parts"].append({"functionCall": {"name": content.name, "args": content.parameters}}) + elif isinstance(content, ToolResult): + converted["parts"].append( + {"functionResponse": {"name": content.tool_use_id, "response": {"content": content.output}}} + ) + messages_spec.append(converted) + + if not messages_spec: + messages_spec.append({"role": "user", "parts": [{"text": "Ignore"}]}) + + return messages_spec + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: List[Tool] = [], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + tools_set = set() + unique_tools = [] + for tool in tools: + if tool.name not in tools_set: + unique_tools.append(tool) + tools_set.add(tool.name) + + payload = dict( + system_instruction={"parts": [{"text": system}]}, + contents=self.messages_to_google_spec(messages), + tools=self.tools_to_google_spec(tuple(unique_tools)), + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(payload, model) + message = self.google_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, payload: dict, model: str) -> httpx.Response: + response = self.client.post("models/" + model + ":generateContent", json=payload) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/ollama.py b/packages/exchange/src/exchange/providers/ollama.py new file mode 100644 index 000000000..888564640 --- /dev/null +++ b/packages/exchange/src/exchange/providers/ollama.py @@ -0,0 +1,45 @@ +import os +from typing import Type + +import httpx + +from exchange.providers.openai import OpenAiProvider + +OLLAMA_HOST = "http://localhost:11434/" +OLLAMA_MODEL = "mistral-nemo" + + +class OllamaProvider(OpenAiProvider): + """Provides chat completions for models hosted by Ollama.""" + + __doc__ += """Here's an example profile configuration to try: + +First run: ollama pull qwen2.5, then use this profile: + +ollama: + provider: ollama + processor: qwen2.5 + accelerator: qwen2.5 + moderator: truncate + toolkits: + - name: developer + requires: {} +""" + + def __init__(self, client: httpx.Client) -> None: + print("PLEASE NOTE: the ollama provider is experimental, use with care") + super().__init__(client) + + @classmethod + def from_env(cls: Type["OllamaProvider"]) -> "OllamaProvider": + ollama_url = os.environ.get("OLLAMA_HOST", OLLAMA_HOST) + timeout = httpx.Timeout(60 * 10) + + # from_env is expected to fail if required ENV variables are not + # available. Since this provider can run with defaults, we substitute + # an Ollama health check (GET /) to determine if the service is ok. + httpx.get(ollama_url, timeout=timeout) + + # When served by Ollama, the OpenAI API is available at the path "v1/". + client = httpx.Client(base_url=ollama_url + "v1/", timeout=timeout) + return cls(client) diff --git a/packages/exchange/src/exchange/providers/openai.py b/packages/exchange/src/exchange/providers/openai.py new file mode 100644 index 000000000..b25c5a70a --- /dev/null +++ b/packages/exchange/src/exchange/providers/openai.py @@ -0,0 +1,101 @@ +import os +from typing import Any, Dict, List, Tuple, Type + +import httpx + +from exchange.message import Message +from exchange.providers.base import Provider, Usage +from exchange.providers.utils import ( + messages_to_openai_spec, + openai_response_to_message, + openai_single_message_context_length_exceeded, + raise_for_status, + tools_to_openai_spec, +) +from exchange.tool import Tool +from tenacity import retry, wait_fixed, stop_after_attempt +from exchange.providers.utils import retry_if_status + +OPENAI_HOST = "https://api.openai.com/" + +retry_procedure = retry( + wait=wait_fixed(2), + stop=stop_after_attempt(2), + retry=retry_if_status(codes=[429], above=500), + reraise=True, +) + + +class OpenAiProvider(Provider): + """Provides chat completions for models hosted directly by OpenAI.""" + + PROVIDER_NAME = "openai" + REQUIRED_ENV_VARS = ["OPENAI_API_KEY"] + instructions_url = "https://platform.openai.com/docs/api-reference/api-keys" + + def __init__(self, client: httpx.Client) -> None: + self.client = client + + @classmethod + def from_env(cls: Type["OpenAiProvider"]) -> "OpenAiProvider": + cls.check_env_vars(cls.instructions_url) + url = os.environ.get("OPENAI_HOST", OPENAI_HOST) + key = os.environ.get("OPENAI_API_KEY") + + client = httpx.Client( + base_url=url + "v1/", + auth=("Bearer", key), + timeout=httpx.Timeout(60 * 10), + ) + return cls(client) + + @staticmethod + def get_usage(data: dict) -> Usage: + usage = data.pop("usage") + input_tokens = usage.get("prompt_tokens") + output_tokens = usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + def complete( + self, + model: str, + system: str, + messages: List[Message], + tools: Tuple[Tool], + **kwargs: Dict[str, Any], + ) -> Tuple[Message, Usage]: + system_message = [] if model.startswith("o1") else [{"role": "system", "content": system}] + payload = dict( + messages=system_message + messages_to_openai_spec(messages), + model=model, + tools=tools_to_openai_spec(tools) if tools else [], + **kwargs, + ) + payload = {k: v for k, v in payload.items() if v} + response = self._post(payload) + + # Check for context_length_exceeded error for single, long input message + if "error" in response and len(messages) == 1: + openai_single_message_context_length_exceeded(response["error"]) + + message = openai_response_to_message(response) + usage = self.get_usage(response) + return message, usage + + @retry_procedure + def _post(self, payload: dict) -> dict: + # Note: While OpenAI and Ollama mount the API under "v1", this is + # conventional and not a strict requirement. For example, Azure OpenAI + # mounts the API under the deployment name, and "v1" is not in the URL. + # See https://github.com/openai/openai-openapi/blob/master/openapi.yaml + response = self.client.post("chat/completions", json=payload) + return raise_for_status(response).json() diff --git a/packages/exchange/src/exchange/providers/utils.py b/packages/exchange/src/exchange/providers/utils.py new file mode 100644 index 000000000..4be7ac31e --- /dev/null +++ b/packages/exchange/src/exchange/providers/utils.py @@ -0,0 +1,185 @@ +import base64 +import json +import re +from typing import Any, Callable, Dict, List, Optional, Tuple + +import httpx +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.tool import Tool +from tenacity import retry_if_exception + + +def retry_if_status(codes: Optional[List[int]] = None, above: Optional[int] = None) -> Callable: + codes = codes or [] + + def predicate(exc: Exception) -> bool: + if isinstance(exc, httpx.HTTPStatusError): + if exc.response.status_code in codes: + return True + if above and exc.response.status_code >= above: + return True + return False + + return retry_if_exception(predicate) + + +def raise_for_status(response: httpx.Response) -> httpx.Response: + """Raise with reason text.""" + try: + response.raise_for_status() + return response + except httpx.HTTPStatusError as e: + response.read() + if response.text: + raise httpx.HTTPStatusError(f"{e}\n{response.text}", request=e.request, response=e.response) + else: + raise e + + +def encode_image(image_path: str) -> str: + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def messages_to_openai_spec(messages: List[Message]) -> List[Dict[str, Any]]: + messages_spec = [] + for message in messages: + converted = {"role": message.role} + output = [] + for content in message.content: + if isinstance(content, Text): + converted["content"] = content.text + elif isinstance(content, ToolUse): + sanitized_name = re.sub(r"[^a-zA-Z0-9_-]", "_", content.name) + converted.setdefault("tool_calls", []).append( + { + "id": content.id, + "type": "function", + "function": { + "name": sanitized_name, + "arguments": json.dumps(content.parameters), + }, + } + ) + elif isinstance(content, ToolResult): + if content.output.startswith('"image:'): + image_path = content.output.replace('"image:', "").replace('"', "") + output.append( + { + "role": "tool", + "content": [ + { + "type": "text", + "text": "This tool result included an image that is uploaded in the next message.", + }, + ], + "tool_call_id": content.tool_use_id, + } + ) + # Note: it is possible to only do this when message == messages[-1] + # but it doesn't seem to hurt too much with tokens to keep this. + output.append( + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{encode_image(image_path)}"}, + } + ], + } + ) + + else: + output.append( + { + "role": "tool", + "content": content.output, + "tool_call_id": content.tool_use_id, + } + ) + + if "content" in converted or "tool_calls" in converted: + output = [converted] + output + messages_spec.extend(output) + return messages_spec + + +def tools_to_openai_spec(tools: Tuple[Tool]) -> Dict[str, Any]: + tools_names = set() + result = [] + for tool in tools: + if tool.name in tools_names: + # we should never allow duplicate tools + raise ValueError(f"Duplicate tool name: {tool.name}") + result.append( + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + ) + tools_names.add(tool.name) + return result + + +def openai_response_to_message(response: dict) -> Message: + original = response["choices"][0]["message"] + content = [] + text = original.get("content") + if text: + content.append(Text(text=text)) + + tool_calls = original.get("tool_calls") + if tool_calls: + for tool_call in tool_calls: + try: + function_name = tool_call["function"]["name"] + # We occasionally see the model generate an invalid function name + # sending this back to openai raises a validation error + if not re.match(r"^[a-zA-Z0-9_-]+$", function_name): + content.append( + ToolUse( + id=tool_call["id"], + name=function_name, + parameters=tool_call["function"]["arguments"], + is_error=True, + error_message=f"The provided function name '{function_name}' had invalid characters, it must match this regex [a-zA-Z0-9_-]+", # noqa: E501 + ) + ) + else: + content.append( + ToolUse( + id=tool_call["id"], + name=function_name, + parameters=json.loads(tool_call["function"]["arguments"]), + ) + ) + except json.JSONDecodeError: + content.append( + ToolUse( + id=tool_call["id"], + name=tool_call["function"]["name"], + parameters=tool_call["function"]["arguments"], + is_error=True, + error_message=f"Could not interpret tool use parameters for id {tool_call['id']}: {tool_call['function']['arguments']}", # noqa: E501 + ) + ) + + return Message(role="assistant", content=content) + + +def openai_single_message_context_length_exceeded(error_dict: dict) -> None: + code = error_dict.get("code") + if code == "context_length_exceeded" or code == "string_above_max_length": + raise InitialMessageTooLargeError(f"Input message too long. Message: {error_dict.get('message')}") + + +class InitialMessageTooLargeError(Exception): + """Custom error raised when the first input message in an exchange is too large.""" + + pass diff --git a/packages/exchange/src/exchange/token_usage_collector.py b/packages/exchange/src/exchange/token_usage_collector.py new file mode 100644 index 000000000..8f0801062 --- /dev/null +++ b/packages/exchange/src/exchange/token_usage_collector.py @@ -0,0 +1,27 @@ +from collections import defaultdict +from typing import Dict + +from exchange.providers.base import Usage + + +class _TokenUsageCollector: + def __init__(self) -> None: + self.usage_data = [] + + def collect(self, model: str, usage: Usage) -> None: + self.usage_data.append((model, usage)) + + def get_token_usage_group_by_model(self) -> Dict[str, Usage]: + usage_group_by_model = defaultdict(lambda: Usage(0, 0, 0)) + for model, usage in self.usage_data: + usage_by_model = usage_group_by_model[model] + if usage is not None and usage.input_tokens is not None: + usage_by_model.input_tokens += usage.input_tokens + if usage is not None and usage.output_tokens is not None: + usage_by_model.output_tokens += usage.output_tokens + if usage is not None and usage.total_tokens is not None: + usage_by_model.total_tokens += usage.total_tokens + return usage_group_by_model + + +_token_usage_collector = _TokenUsageCollector() diff --git a/packages/exchange/src/exchange/tool.py b/packages/exchange/src/exchange/tool.py new file mode 100644 index 000000000..4ce9e7c50 --- /dev/null +++ b/packages/exchange/src/exchange/tool.py @@ -0,0 +1,55 @@ +import inspect +from typing import Any, Callable, Type + +from attrs import define + +from exchange.utils import json_schema, parse_docstring + + +@define +class Tool: + """A tool that can be used by a model. + + Attributes: + name (str): The name of the tool + description (str): A description of what the tool does + parameters dict[str, Any]: A json schema of the function signature + function (Callable): The python function that powers the tool + """ + + name: str + description: str + parameters: dict[str, Any] + function: Callable + + @classmethod + def from_function(cls: Type["Tool"], func: Any) -> "Tool": # noqa: ANN401 + """Create a tool instance from a function and its docstring + + The function must have a docstring - we require it to load the description + and parameter descriptions. This also supports a class instance with a __call__ + method. + """ + if inspect.isfunction(func) or inspect.ismethod(func): + name = func.__name__ + else: + name = func.__class__.__name__.lower() + func = func.__call__ + + description, param_descriptions = parse_docstring(func) + schema = json_schema(func) + + # Set the 'description' field of the schema to the arg's docstring description + for arg in param_descriptions: + arg_name, arg_description = arg["name"], arg["description"] + + if arg_name not in schema["properties"]: + raise ValueError(f"Argument {arg_name} found in docstring but not in schema") + schema["properties"][arg_name]["description"] = arg_description + + return cls( + name=name, + description=description, + parameters=schema, + function=func, + ) diff --git a/packages/exchange/src/exchange/utils.py b/packages/exchange/src/exchange/utils.py new file mode 100644 index 000000000..04d5ffa18 --- /dev/null +++ b/packages/exchange/src/exchange/utils.py @@ -0,0 +1,155 @@ +import inspect +import uuid +from importlib.metadata import entry_points +from typing import Any, Callable, Dict, List, Type, get_args, get_origin + +from griffe import ( + Docstring, + DocstringSection, + DocstringSectionParameters, + DocstringSectionText, +) + + +def create_object_id(prefix: str) -> str: + return f"{prefix}_{uuid.uuid4().hex[:24]}" + + +def compact(content: str) -> str: + """Replace any amount of whitespace with a single space""" + return " ".join(content.split()) + + +def parse_docstring(func: Callable) -> tuple[str, List[Dict]]: + """Get description and parameters from function docstring""" + function_args = list(inspect.signature(func).parameters.keys()) + text = str(func.__doc__) + docstring = Docstring(text) + + for style in ["google", "numpy", "sphinx"]: + parsed = docstring.parse(style) + + if not _check_section_is_present(parsed, DocstringSectionText): + continue + + if function_args and not _check_section_is_present(parsed, DocstringSectionParameters): + continue + break + else: # if we did not find a valid style in the for loop + raise ValueError( + f"Attempted to load from a function {func.__name__} with an invalid docstring. Parameter docs are required if the function has parameters. https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings" # noqa: E501 + ) + + description = None + parameters = [] + + for section in parsed: + if isinstance(section, DocstringSectionText): + description = compact(section.value) + elif isinstance(section, DocstringSectionParameters): + parameters = [arg.as_dict() for arg in section.value] + + docstring_args = [d["name"] for d in parameters] + if description is None: + raise ValueError("Docstring must include a description.") + + if not docstring_args == function_args: + extra_docstring_args = ", ".join(sorted(set(docstring_args) - set(function_args))) + extra_function_args = ", ".join(sorted(set(function_args) - set(docstring_args))) + if extra_docstring_args and extra_function_args: + raise ValueError( + f"Docstring args must match function args: docstring had extra {extra_docstring_args}; function had extra {extra_function_args}" # noqa: E501 + ) + elif extra_function_args: + raise ValueError(f"Docstring args must match function args: function had extra {extra_function_args}") + elif extra_docstring_args: + raise ValueError(f"Docstring args must match function args: docstring had extra {extra_docstring_args}") + else: + raise ValueError("Docstring args must match function args") + + return description, parameters + + +def _check_section_is_present( + parsed_docstring: List[DocstringSection], section_type: Type[DocstringSectionText] +) -> bool: + for section in parsed_docstring: + if isinstance(section, section_type): + return True + return False + + +def json_schema(func: Any) -> dict[str, Any]: # noqa: ANN401 + """Get the json schema for a function""" + signature = inspect.signature(func) + parameters = signature.parameters + + schema = { + "type": "object", + "properties": {}, + "required": [], + } + + for param_name, param in parameters.items(): + param_schema = {} + + if param.annotation is not inspect.Parameter.empty: + param_schema = _map_type_to_schema(param.annotation) + + if param.default is not inspect.Parameter.empty: + param_schema["default"] = param.default + + schema["properties"][param_name] = param_schema + + if param.default is inspect.Parameter.empty: + schema["required"].append(param_name) + + return schema + + +def _map_type_to_schema(py_type: Type) -> Dict[str, Any]: # noqa: ANN401 + origin = get_origin(py_type) + args = get_args(py_type) + + if origin is list or origin is tuple: + return {"type": "array", "items": _map_type_to_schema(args[0] if args else Any)} + elif origin is dict: + return { + "type": "object", + "additionalProperties": _map_type_to_schema(args[1] if len(args) > 1 else Any), + } + elif py_type is int: + return {"type": "integer"} + elif py_type is bool: + return {"type": "boolean"} + elif py_type is float: + return {"type": "number"} + elif py_type is str: + return {"type": "string"} + else: + return {"type": "string"} + + +def load_plugins(group: str) -> dict: + """ + Load plugins based on a specified entry point group. + + This function iterates through all entry points registered under a specified group + + Args: + group (str): The entry point group to load plugins from. This should match the group specified + in the package setup where plugins are defined. + + Returns: + dict: A dictionary where each key is the entry point name, and the value is the loaded plugin object. + + Raises: + Exception: Propagates exceptions raised by entry point loading, which might occur if a plugin + is not found or if there are issues with the plugin's code. + """ + plugins = {} + # Access all entry points for the specified group and load each. + for entrypoint in entry_points(group=group): + plugin = entrypoint.load() # Load the plugin. + plugins[entrypoint.name] = plugin # Store the loaded plugin in the dictionary. + return plugins diff --git a/packages/exchange/tests/.ruff.toml b/packages/exchange/tests/.ruff.toml new file mode 100644 index 000000000..cddf42337 --- /dev/null +++ b/packages/exchange/tests/.ruff.toml @@ -0,0 +1,2 @@ +lint.select = ["E", "W", "F", "N"] +line-length = 120 \ No newline at end of file diff --git a/packages/exchange/tests/__init__.py b/packages/exchange/tests/__init__.py new file mode 100644 index 000000000..c2b89ac6d --- /dev/null +++ b/packages/exchange/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for exchange.""" diff --git a/packages/exchange/tests/conftest.py b/packages/exchange/tests/conftest.py new file mode 100644 index 000000000..684a446d7 --- /dev/null +++ b/packages/exchange/tests/conftest.py @@ -0,0 +1,36 @@ +import pytest + +from exchange.providers.base import Usage + + +@pytest.fixture +def dummy_tool(): + def _dummy_tool() -> str: + """An example tool""" + return "dummy response" + + return _dummy_tool + + +@pytest.fixture +def usage_factory(): + def _create_usage(input_tokens=100, output_tokens=200, total_tokens=300): + return Usage(input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=total_tokens) + + return _create_usage + + +def read_file(filename: str) -> str: + """ + Read the contents of the file. + + Args: + filename (str): The path to the file, which can be relative or + absolute. If it is a plain filename, it is assumed to be in the + current working directory. + + Returns: + str: The contents of the file. + """ + assert filename == "test.txt" + return "hello exchange" diff --git a/packages/exchange/tests/providers/__init__.py b/packages/exchange/tests/providers/__init__.py new file mode 100644 index 000000000..4e13a800d --- /dev/null +++ b/packages/exchange/tests/providers/__init__.py @@ -0,0 +1 @@ +"""Tests for chat completion providers.""" diff --git a/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml b/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml new file mode 100644 index 000000000..3ac8a4fc0 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_azure_complete.yaml @@ -0,0 +1,68 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + api-key: + - test_azure_api_key + connection: + - keep-alive + content-length: + - '139' + content-type: + - application/json + host: + - test.openai.azure.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview + response: + body: + string: '{"choices":[{"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}},"finish_reason":"stop","index":0,"logprobs":null,"message":{"content":"Hello! + How can I assist you today?","role":"assistant"}}],"created":1727230065,"id":"chatcmpl-ABBjN3AoYlxkP7Vg2lBvUhYeA6j5K","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":9,"prompt_tokens":18,"total_tokens":27}} + + ' + headers: + Cache-Control: + - no-cache, must-revalidate + Content-Length: + - '825' + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 02:07:45 GMT + Set-Cookie: test_set_cookie + Strict-Transport-Security: + - max-age=31536000; includeSubDomains; preload + access-control-allow-origin: + - '*' + apim-request-id: + - 82e66ef8-ac07-4a43-b60f-9aecec1d8c81 + azureml-model-session: + - d145-20240919052126 + openai-organization: test_openai_org_key + x-accel-buffering: + - 'no' + x-content-type-options: + - nosniff + x-ms-client-request-id: + - 82e66ef8-ac07-4a43-b60f-9aecec1d8c81 + x-ms-rai-invoked: + - 'true' + x-ms-region: + - Switzerland North + x-ratelimit-remaining-requests: + - '79' + x-ratelimit-remaining-tokens: + - '79984' + x-request-id: + - 38db9001-8b16-4efe-84c9-620e10f18c3c + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml b/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml new file mode 100644 index 000000000..9da479790 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_azure_tools.yaml @@ -0,0 +1,74 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. + Expect to need to read a file using read_file."}, {"role": "user", "content": + "What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools": + [{"type": "function", "function": {"name": "read_file", "description": "Read + the contents of the file.", "parameters": {"type": "object", "properties": {"filename": + {"type": "string", "description": "The path to the file, which can be relative + or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent + working directory."}}, "required": ["filename"]}}}]}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + api-key: + - test_azure_api_key + connection: + - keep-alive + content-length: + - '608' + content-type: + - application/json + host: + - test.openai.azure.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://test.openai.azure.com/openai/deployments/test-azure-deployment/chat/completions?api-version=2024-05-01-preview + response: + body: + string: '{"choices":[{"content_filter_results":{},"finish_reason":"tool_calls","index":0,"logprobs":null,"message":{"content":null,"role":"assistant","tool_calls":[{"function":{"arguments":"{\n \"filename\": + \"test.txt\"\n}","name":"read_file"},"id":"call_a47abadDxlGKIWjvYYvGVAHa","type":"function"}]}}],"created":1727256650,"id":"chatcmpl-ABIeABbq5WVCq0e0AriGFaYDSih3P","model":"gpt-4-32k","object":"chat.completion","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"system_fingerprint":null,"usage":{"completion_tokens":16,"prompt_tokens":109,"total_tokens":125}} + + ' + headers: + Cache-Control: + - no-cache, must-revalidate + Content-Length: + - '769' + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 09:30:50 GMT + Set-Cookie: test_set_cookie + Strict-Transport-Security: + - max-age=31536000; includeSubDomains; preload + access-control-allow-origin: + - '*' + apim-request-id: + - 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339 + azureml-model-session: + - d145-20240919052126 + openai-organization: test_openai_org_key + x-accel-buffering: + - 'no' + x-content-type-options: + - nosniff + x-ms-client-request-id: + - 8c0e3372-8ffd-4ff5-a5d1-0b962c4ea339 + x-ms-rai-invoked: + - 'true' + x-ms-region: + - Switzerland North + x-ratelimit-remaining-requests: + - '79' + x-ratelimit-remaining-tokens: + - '79824' + x-request-id: + - 401bd803-b790-47b7-b098-98708d44f060 + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml b/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml new file mode 100644 index 000000000..88bc206ff --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_ollama_complete.yaml @@ -0,0 +1,68 @@ +interactions: +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: GET + uri: http://localhost:11434/ + response: + body: + string: Ollama is running + headers: + Content-Length: + - '17' + Content-Type: + - text/plain; charset=utf-8 + Date: + - Sun, 22 Sep 2024 23:40:13 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}], "model": "mistral-nemo"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '140' + content-type: + - application/json + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: POST + uri: http://localhost:11434/v1/chat/completions + response: + body: + string: "{\"id\":\"chatcmpl-429\",\"object\":\"chat.completion\",\"created\":1727048416,\"model\":\"mistral-nemo\",\"system_fingerprint\":\"fp_ollama\",\"choices\":[{\"index\":0,\"message\":{\"role\":\"assistant\",\"content\":\"Hello! + I'm here to help. How can I assist you today? Let's chat. \U0001F60A\"},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":23,\"total_tokens\":33}}\n" + headers: + Content-Length: + - '356' + Content-Type: + - application/json + Date: + - Sun, 22 Sep 2024 23:40:16 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml b/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml new file mode 100644 index 000000000..7271bf227 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_ollama_tools.yaml @@ -0,0 +1,75 @@ +interactions: +- request: + body: '' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: GET + uri: http://localhost:11434/ + response: + body: + string: Ollama is running + headers: + Content-Length: + - '17' + Content-Type: + - text/plain; charset=utf-8 + Date: + - Wed, 25 Sep 2024 09:23:08 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. + Expect to need to read a file using read_file."}, {"role": "user", "content": + "What are the contents of this file? test.txt"}], "model": "mistral-nemo", "tools": + [{"type": "function", "function": {"name": "read_file", "description": "Read + the contents of the file.", "parameters": {"type": "object", "properties": {"filename": + {"type": "string", "description": "The path to the file, which can be relative + or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent + working directory."}}, "required": ["filename"]}}}]}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '609' + content-type: + - application/json + host: + - localhost:11434 + user-agent: + - python-httpx/0.27.2 + method: POST + uri: http://localhost:11434/v1/chat/completions + response: + body: + string: '{"id":"chatcmpl-245","object":"chat.completion","created":1727256190,"model":"mistral-nemo","system_fingerprint":"fp_ollama","choices":[{"index":0,"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_z6fgu3z3","type":"function","function":{"name":"read_file","arguments":"{\"filename\":\"test.txt\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":112,"completion_tokens":21,"total_tokens":133}} + + ' + headers: + Content-Length: + - '425' + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 09:23:10 GMT + Set-Cookie: test_set_cookie + openai-organization: test_openai_org_key + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml b/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml new file mode 100644 index 000000000..1a92eb36b --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_openai_complete.yaml @@ -0,0 +1,80 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}], "model": "gpt-4o-mini"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + authorization: + - Bearer test_openai_api_key + connection: + - keep-alive + content-length: + - '139' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"id\": \"chatcmpl-AAQTYi3DXJnltAfd5sUH1Wnzh69t3\",\n \"object\": + \"chat.completion\",\n \"created\": 1727048416,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": \"Hello! How can I assist you today?\",\n + \ \"refusal\": null\n },\n \"logprobs\": null,\n \"finish_reason\": + \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": 18,\n \"completion_tokens\": + 9,\n \"total_tokens\": 27,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": + 0\n }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8c762399feb55739-SYD + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Sun, 22 Sep 2024 23:40:17 GMT + Server: + - cloudflare + Set-Cookie: test_set_cookie + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + content-length: + - '593' + openai-organization: test_openai_org_key + openai-processing-ms: + - '560' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=15552000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '200000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '199973' + x-ratelimit-reset-requests: + - 8.64s + x-ratelimit-reset-tokens: + - 8ms + x-request-id: + - req_22e26c840219cde3152eaba1ce89483b + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml b/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml new file mode 100644 index 000000000..30496fcb8 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_openai_tools.yaml @@ -0,0 +1,90 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant. + Expect to need to read a file using read_file."}, {"role": "user", "content": + "What are the contents of this file? test.txt"}], "model": "gpt-4o-mini", "tools": + [{"type": "function", "function": {"name": "read_file", "description": "Read + the contents of the file.", "parameters": {"type": "object", "properties": {"filename": + {"type": "string", "description": "The path to the file, which can be relative + or\nabsolute. If it is a plain filename, it is assumed to be in the\ncurrent + working directory."}}, "required": ["filename"]}}}]}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + authorization: + - Bearer test_openai_api_key + connection: + - keep-alive + content-length: + - '608' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"id\": \"chatcmpl-ABIV2aZWVKQ774RAQ8KHYdNwkI5N7\",\n \"object\": + \"chat.completion\",\n \"created\": 1727256084,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n + \ \"id\": \"call_xXYlw4A7Ud1qtCopuK5gEJrP\",\n \"type\": + \"function\",\n \"function\": {\n \"name\": \"read_file\",\n + \ \"arguments\": \"{\\\"filename\\\":\\\"test.txt\\\"}\"\n }\n + \ }\n ],\n \"refusal\": null\n },\n \"logprobs\": + null,\n \"finish_reason\": \"tool_calls\"\n }\n ],\n \"usage\": + {\n \"prompt_tokens\": 107,\n \"completion_tokens\": 15,\n \"total_tokens\": + 122,\n \"completion_tokens_details\": {\n \"reasoning_tokens\": 0\n + \ }\n },\n \"system_fingerprint\": \"fp_1bb46167f9\"\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8c89f19fed997e43-SYD + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 09:21:25 GMT + Server: + - cloudflare + Set-Cookie: test_set_cookie + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + content-length: + - '844' + openai-organization: test_openai_org_key + openai-processing-ms: + - '266' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '200000' + x-ratelimit-remaining-requests: + - '9991' + x-ratelimit-remaining-tokens: + - '199952' + x-ratelimit-reset-requests: + - 1m9.486s + x-ratelimit-reset-tokens: + - 14ms + x-request-id: + - req_ff6b5d65c24f40e1faaf049c175e718d + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml b/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml new file mode 100644 index 000000000..1b9691d29 --- /dev/null +++ b/packages/exchange/tests/providers/cassettes/test_openai_vision.yaml @@ -0,0 +1,86 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What does the first entry in the menu say?"}, {"role": + "assistant", "tool_calls": [{"id": "xyz", "type": "function", "function": {"name": + "screenshot", "arguments": "{}"}}]}, {"role": "tool", "content": [{"type": "text", + "text": "This tool result included an image that is uploaded in the next message."}], + "tool_call_id": "xyz"}, {"role": "user", "content": [{"type": "image_url", "image_url": + {"url": ""}}]}], + "model": "gpt-4o-mini"}' + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + authorization: + - Bearer test_openai_api_key + connection: + - keep-alive + content-length: + - '78932' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - python-httpx/0.27.2 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: "{\n \"id\": \"chatcmpl-ABIA0YzOHlhqb02K8Ay4Jwsw6xOpk\",\n \"object\": + \"chat.completion\",\n \"created\": 1727254780,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": \"The first entry in the menu says \\\"Ask + Goose.\\\"\",\n \"refusal\": null\n },\n \"logprobs\": null,\n + \ \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\": + 14230,\n \"completion_tokens\": 11,\n \"total_tokens\": 14241,\n \"completion_tokens_details\": + {\n \"reasoning_tokens\": 0\n }\n },\n \"system_fingerprint\": \"fp_e9627b5346\"\n}\n" + headers: + CF-Cache-Status: + - DYNAMIC + CF-RAY: + - 8c89d1c45d98a883-SYD + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Wed, 25 Sep 2024 08:59:41 GMT + Server: + - cloudflare + Set-Cookie: test_set_cookie + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + content-length: + - '613' + openai-organization: test_openai_org_key + openai-processing-ms: + - '1289' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '10000' + x-ratelimit-limit-tokens: + - '200000' + x-ratelimit-remaining-requests: + - '9999' + x-ratelimit-remaining-tokens: + - '199177' + x-ratelimit-reset-requests: + - 8.64s + x-ratelimit-reset-tokens: + - 246ms + x-request-id: + - req_9503b21e31db78c4ebd2b71b304cea72 + status: + code: 200 + message: OK +version: 1 diff --git a/packages/exchange/tests/providers/conftest.py b/packages/exchange/tests/providers/conftest.py new file mode 100644 index 000000000..010504e84 --- /dev/null +++ b/packages/exchange/tests/providers/conftest.py @@ -0,0 +1,131 @@ +import os +import re +from typing import Type, Tuple + +import pytest + +from exchange import Message, ToolUse, ToolResult, Tool +from exchange.providers import Usage, Provider +from tests.conftest import read_file + +OPENAI_API_KEY = "test_openai_api_key" +OPENAI_ORG_ID = "test_openai_org_key" +OPENAI_PROJECT_ID = "test_openai_project_id" + + +@pytest.fixture +def default_openai_env(monkeypatch): + """ + This fixture prevents OpenAIProvider.from_env() from erring on missing + environment variables. + + When running VCR tests for the first time or after deleting a cassette + recording, set required environment variables, so that real requests don't + fail. Subsequent runs use the recorded data, so don't need them. + """ + if "OPENAI_API_KEY" not in os.environ: + monkeypatch.setenv("OPENAI_API_KEY", OPENAI_API_KEY) + + +AZURE_ENDPOINT = "https://test.openai.azure.com" +AZURE_DEPLOYMENT_NAME = "test-azure-deployment" +AZURE_API_VERSION = "2024-05-01-preview" +AZURE_API_KEY = "test_azure_api_key" + + +@pytest.fixture +def default_azure_env(monkeypatch): + """ + This fixture prevents AzureProvider.from_env() from erring on missing + environment variables. + + When running VCR tests for the first time or after deleting a cassette + recording, set required environment variables, so that real requests don't + fail. Subsequent runs use the recorded data, so don't need them. + """ + if "AZURE_CHAT_COMPLETIONS_HOST_NAME" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_HOST_NAME", AZURE_ENDPOINT) + if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME", AZURE_DEPLOYMENT_NAME) + if "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION", AZURE_API_VERSION) + if "AZURE_CHAT_COMPLETIONS_KEY" not in os.environ: + monkeypatch.setenv("AZURE_CHAT_COMPLETIONS_KEY", AZURE_API_KEY) + + +@pytest.fixture(scope="module") +def vcr_config(): + """ + This scrubs sensitive data and gunzips bodies when in recording mode. + + Without this, you would leak cookies and auth tokens in the cassettes. + Also, depending on the request, some responses would be binary encoded + while others plain json. This ensures all bodies are human-readable. + """ + return { + "decode_compressed_response": True, + "filter_headers": [ + ("authorization", "Bearer " + OPENAI_API_KEY), + ("openai-organization", OPENAI_ORG_ID), + ("openai-project", OPENAI_PROJECT_ID), + ("cookie", None), + ], + "before_record_request": scrub_request_url, + "before_record_response": scrub_response_headers, + } + + +def scrub_request_url(request): + """ + This scrubs sensitive request data in provider-specific way. Note that headers + are case-sensitive! + """ + if "openai.azure.com" in request.uri: + request.uri = re.sub(r"https://[^/]+", AZURE_ENDPOINT, request.uri) + request.uri = re.sub(r"/deployments/[^/]+", f"/deployments/{AZURE_DEPLOYMENT_NAME}", request.uri) + request.headers["host"] = AZURE_ENDPOINT.replace("https://", "") + request.headers["api-key"] = AZURE_API_KEY + + return request + + +def scrub_response_headers(response): + """ + This scrubs sensitive response headers. Note they are case-sensitive! + """ + response["headers"]["openai-organization"] = OPENAI_ORG_ID + response["headers"]["Set-Cookie"] = "test_set_cookie" + return response + + +def complete(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: + provider = provider_cls.from_env() + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs) + + +def tools(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: + provider = provider_cls.from_env() + system = "You are a helpful assistant. Expect to need to read a file using read_file." + messages = [Message.user("What are the contents of this file? test.txt")] + return provider.complete( + model=model, system=system, messages=messages, tools=(Tool.from_function(read_file),), **kwargs + ) + + +def vision(provider_cls: Type[Provider], model: str, **kwargs) -> Tuple[Message, Usage]: + provider = provider_cls.from_env() + system = "You are a helpful assistant." + messages = [ + Message.user("What does the first entry in the menu say?"), + Message( + role="assistant", + content=[ToolUse(id="xyz", name="screenshot", parameters={})], + ), + Message( + role="user", + content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')], + ), + ] + return provider.complete(model=model, system=system, messages=messages, tools=None, **kwargs) diff --git a/packages/exchange/tests/providers/test_anthropic.py b/packages/exchange/tests/providers/test_anthropic.py new file mode 100644 index 000000000..272ebcb0f --- /dev/null +++ b/packages/exchange/tests/providers/test_anthropic.py @@ -0,0 +1,184 @@ +import os +from unittest.mock import patch + +import httpx +import pytest +from exchange import Message, Text +from exchange.content import ToolResult, ToolUse +from exchange.providers.anthropic import AnthropicProvider +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.tool import Tool + + +def example_fn(param: str) -> None: + """ + Testing function. + + Args: + param (str): Description of param1 + """ + pass + + +@pytest.fixture +@patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test_api_key"}) +def anthropic_provider(): + return AnthropicProvider.from_env() + + +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + AnthropicProvider.from_env() + assert context.value.provider == "anthropic" + assert context.value.env_variable == "ANTHROPIC_API_KEY" + assert context.value.message == "Missing environment variable: ANTHROPIC_API_KEY for provider anthropic." + + +def test_anthropic_response_to_text_message() -> None: + response = { + "content": [{"type": "text", "text": "Hello from Claude!"}], + } + message = AnthropicProvider.anthropic_response_to_message(response) + assert message.content[0].text == "Hello from Claude!" + + +def test_anthropic_response_to_tool_use_message() -> None: + response = { + "content": [ + { + "type": "tool_use", + "id": "1", + "name": "example_fn", + "input": {"param": "value"}, + } + ], + } + message = AnthropicProvider.anthropic_response_to_message(response) + assert message.content[0].id == "1" + assert message.content[0].name == "example_fn" + assert message.content[0].parameters == {"param": "value"} + + +def test_tools_to_anthropic_spec() -> None: + tools = (Tool.from_function(example_fn),) + expected_spec = [ + { + "name": "example_fn", + "description": "Testing function.", + "input_schema": { + "type": "object", + "properties": {"param": {"type": "string", "description": "Description of param1"}}, + "required": ["param"], + }, + } + ] + result = AnthropicProvider.tools_to_anthropic_spec(tools) + assert result == expected_spec + + +def test_message_text_to_anthropic_spec() -> None: + messages = [Message.user("Hello, Claude")] + expected_spec = [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello, Claude"}], + } + ] + result = AnthropicProvider.messages_to_anthropic_spec(messages) + assert result == expected_spec + + +def test_messages_to_anthropic_spec() -> None: + messages = [ + Message(role="user", content=[Text(text="Hello, Claude")]), + Message( + role="assistant", + content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), + ] + actual_spec = AnthropicProvider.messages_to_anthropic_spec(messages) + # != + expected_spec = [ + {"role": "user", "content": [{"type": "text", "text": "Hello, Claude"}]}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "1", + "name": "example_fn", + "input": {"param": "value"}, + } + ], + }, + { + "role": "user", + "content": [{"type": "tool_result", "tool_use_id": "1", "content": "Result"}], + }, + ] + assert actual_spec == expected_spec + + +@patch("httpx.Client.post") +@patch("logging.warning") +@patch("logging.error") +def test_anthropic_completion(mock_error, mock_warning, mock_post, anthropic_provider): + mock_response = { + "content": [{"type": "text", "text": "Hello from Claude!"}], + "usage": {"input_tokens": 10, "output_tokens": 25}, + } + + # First attempts fail with status code 429, 2nd succeeds + def create_response(status_code, json_data=None): + response = httpx.Response(status_code) + response._content = httpx._content.json_dumps(json_data or {}).encode() + response._request = httpx.Request("POST", "https://api.anthropic.com/v1/messages") + return response + + mock_post.side_effect = [ + create_response(429), # 1st attempt + create_response(200, mock_response), # Final success + ] + + model = "claude-3-5-sonnet-20240620" + system = "You are a helpful assistant." + messages = [Message.user("Hello, Claude")] + + reply_message, reply_usage = anthropic_provider.complete(model=model, system=system, messages=messages) + + assert reply_message.content == [Text(text="Hello from Claude!")] + assert reply_usage.total_tokens == 35 + assert mock_post.call_count == 2 + mock_post.assert_any_call( + "https://api.anthropic.com/v1/messages", + json={ + "system": system, + "model": model, + "max_tokens": 4096, + "messages": [ + *[ + { + "role": msg.role, + "content": [{"type": "text", "text": msg.content[0].text}], + } + for msg in messages + ], + ], + }, + ) + + +@pytest.mark.integration +def test_anthropic_integration(): + provider = AnthropicProvider.from_env() + model = "claude-3-5-sonnet-20240620" # updated model to a known valid model + system = "You are a helpful assistant." + messages = [Message.user("Hello, Claude")] + + # Run the completion + reply = provider.complete(model=model, system=system, messages=messages) + + assert reply[0].content is not None + print("Completion content from Anthropic:", reply[0].content) diff --git a/packages/exchange/tests/providers/test_azure.py b/packages/exchange/tests/providers/test_azure.py new file mode 100644 index 000000000..b46be30b9 --- /dev/null +++ b/packages/exchange/tests/providers/test_azure.py @@ -0,0 +1,78 @@ +import os +from unittest.mock import patch + +import pytest + +from exchange import Text, ToolUse +from exchange.providers.azure import AzureProvider +from exchange.providers.base import MissingProviderEnvVariableError +from .conftest import complete, tools + +AZURE_MODEL = os.getenv("AZURE_MODEL", "gpt-4o-mini") + + +@pytest.mark.parametrize( + "env_var_name", + [ + ("AZURE_CHAT_COMPLETIONS_HOST_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME"), + ("AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION"), + ("AZURE_CHAT_COMPLETIONS_KEY"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AZURE_CHAT_COMPLETIONS_HOST_NAME": "test_host_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_NAME": "test_deployment_name", + "AZURE_CHAT_COMPLETIONS_DEPLOYMENT_API_VERSION": "test_api_version", + "AZURE_CHAT_COMPLETIONS_KEY": "test_api_key", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + AzureProvider.from_env() + assert context.value.provider == "azure" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider azure." + + +@pytest.mark.vcr() +def test_azure_complete(default_azure_env): + reply_message, reply_usage = complete(AzureProvider, AZURE_MODEL) + + assert reply_message.content == [Text(text="Hello! How can I assist you today?")] + assert reply_usage.total_tokens == 27 + + +@pytest.mark.integration +def test_azure_complete_integration(): + reply = complete(AzureProvider, AZURE_MODEL) + + assert reply[0].content is not None + print("Completion content from Azure:", reply[0].content) + + +@pytest.mark.vcr() +def test_azure_tools(default_azure_env): + reply_message, reply_usage = tools(AzureProvider, AZURE_MODEL) + + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "call_a47abadDxlGKIWjvYYvGVAHa" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 125 + + +@pytest.mark.integration +def test_azure_tools_integration(): + reply = tools(AzureProvider, AZURE_MODEL) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} diff --git a/packages/exchange/tests/providers/test_bedrock.py b/packages/exchange/tests/providers/test_bedrock.py new file mode 100644 index 000000000..f8fcaa4b8 --- /dev/null +++ b/packages/exchange/tests/providers/test_bedrock.py @@ -0,0 +1,255 @@ +import logging +import os +from unittest.mock import patch + +import pytest +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.providers.bedrock import BedrockProvider +from exchange.tool import Tool + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + "env_var_name", + [ + ("AWS_ACCESS_KEY_ID"), + ("AWS_SECRET_ACCESS_KEY"), + ("AWS_SESSION_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test_access_key_id", + "AWS_SECRET_ACCESS_KEY": "test_secret_access_key", + "AWS_SESSION_TOKEN": "test_session_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + BedrockProvider.from_env() + assert context.value.provider == "bedrock" + assert context.value.env_variable == env_var_name + assert context.value.message == f"Missing environment variable: {env_var_name} for provider bedrock." + + +@pytest.fixture +@patch.dict( + os.environ, + { + "AWS_REGION": "us-east-1", + "AWS_ACCESS_KEY_ID": "fake-access-key", + "AWS_SECRET_ACCESS_KEY": "fake-secret-key", + "AWS_SESSION_TOKEN": "fake-session-token", + }, +) +def bedrock_provider(): + return BedrockProvider.from_env() + + +@patch("time.time", return_value=1624250000) +def test_sign_and_get_headers(mock_time, bedrock_provider): + # Create sample values + method = "POST" + url = "https://bedrock-runtime.us-east-1.amazonaws.com/some/path" + payload = {"key": "value"} + service = "bedrock" + # Generate headers + headers = bedrock_provider.client.sign_and_get_headers( + method, + url, + payload, + service, + ) + # Assert that headers contain expected keys + assert "Authorization" in headers + assert "Content-Type" in headers + assert "X-Amz-date" in headers + assert "x-amz-content-sha256" in headers + assert "X-Amz-Security-Token" in headers + + +@patch("httpx.Client.post") +def test_complete(mock_post, bedrock_provider): + # Mocked response from the server + mock_response = { + "output": {"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}}, + "usage": {"inputTokens": 10, "outputTokens": 15, "totalTokens": 25}, + } + mock_post.return_value.json.return_value = mock_response + + model = "test-model" + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + tools = () + + reply_message, reply_usage = bedrock_provider.complete(model=model, system=system, messages=messages, tools=tools) + + # Assertions for reply message + assert reply_message.content[0].text == "Hello, world!" + assert reply_usage.total_tokens == 25 + + +def test_message_to_bedrock_spec_text(bedrock_provider): + message = Message(role="user", content=[Text("Hello, world!")]) + expected = {"role": "user", "content": [{"text": "Hello, world!"}]} + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_tool_use(bedrock_provider): + tool_use = ToolUse(id="tool-1", name="WordCount", parameters={"text": "Hello, world!"}) + message = Message(role="assistant", content=[tool_use]) + expected = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "WordCount", + "input": {"text": "Hello, world!"}, + } + } + ], + } + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_tool_result(bedrock_provider): + message = Message( + role="assistant", + content=[ToolUse(id="tool-1", name="WordCount", parameters={"text": "Hello, world!"})], + ) + expected = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "WordCount", + "input": {"text": "Hello, world!"}, + } + } + ], + } + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_tool_result_text(bedrock_provider): + tool_result = ToolResult(tool_use_id="tool-1", output="Error occurred", is_error=True) + message = Message(role="user", content=[tool_result]) + expected = { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tool-1", + "content": [{"text": "Error occurred"}], + "status": "error", + } + } + ], + } + assert bedrock_provider.message_to_bedrock_spec(message) == expected + + +def test_message_to_bedrock_spec_invalid(bedrock_provider): + with pytest.raises(Exception): + bedrock_provider.message_to_bedrock_spec(Message(role="user", content=[])) + + +def test_response_to_message_text(bedrock_provider): + response = {"role": "user", "content": [{"text": "Hello, world!"}]} + message = bedrock_provider.response_to_message(response) + assert message.role == "user" + assert message.content[0].text == "Hello, world!" + + +def test_response_to_message_tool_use(bedrock_provider): + response = { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tool-1", + "name": "WordCount", + "input": {"text": "Hello, world!"}, + } + } + ], + } + message = bedrock_provider.response_to_message(response) + assert message.role == "assistant" + assert message.content[0].name == "WordCount" + assert message.content[0].parameters == {"text": "Hello, world!"} + + +def test_response_to_message_tool_result(bedrock_provider): + response = { + "role": "user", + "content": [ + { + "toolResult": { + "toolResultId": "tool-1", + "content": [{"json": {"result": 2}}], + } + } + ], + } + message = bedrock_provider.response_to_message(response) + assert message.role == "user" + assert message.content[0].tool_use_id == "tool-1" + assert message.content[0].output == {"result": 2} + + +def test_response_to_message_invalid(bedrock_provider): + with pytest.raises(Exception): + bedrock_provider.response_to_message({}) + + +def test_tools_to_bedrock_spec(bedrock_provider): + def word_count(text: str): + return len(text.split()) + + tool = Tool( + name="WordCount", + description="Counts words.", + parameters={"text": "string"}, + function=word_count, + ) + expected = { + "tools": [ + { + "toolSpec": { + "name": "WordCount", + "description": "Counts words.", + "inputSchema": {"json": {"text": "string"}}, + } + } + ] + } + assert bedrock_provider.tools_to_bedrock_spec((tool,)) == expected + + +def test_tools_to_bedrock_spec_duplicate(bedrock_provider): + def word_count(text: str): + return len(text.split()) + + tool = Tool( + name="WordCount", + description="Counts words.", + parameters={"text": "string"}, + function=word_count, + ) + tool_duplicate = Tool( + name="WordCount", + description="Counts words.", + parameters={"text": "string"}, + function=word_count, + ) + tools = bedrock_provider.tools_to_bedrock_spec((tool, tool_duplicate)) + assert set(tool["toolSpec"]["name"] for tool in tools["tools"]) == {"WordCount"} diff --git a/packages/exchange/tests/providers/test_databricks.py b/packages/exchange/tests/providers/test_databricks.py new file mode 100644 index 000000000..4b6793abc --- /dev/null +++ b/packages/exchange/tests/providers/test_databricks.py @@ -0,0 +1,75 @@ +import os +from unittest.mock import patch + +import pytest +from exchange import Message, Text +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.providers.databricks import DatabricksProvider + + +@pytest.mark.parametrize( + "env_var_name", + [ + ("DATABRICKS_HOST"), + ("DATABRICKS_TOKEN"), + ], +) +def test_from_env_throw_error_when_missing_env_var(env_var_name): + with patch.dict( + os.environ, + { + "DATABRICKS_HOST": "test_host", + "DATABRICKS_TOKEN": "test_token", + }, + clear=True, + ): + os.environ.pop(env_var_name) + with pytest.raises(MissingProviderEnvVariableError) as context: + DatabricksProvider.from_env() + assert context.value.provider == "databricks" + assert context.value.env_variable == env_var_name + assert f"Missing environment variable: {env_var_name} for provider databricks" in context.value.message + assert "https://docs.databricks.com" in context.value.message + + +@pytest.fixture +@patch.dict( + os.environ, + {"DATABRICKS_HOST": "http://test-host", "DATABRICKS_TOKEN": "test_token"}, +) +def databricks_provider(): + return DatabricksProvider.from_env() + + +@patch("httpx.Client.post") +@patch("time.sleep", return_value=None) +@patch("logging.warning") +@patch("logging.error") +def test_databricks_completion(mock_error, mock_warning, mock_sleep, mock_post, databricks_provider): + mock_response = { + "choices": [{"message": {"role": "assistant", "content": "Hello!"}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35}, + } + mock_post.return_value.json.return_value = mock_response + + model = "my-databricks-model" + system = "You are a helpful assistant." + messages = [Message.user("Hello")] + tools = () + + reply_message, reply_usage = databricks_provider.complete( + model=model, system=system, messages=messages, tools=tools + ) + + assert reply_message.content == [Text(text="Hello!")] + assert reply_usage.total_tokens == 35 + assert mock_post.call_count == 1 + mock_post.assert_called_once_with( + "serving-endpoints/my-databricks-model/invocations", + json={ + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": "Hello"}, + ] + }, + ) diff --git a/packages/exchange/tests/providers/test_google.py b/packages/exchange/tests/providers/test_google.py new file mode 100644 index 000000000..76ae4c8d7 --- /dev/null +++ b/packages/exchange/tests/providers/test_google.py @@ -0,0 +1,158 @@ +import os +from unittest.mock import patch + +import httpx +import pytest +from exchange import Message, Text +from exchange.content import ToolResult, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.providers.google import GoogleProvider +from exchange.tool import Tool + + +def example_fn(param: str) -> None: + """ + Testing function. + + Args: + param (str): Description of param1 + """ + pass + + +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + GoogleProvider.from_env() + assert context.value.provider == "google" + assert context.value.env_variable == "GOOGLE_API_KEY" + assert "Missing environment variable: GOOGLE_API_KEY for provider google" in context.value.message + assert "https://ai.google.dev/gemini-api/docs/api-key" in context.value.message + + +@pytest.fixture +@patch.dict(os.environ, {"GOOGLE_API_KEY": "test_api_key"}) +def google_provider(): + return GoogleProvider.from_env() + + +def test_google_response_to_text_message() -> None: + response = {"candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}]} + message = GoogleProvider.google_response_to_message(response) + assert message.content[0].text == "Hello from Gemini!" + + +def test_google_response_to_tool_use_message() -> None: + response = { + "candidates": [ + { + "content": { + "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}], + "role": "model", + } + } + ] + } + + message = GoogleProvider.google_response_to_message(response) + assert message.content[0].name == "example_fn" + assert message.content[0].parameters == {"param": "value"} + + +def test_tools_to_google_spec() -> None: + tools = (Tool.from_function(example_fn),) + expected_spec = { + "functionDeclarations": [ + { + "name": "example_fn", + "description": "Testing function.", + "parameters": { + "type": "object", + "properties": {"param": {"type": "string", "description": "Description of param1"}}, + "required": ["param"], + }, + } + ] + } + result = GoogleProvider.tools_to_google_spec(tools) + assert result == expected_spec + + +def test_message_text_to_google_spec() -> None: + messages = [Message.user("Hello, Gemini")] + expected_spec = [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}] + result = GoogleProvider.messages_to_google_spec(messages) + assert result == expected_spec + + +def test_messages_to_google_spec() -> None: + messages = [ + Message(role="user", content=[Text(text="Hello, Gemini")]), + Message( + role="assistant", + content=[ToolUse(id="1", name="example_fn", parameters={"param": "value"})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="Result")]), + ] + actual_spec = GoogleProvider.messages_to_google_spec(messages) + # != + expected_spec = [ + {"role": "user", "parts": [{"text": "Hello, Gemini"}]}, + {"role": "model", "parts": [{"functionCall": {"name": "example_fn", "args": {"param": "value"}}}]}, + {"role": "user", "parts": [{"functionResponse": {"name": "1", "response": {"content": "Result"}}}]}, + ] + + assert actual_spec == expected_spec + + +@patch("httpx.Client.post") +@patch("logging.warning") +@patch("logging.error") +def test_google_completion(mock_error, mock_warning, mock_post, google_provider): + mock_response = { + "candidates": [{"content": {"parts": [{"text": "Hello from Gemini!"}], "role": "model"}}], + "usageMetadata": {"promptTokenCount": 3, "candidatesTokenCount": 10, "totalTokenCount": 13}, + } + + # First attempts fail with status code 429, 2nd succeeds + def create_response(status_code, json_data=None): + response = httpx.Response(status_code) + response._content = httpx._content.json_dumps(json_data or {}).encode() + response._request = httpx.Request("POST", "https://generativelanguage.googleapis.com/v1beta/") + return response + + mock_post.side_effect = [ + create_response(429), # 1st attempt + create_response(200, mock_response), # Final success + ] + + model = "gemini-1.5-flash" + system = "You are a helpful assistant." + messages = [Message.user("Hello, Gemini")] + + reply_message, reply_usage = google_provider.complete(model=model, system=system, messages=messages) + + assert reply_message.content == [Text(text="Hello from Gemini!")] + assert reply_usage.total_tokens == 13 + assert mock_post.call_count == 2 + mock_post.assert_any_call( + "models/gemini-1.5-flash:generateContent", + json={ + "system_instruction": {"parts": [{"text": "You are a helpful assistant."}]}, + "contents": [{"role": "user", "parts": [{"text": "Hello, Gemini"}]}], + }, + ) + + +@pytest.mark.integration +def test_google_integration(): + provider = GoogleProvider.from_env() + model = "gemini-1.5-flash" # updated model to a known valid model + system = "You are a helpful assistant." + messages = [Message.user("Hello, Gemini")] + + # Run the completion + reply = provider.complete(model=model, system=system, messages=messages) + + assert reply[0].content is not None + print("Completion content from Google:", reply[0].content) diff --git a/packages/exchange/tests/providers/test_ollama.py b/packages/exchange/tests/providers/test_ollama.py new file mode 100644 index 000000000..3ce870d36 --- /dev/null +++ b/packages/exchange/tests/providers/test_ollama.py @@ -0,0 +1,48 @@ +import os + +import pytest + +from exchange import Text, ToolUse +from exchange.providers.ollama import OllamaProvider, OLLAMA_MODEL +from .conftest import complete, tools + +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", OLLAMA_MODEL) + + +@pytest.mark.vcr() +def test_ollama_complete(): + reply_message, reply_usage = complete(OllamaProvider, OLLAMA_MODEL) + + assert reply_message.content == [Text(text="Hello! I'm here to help. How can I assist you today? Let's chat. 😊")] + assert reply_usage.total_tokens == 33 + + +@pytest.mark.integration +def test_ollama_complete_integration(): + reply = complete(OllamaProvider, OLLAMA_MODEL) + + assert reply[0].content is not None + print("Completion content from OpenAI:", reply[0].content) + + +@pytest.mark.vcr() +def test_ollama_tools(): + reply_message, reply_usage = tools(OllamaProvider, OLLAMA_MODEL) + + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "call_z6fgu3z3" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 133 + + +@pytest.mark.integration +def test_ollama_tools_integration(): + reply = tools(OllamaProvider, OLLAMA_MODEL, seed=3) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} diff --git a/packages/exchange/tests/providers/test_openai.py b/packages/exchange/tests/providers/test_openai.py new file mode 100644 index 000000000..ea979abeb --- /dev/null +++ b/packages/exchange/tests/providers/test_openai.py @@ -0,0 +1,75 @@ +import os +from unittest.mock import patch + +import pytest + +from exchange import Text, ToolUse +from exchange.providers.base import MissingProviderEnvVariableError +from exchange.providers.openai import OpenAiProvider +from .conftest import complete, vision, tools + +OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") + + +def test_from_env_throw_error_when_missing_api_key(): + with patch.dict(os.environ, {}, clear=True): + with pytest.raises(MissingProviderEnvVariableError) as context: + OpenAiProvider.from_env() + assert context.value.provider == "openai" + assert context.value.env_variable == "OPENAI_API_KEY" + assert "Missing environment variable: OPENAI_API_KEY for provider openai" in context.value.message + assert "https://platform.openai.com" in context.value.message + + +@pytest.mark.vcr() +def test_openai_complete(default_openai_env): + reply_message, reply_usage = complete(OpenAiProvider, OPENAI_MODEL) + + assert reply_message.content == [Text(text="Hello! How can I assist you today?")] + assert reply_usage.total_tokens == 27 + + +@pytest.mark.integration +def test_openai_complete_integration(): + reply = complete(OpenAiProvider, OPENAI_MODEL) + + assert reply[0].content is not None + print("Completion content from OpenAI:", reply[0].content) + + +@pytest.mark.vcr() +def test_openai_tools(default_openai_env): + reply_message, reply_usage = tools(OpenAiProvider, OPENAI_MODEL) + + tool_use = reply_message.content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id == "call_xXYlw4A7Ud1qtCopuK5gEJrP" + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + assert reply_usage.total_tokens == 122 + + +@pytest.mark.integration +def test_openai_tools_integration(): + reply = tools(OpenAiProvider, OPENAI_MODEL) + + tool_use = reply[0].content[0] + assert isinstance(tool_use, ToolUse), f"Expected ToolUse, but was {type(tool_use).__name__}" + assert tool_use.id is not None + assert tool_use.name == "read_file" + assert tool_use.parameters == {"filename": "test.txt"} + + +@pytest.mark.vcr() +def test_openai_vision(default_openai_env): + reply_message, reply_usage = vision(OpenAiProvider, OPENAI_MODEL) + + assert reply_message.content == [Text(text='The first entry in the menu says "Ask Goose."')] + assert reply_usage.total_tokens == 14241 + + +@pytest.mark.integration +def test_openai_vision_integration(): + reply = vision(OpenAiProvider, OPENAI_MODEL) + + assert "ask goose" in reply[0].text.lower() diff --git a/packages/exchange/tests/providers/test_provider.py b/packages/exchange/tests/providers/test_provider.py new file mode 100644 index 000000000..fb7d15ce0 --- /dev/null +++ b/packages/exchange/tests/providers/test_provider.py @@ -0,0 +1,18 @@ +import pytest +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers import get_provider + + +def test_get_provider_valid(): + provider_name = "openai" + provider = get_provider(provider_name) + assert provider.__name__ == "OpenAiProvider" + + +def test_get_provider_throw_error_for_unknown_provider(): + with pytest.raises(InvalidChoiceError) as error: + get_provider("nonexistent") + assert error.value.attribute_name == "provider" + assert error.value.attribute_value == "nonexistent" + assert "openai" in error.value.available_values + assert "openai" in error.value.message diff --git a/packages/exchange/tests/providers/test_provider_utils.py b/packages/exchange/tests/providers/test_provider_utils.py new file mode 100644 index 000000000..5ad0135ea --- /dev/null +++ b/packages/exchange/tests/providers/test_provider_utils.py @@ -0,0 +1,245 @@ +from copy import deepcopy +import json +from unittest.mock import Mock +from attrs import asdict +import httpx +import pytest +from unittest.mock import patch + +from exchange.content import Text, ToolResult, ToolUse +from exchange.message import Message +from exchange.providers.utils import ( + messages_to_openai_spec, + openai_response_to_message, + raise_for_status, + tools_to_openai_spec, +) +from exchange.tool import Tool + +OPEN_AI_TOOL_USE_RESPONSE = response = { + "choices": [ + { + "role": "assistant", + "message": { + "tool_calls": [ + { + "id": "1", + "function": { + "name": "example_fn", + "arguments": json.dumps( + { + "param": "value", + } + ), + # TODO: should this handle dict's as well? + }, + } + ], + }, + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35, + }, +} + + +def example_fn(param: str) -> None: + """ + Testing function. + + Args: + param (str): Description of param1 + """ + pass + + +def example_fn_two() -> str: + """ + Second testing function + + Returns: + str: Description of return value + """ + pass + + +def test_raise_for_status_success() -> None: + response = Mock(spec=httpx.Response) + response.status_code = 200 + + result = raise_for_status(response) + + assert result == response + + +def test_raise_for_status_failure_with_text() -> None: + response = Mock(spec=httpx.Response) + response.status_code = 404 + response.text = "Not Found: John Cena" + + try: + raise_for_status(response) + except httpx.HTTPStatusError as e: + assert e.response == response + assert str(e) == "404 Not Found: John Cena" + assert e.request is None + + +def test_raise_for_status_failure_without_text() -> None: + response = Mock(spec=httpx.Response) + response.status_code = 500 + response.text = "" + + try: + raise_for_status(response) + except httpx.HTTPStatusError as e: + assert e.response == response + assert str(e) == "500 Internal Server Error" + assert e.request is None + + +def test_messages_to_openai_spec() -> None: + messages = [ + Message(role="assistant", content=[Text("Hello!")]), + Message(role="user", content=[Text("How are you?")]), + Message( + role="assistant", + content=[ToolUse(id=1, name="tool1", parameters={"param1": "value1"})], + ), + Message(role="user", content=[ToolResult(tool_use_id=1, output="Result")]), + ] + + spec = messages_to_openai_spec(messages) + + assert spec == [ + {"role": "assistant", "content": "Hello!"}, + {"role": "user", "content": "How are you?"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": 1, + "type": "function", + "function": { + "name": "tool1", + "arguments": '{"param1": "value1"}', + }, + } + ], + }, + { + "role": "tool", + "content": "Result", + "tool_call_id": 1, + }, + ] + + +def test_tools_to_openai_spec() -> None: + tools = (Tool.from_function(example_fn), Tool.from_function(example_fn_two)) + assert len(tools_to_openai_spec(tools)) == 2 + + +def test_tools_to_openai_spec_duplicate() -> None: + tools = (Tool.from_function(example_fn), Tool.from_function(example_fn)) + with pytest.raises(ValueError): + tools_to_openai_spec(tools) + + +def test_tools_to_openai_spec_single() -> None: + tools = Tool.from_function(example_fn) + expected_spec = [ + { + "type": "function", + "function": { + "name": "example_fn", + "description": "Testing function.", + "parameters": { + "type": "object", + "properties": { + "param": { + "type": "string", + "description": "Description of param1", + } + }, + "required": ["param"], + }, + }, + }, + ] + result = tools_to_openai_spec((tools,)) + assert result == expected_spec + + +def test_tools_to_openai_spec_empty() -> None: + tools = () + expected_spec = [] + assert tools_to_openai_spec(tools) == expected_spec + + +def test_openai_response_to_message_text() -> None: + response = { + "choices": [ + { + "role": "assistant", + "message": {"content": "Hello from John Cena!"}, + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 25, + "total_tokens": 35, + }, + } + + message = openai_response_to_message(response) + + actual = asdict(message) + expect = asdict( + Message( + role="assistant", + content=[Text("Hello from John Cena!")], + ) + ) + actual.pop("id") + expect.pop("id") + assert actual == expect + + +def test_openai_response_to_message_valid_tooluse() -> None: + response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE) + message = openai_response_to_message(response) + actual = asdict(message) + expect = asdict( + Message( + role="assistant", + content=[ToolUse(id=1, name="example_fn", parameters={"param": "value"})], + ) + ) + actual.pop("id") + actual["content"][0].pop("id") + expect.pop("id") + expect["content"][0].pop("id") + assert actual == expect + + +def test_openai_response_to_message_invalid_func_name() -> None: + response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE) + response["choices"][0]["message"]["tool_calls"][0]["function"]["name"] = "invalid fn" + message = openai_response_to_message(response) + assert message.content[0].name == "invalid fn" + assert json.loads(message.content[0].parameters) == {"param": "value"} + assert message.content[0].is_error + assert message.content[0].error_message.startswith("The provided function name") + + +@patch("json.loads", side_effect=json.JSONDecodeError("error", "doc", 0)) +def test_openai_response_to_message_json_decode_error(mock_json) -> None: + response = deepcopy(OPEN_AI_TOOL_USE_RESPONSE) + message = openai_response_to_message(response) + assert message.content[0].name == "example_fn" + assert message.content[0].is_error + assert message.content[0].error_message.startswith("Could not interpret tool use") diff --git a/packages/exchange/tests/test_base.py b/packages/exchange/tests/test_base.py new file mode 100644 index 000000000..4aae8bde5 --- /dev/null +++ b/packages/exchange/tests/test_base.py @@ -0,0 +1,27 @@ +from exchange.providers.base import MissingProviderEnvVariableError + + +def test_missing_provider_env_variable_error_without_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + error = MissingProviderEnvVariableError(env_variable, provider) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url is None + assert error.message == "Missing environment variable: API_KEY for provider TestProvider." + + +def test_missing_provider_env_variable_error_with_instructions_url(): + env_variable = "API_KEY" + provider = "TestProvider" + instructions_url = "http://example.com/instructions" + error = MissingProviderEnvVariableError(env_variable, provider, instructions_url) + + assert error.env_variable == env_variable + assert error.provider == provider + assert error.instructions_url == instructions_url + assert error.message == ( + "Missing environment variable: API_KEY for provider TestProvider.\n" + "Please see http://example.com/instructions for instructions" + ) diff --git a/packages/exchange/tests/test_exchange.py b/packages/exchange/tests/test_exchange.py new file mode 100644 index 000000000..f01ef4694 --- /dev/null +++ b/packages/exchange/tests/test_exchange.py @@ -0,0 +1,763 @@ +from typing import List, Tuple + +import pytest + +from exchange.checkpoint import Checkpoint, CheckpointData +from exchange.content import Text, ToolResult, ToolUse +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators import PassiveModerator +from exchange.providers import Provider, Usage +from exchange.tool import Tool + + +def dummy_tool() -> str: + """An example tool""" + return "dummy response" + + +too_long_output = "x" * (2**20 + 1) +too_long_token_output = "word " * 128000 + + +def no_overlapping_checkpoints(exchange: Exchange) -> bool: + """Assert that there are no overlapping checkpoints in the exchange.""" + for i, checkpoint in enumerate(exchange.checkpoint_data.checkpoints): + for other_checkpoint in exchange.checkpoint_data.checkpoints[i + 1 :]: + if not checkpoint.end_index < other_checkpoint.start_index: + return False + return True + + +def checkpoint_to_index_pairs(checkpoints: List[Checkpoint]) -> List[Tuple[int, int]]: + return [(checkpoint.start_index, checkpoint.end_index) for checkpoint in checkpoints] + + +class MockProvider(Provider): + def __init__(self, sequence: List[Message], usage_dicts: List[dict]): + # We'll use init to provide a preplanned reply sequence + self.sequence = sequence + self.call_count = 0 + self.usage_dicts = usage_dicts + + @staticmethod + def get_usage(data: dict) -> Usage: + usage = data.pop("usage") + input_tokens = usage.get("input_tokens") + output_tokens = usage.get("output_tokens") + total_tokens = usage.get("total_tokens") + + if total_tokens is None and input_tokens is not None and output_tokens is not None: + total_tokens = input_tokens + output_tokens + + return Usage( + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + + def complete(self, model: str, system: str, messages: List[Message], tools: List[Tool]) -> Message: + output = self.sequence[self.call_count] + usage = self.get_usage(self.usage_dicts[self.call_count]) + self.call_count += 1 + return (output, usage) + + +def test_reply_with_unsupported_tool(): + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="unsupported_tool", parameters={})], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=(Tool.from_function(dummy_tool),), + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert isinstance(content, ToolResult) and content.is_error and "no tool exists" in content.output.lower() + + +def test_invalid_tool_parameters(): + """Test handling of invalid tool parameters response""" + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="dummy_tool", parameters="invalid json")], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test invalid parameters")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert isinstance(content, ToolResult) and content.is_error and "invalid json" in content.output.lower() + + +def test_max_tool_use_when_limit_reached(): + """Test the max_tool_use parameter in the reply method.""" + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="dummy_tool", parameters={})], + ), + Message( + role="assistant", + content=[ToolUse(id="2", name="dummy_tool", parameters={})], + ), + Message( + role="assistant", + content=[ToolUse(id="3", name="dummy_tool", parameters={})], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test max tool use")])) + + response = ex.reply(max_tool_use=3) + + assert ex.provider.call_count == 3 + assert "reached the limit" in response.content[0].text.lower() + + assert isinstance(ex.messages[-2].content[0], ToolResult) and ex.messages[-2].content[0].tool_use_id == "3" + + assert ex.messages[-1].role == "assistant" + + +def test_tool_output_too_long_character_error(): + """Test tool handling when output exceeds character limit.""" + + def long_output_tool_char() -> str: + return too_long_output + + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="long_output_tool_char", parameters={})], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(long_output_tool_char)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test long output char")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert ( + isinstance(content, ToolResult) + and content.is_error + and "output that was too long to handle" in content.output.lower() + ) + + +def test_tool_output_too_long_token_error(): + """Test tool handling when output exceeds token limit.""" + + def long_output_tool_token() -> str: + return too_long_token_output + + ex = Exchange( + provider=MockProvider( + sequence=[ + Message( + role="assistant", + content=[ToolUse(id="1", name="long_output_tool_token", parameters={})], + ), + Message( + role="assistant", + content=[Text(text="Here is the completion after tool call")], + ), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(long_output_tool_token)], + moderator=PassiveModerator(), + ) + + ex.add(Message(role="user", content=[Text(text="test long output token")])) + + ex.reply() + + content = ex.messages[-2].content[0] + assert ( + isinstance(content, ToolResult) + and content.is_error + and "output that was too long to handle" in content.output.lower() + ) + + +@pytest.fixture(scope="function") +def normal_exchange() -> Exchange: + ex = Exchange( + provider=MockProvider( + sequence=[ + Message(role="assistant", content=[Text(text="Message 1")]), + Message(role="assistant", content=[Text(text="Message 2")]), + Message(role="assistant", content=[Text(text="Message 3")]), + Message(role="assistant", content=[Text(text="Message 4")]), + Message(role="assistant", content=[Text(text="Message 5")]), + ], + usage_dicts=[ + {"usage": {"total_tokens": 10, "input_tokens": 5, "output_tokens": 5}}, + {"usage": {"total_tokens": 28, "input_tokens": 10, "output_tokens": 18}}, + {"usage": {"total_tokens": 33, "input_tokens": 28, "output_tokens": 5}}, + {"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}}, + {"usage": {"total_tokens": 50, "input_tokens": 40, "output_tokens": 10}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=(Tool.from_function(dummy_tool),), + moderator=PassiveModerator(), + checkpoint_data=CheckpointData(), + ) + return ex + + +@pytest.fixture(scope="function") +def resumed_exchange() -> Exchange: + messages = [ + Message(role="user", content=[Text(text="User message 1")]), + Message(role="assistant", content=[Text(text="Assistant Message 1")]), + Message(role="user", content=[Text(text="User message 2")]), + Message(role="assistant", content=[Text(text="Assistant Message 2")]), + Message(role="user", content=[Text(text="User message 3")]), + Message(role="assistant", content=[Text(text="Assistant Message 3")]), + ] + provider = MockProvider( + sequence=[ + Message(role="assistant", content=[Text(text="Assistant Message 4")]), + ], + usage_dicts=[ + {"usage": {"total_tokens": 40, "input_tokens": 32, "output_tokens": 8}}, + ], + ) + ex = Exchange( + provider=provider, + messages=messages, + tools=[], + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + checkpoint_data=CheckpointData(), + moderator=PassiveModerator(), + ) + return ex + + +def test_checkpoints_on_exchange(normal_exchange): + """Test checkpoints on an exchange.""" + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + + # Check if checkpoints are created correctly + checkpoints = ex.checkpoint_data.checkpoints + assert len(checkpoints) == 6 + for i in range(len(ex.messages)): + # asserting that each message has a corresponding checkpoint + assert checkpoints[i].start_index == i + assert checkpoints[i].end_index == i + + # Check if the messages are ordered correctly + assert [msg.content[0].text for msg in ex.messages] == [ + "User message", + "Message 1", + "User message", + "Message 2", + "User message", + "Message 3", + ] + assert no_overlapping_checkpoints(ex) + + +def test_checkpoints_on_resumed_exchange(resumed_exchange) -> None: + ex = resumed_exchange + ex.pop_last_message() + ex.reply() + + checkpoints = ex.checkpoint_data.checkpoints + assert len(checkpoints) == 2 + assert len(ex.messages) == 6 + assert checkpoints[0].token_count == 32 + assert checkpoints[0].start_index == 0 + assert checkpoints[0].end_index == 4 + assert checkpoints[1].token_count == 8 + assert checkpoints[1].start_index == 5 + assert checkpoints[1].end_index == 5 + assert no_overlapping_checkpoints(ex) + + +def test_pop_last_checkpoint_on_resumed_exchange(resumed_exchange) -> None: + ex = resumed_exchange + ex.add(Message(role="user", content=[Text(text="Assistant Message 4")])) + ex.reply() + ex.pop_last_checkpoint() + + assert len(ex.messages) == 7 + assert len(ex.checkpoint_data.checkpoints) == 1 + + ex.pop_last_checkpoint() + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert no_overlapping_checkpoints(ex) + + +def test_pop_last_checkpoint_on_normal_exchange(normal_exchange) -> None: + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + ex.pop_last_checkpoint() + ex.pop_last_checkpoint() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert no_overlapping_checkpoints(ex) + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.pop_last_checkpoint() + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + ex.reply() + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert no_overlapping_checkpoints(ex) + + +def test_pop_first_message_no_messages(): + ex = Exchange( + provider=MockProvider(sequence=[], usage_dicts=[]), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + + with pytest.raises(ValueError) as e: + ex.pop_first_message() + assert str(e.value) == "There are no messages to pop" + + +def test_pop_first_message_checkpoint_with_many_messages(resumed_exchange): + ex = resumed_exchange + ex.pop_last_message() + ex.reply() + + assert len(ex.messages) == 6 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.checkpoints[0].start_index == 0 + assert ex.checkpoint_data.checkpoints[0].end_index == 4 + assert ex.checkpoint_data.checkpoints[1].start_index == 5 + assert ex.checkpoint_data.checkpoints[1].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 0 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 5 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 4 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 3 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 3 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 4 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 5 + assert ex.checkpoint_data.checkpoints[0].end_index == 5 + assert ex.checkpoint_data.message_index_offset == 5 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + assert no_overlapping_checkpoints(ex) + + with pytest.raises(ValueError) as e: + ex.pop_first_message() + + assert str(e.value) == "There are no messages to pop" + + +def test_varied_message_manipulation(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message 1")])) + ex.reply() + + ex.pop_first_message() + + ex.add(Message(role="user", content=[Text(text="User message 2")])) + ex.reply() + + assert len(ex.messages) == 3 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert ex.checkpoint_data.message_index_offset == 1 + # (start, end) + # (1, 1), (2, 2), (3, 3) + # actual_index_in_messages_arr = any checkpoint index - offset + assert no_overlapping_checkpoints(ex) + for i in range(3): + assert ex.checkpoint_data.checkpoints[i].start_index == i + 1 + assert ex.checkpoint_data.checkpoints[i].end_index == i + 1 + + ex.pop_last_message() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + for i in range(2): + assert ex.checkpoint_data.checkpoints[i].start_index == i + 1 + assert ex.checkpoint_data.checkpoints[i].end_index == i + 1 + + ex.add(Message(role="assistant", content=[Text(text="Assistant message")])) + ex.add(Message(role="user", content=[Text(text="User message 3")])) + ex.reply() + + assert len(ex.messages) == 5 + assert len(ex.checkpoint_data.checkpoints) == 4 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4), (5, 5)] + + ex.pop_last_checkpoint() + + assert len(ex.messages) == 4 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert ex.checkpoint_data.message_index_offset == 1 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(1, 1), (2, 2), (3, 4)] + + ex.pop_first_message() + + assert len(ex.messages) == 3 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 4)] + + ex.pop_last_message() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)] + + ex.pop_last_message() + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)] + + ex.add(Message(role="assistant", content=[Text(text="Assistant message")])) + ex.add(Message(role="user", content=[Text(text="User message 5")])) + ex.pop_last_checkpoint() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + + ex.add(Message(role="user", content=[Text(text="User message 6")])) + ex.reply() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2), (3, 3)] + + ex.pop_last_message() + + assert len(ex.messages) == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.message_index_offset == 2 + assert no_overlapping_checkpoints(ex) + assert checkpoint_to_index_pairs(ex.checkpoint_data.checkpoints) == [(2, 2)] + + ex.pop_first_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + + ex.add(Message(role="user", content=[Text(text="User message 7")])) + ex.pop_last_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + + +def test_pop_last_message_when_no_checkpoints_but_messages_present(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + + ex.pop_last_message() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert ex.checkpoint_data.message_index_offset == 0 + + +def test_pop_first_message_when_no_checkpoints_but_message_present(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + + with pytest.raises(ValueError) as e: + ex.pop_first_message() + + assert str(e.value) == "There must be at least one checkpoint to pop the first message" + + +def test_pop_first_checkpoint_size_n(resumed_exchange): + ex = resumed_exchange + ex.pop_last_message() # needed because the last message is an assistant message + ex.reply() + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 5 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert len(ex.messages) == 1 + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert len(ex.messages) == 0 + + +def test_pop_first_checkpoint_size_1(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 1 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert len(ex.messages) == 1 + + ex.pop_first_checkpoint() + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + assert len(ex.messages) == 0 + + +def test_pop_first_checkpoint_no_checkpoints(normal_exchange): + ex = normal_exchange + + with pytest.raises(ValueError) as e: + ex.pop_first_checkpoint() + + assert str(e.value) == "There are no checkpoints to pop" + + +def test_prepend_checkpointed_message_empty_exchange(normal_exchange): + ex = normal_exchange + ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10) + + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 1 + assert ex.checkpoint_data.checkpoints[0].start_index == 0 + assert ex.checkpoint_data.checkpoints[0].end_index == 0 + + ex.add(Message(role="user", content=[Text(text="User message")])) + ex.reply() + + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert len(ex.messages) == 3 + assert no_overlapping_checkpoints(ex) + + ex.pop_first_checkpoint() + + assert ex.checkpoint_data.message_index_offset == 1 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert len(ex.messages) == 2 + assert no_overlapping_checkpoints(ex) + + ex.prepend_checkpointed_message(Message(role="assistant", content=[Text(text="Assistant message")]), 10) + assert ex.checkpoint_data.message_index_offset == 0 + assert len(ex.checkpoint_data.checkpoints) == 3 + assert len(ex.messages) == 3 + assert no_overlapping_checkpoints(ex) + + +def test_generate_successful_response_on_first_try(normal_exchange): + ex = normal_exchange + ex.add(Message(role="user", content=[Text("Hello")])) + ex.generate() + + +def test_rewind_in_normal_exchange(normal_exchange): + ex = normal_exchange + ex.rewind() + + assert len(ex.messages) == 0 + assert len(ex.checkpoint_data.checkpoints) == 0 + + ex.add(Message(role="user", content=[Text("Hello")])) + ex.generate() + ex.add(Message(role="user", content=[Text("Hello")])) + + # testing if it works with a user text message at the end + ex.rewind() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + + ex.add(Message(role="user", content=[Text("Hello")])) + ex.generate() + + # testing if it works with a non-user text message at the end + ex.rewind() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + + +def test_rewind_with_tool_usage(): + # simulating a real exchange with tool usage + ex = Exchange( + provider=MockProvider( + sequence=[ + Message.assistant("Hello!"), + Message( + role="assistant", + content=[ToolUse(id="1", name="dummy_tool", parameters={})], + ), + Message( + role="assistant", + content=[ToolUse(id="2", name="dummy_tool", parameters={})], + ), + Message.assistant("Done!"), + ], + usage_dicts=[ + {"usage": {"input_tokens": 12, "output_tokens": 23}}, + {"usage": {"input_tokens": 27, "output_tokens": 44}}, + {"usage": {"input_tokens": 50, "output_tokens": 56}}, + {"usage": {"input_tokens": 60, "output_tokens": 76}}, + ], + ), + model="gpt-4o-2024-05-13", + system="You are a helpful assistant.", + tools=[Tool.from_function(dummy_tool)], + moderator=PassiveModerator(), + ) + ex.add(Message(role="user", content=[Text(text="test")])) + ex.reply() + ex.add(Message(role="user", content=[Text(text="kick it off!")])) + ex.reply() + + # removing the last message to simulate not getting a response + ex.pop_last_message() + + # calling rewind to last user message + ex.rewind() + + assert len(ex.messages) == 2 + assert len(ex.checkpoint_data.checkpoints) == 2 + assert no_overlapping_checkpoints(ex) + assert ex.messages[0].content[0].text == "test" + assert type(ex.messages[1].content[0]) is Text + assert ex.messages[1].role == "assistant" diff --git a/packages/exchange/tests/test_exchange_collect_usage.py b/packages/exchange/tests/test_exchange_collect_usage.py new file mode 100644 index 000000000..590dc709b --- /dev/null +++ b/packages/exchange/tests/test_exchange_collect_usage.py @@ -0,0 +1,33 @@ +from unittest.mock import MagicMock +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators.passive import PassiveModerator +from exchange.providers.base import Provider +from exchange.tool import Tool +from exchange.token_usage_collector import _TokenUsageCollector + +MODEL_NAME = "test-model" + + +def create_exchange(mock_provider, dummy_tool): + return Exchange( + provider=mock_provider, + model=MODEL_NAME, + system="test-system", + tools=(Tool.from_function(dummy_tool),), + messages=[], + moderator=PassiveModerator(), + ) + + +def test_exchange_generate_collect_usage(usage_factory, dummy_tool, monkeypatch): + mock_provider = MagicMock(spec=Provider) + mock_usage_collector = MagicMock(spec=_TokenUsageCollector) + usage = usage_factory() + mock_provider.complete.return_value = (Message.assistant("msg"), usage) + exchange = create_exchange(mock_provider, dummy_tool) + + monkeypatch.setattr("exchange.exchange._token_usage_collector", mock_usage_collector) + exchange.generate() + + mock_usage_collector.collect.assert_called_once_with(MODEL_NAME, usage) diff --git a/packages/exchange/tests/test_exchange_frozen.py b/packages/exchange/tests/test_exchange_frozen.py new file mode 100644 index 000000000..a3095b3a3 --- /dev/null +++ b/packages/exchange/tests/test_exchange_frozen.py @@ -0,0 +1,48 @@ +import pytest +from attr.exceptions import FrozenInstanceError +from exchange.content import Text +from exchange.exchange import Exchange +from exchange.moderators import PassiveModerator +from exchange.message import Message +from exchange.providers import Provider, Usage +from exchange.tool import Tool + + +class MockProvider(Provider): + def complete(self, model, system, messages, tools=None): + return Message(role="assistant", content=[Text(text="This is a mock response.")]), Usage.from_dict( + {"total_tokens": 35} + ) + + +def test_exchange_immutable(dummy_tool): + # Create an instance of Exchange + provider = MockProvider() + # intentionally setting a list instead of tuple on tools, it should be converted + exchange = Exchange( + provider=provider, + model="test-model", + system="test-system", + tools=(Tool.from_function(dummy_tool),), + messages=[Message(role="user", content=[Text(text="Hello!")])], + moderator=PassiveModerator(), + ) + + # Try to directly modify a field (should raise an error) + with pytest.raises(FrozenInstanceError): + exchange.system = "" + + with pytest.raises(AttributeError): + exchange.tools.append("anything") + + # Replace method should return a new instance with deepcopy of messages + new_exchange = exchange.replace(system="changed") + + assert new_exchange.system == "changed" + assert len(exchange.messages) == 1 + assert len(new_exchange.messages) == 1 + + # Ensure that the messages are deep copied + new_exchange.messages[0].content[0].text = "Changed!" + assert exchange.messages[0].content[0].text == "Hello!" + assert new_exchange.messages[0].content[0].text == "Changed!" diff --git a/packages/exchange/tests/test_image.png b/packages/exchange/tests/test_image.png new file mode 100644 index 000000000..3488b8a51 Binary files /dev/null and b/packages/exchange/tests/test_image.png differ diff --git a/packages/exchange/tests/test_integration.py b/packages/exchange/tests/test_integration.py new file mode 100644 index 000000000..1eb198082 --- /dev/null +++ b/packages/exchange/tests/test_integration.py @@ -0,0 +1,89 @@ +import os +import pytest +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators import ContextTruncate +from exchange.providers import get_provider +from exchange.providers.ollama import OLLAMA_MODEL +from exchange.tool import Tool +from tests.conftest import read_file + +too_long_chars = "x" * (2**20 + 1) + +cases = [ + # Set seed and temperature for more determinism, to avoid flakes + (get_provider("ollama"), os.getenv("OLLAMA_MODEL", OLLAMA_MODEL), dict(seed=3, temperature=0.1)), + (get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini"), dict()), + (get_provider("azure"), os.getenv("AZURE_MODEL", "gpt-4o-mini"), dict()), + (get_provider("databricks"), "databricks-meta-llama-3-70b-instruct", dict()), + (get_provider("bedrock"), "anthropic.claude-3-5-sonnet-20240620-v1:0", dict()), + (get_provider("google"), "gemini-1.5-flash", dict()), +] + + +@pytest.mark.integration +@pytest.mark.parametrize("provider,model,kwargs", cases) +def test_simple(provider, model, kwargs): + provider = provider.from_env() + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant.", + generation_args=kwargs, + ) + + ex.add(Message.user("Who is the most famous wizard from the lord of the rings")) + + response = ex.reply() + + # It's possible this can be flakey, but in experience so far haven't seen it + assert "gandalf" in response.text.lower() + + +@pytest.mark.integration +@pytest.mark.parametrize("provider,model,kwargs", cases) +def test_tools(provider, model, kwargs, tmp_path): + provider = provider.from_env() + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant. Expect to need to read a file using read_file.", + tools=(Tool.from_function(read_file),), + generation_args=kwargs, + ) + + ex.add(Message.user("What are the contents of this file? test.txt")) + + response = ex.reply() + + assert "hello exchange" in response.text.lower() + + +@pytest.mark.integration +@pytest.mark.parametrize("provider,model,kwargs", cases) +def test_tool_use_output_chars(provider, model, kwargs): + provider = provider.from_env() + + def get_password() -> str: + """Return the password for authentication""" + return too_long_chars + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant. Expect to need to authenticate using get_password.", + tools=(Tool.from_function(get_password),), + generation_args=kwargs, + ) + + ex.add(Message.user("Can you authenticate this session by responding with the password")) + + ex.reply() + + # Without our error handling, this would raise + # string too long. Expected a string with maximum length 1048576, but got a string with length ... diff --git a/packages/exchange/tests/test_integration_vision.py b/packages/exchange/tests/test_integration_vision.py new file mode 100644 index 000000000..20f165ade --- /dev/null +++ b/packages/exchange/tests/test_integration_vision.py @@ -0,0 +1,44 @@ +import os + +import pytest +from exchange.content import ToolResult, ToolUse +from exchange.exchange import Exchange +from exchange.message import Message +from exchange.moderators import ContextTruncate +from exchange.providers import get_provider + +cases = [ + (get_provider("openai"), os.getenv("OPENAI_MODEL", "gpt-4o-mini")), +] + + +@pytest.mark.integration # skipped in CI/CD +@pytest.mark.parametrize("provider,model", cases) +def test_simple(provider, model): + provider = provider.from_env() + + ex = Exchange( + provider=provider, + model=model, + moderator=ContextTruncate(model), + system="You are a helpful assistant.", + ) + + ex.add(Message.user("What does the first entry in the menu say?")) + ex.add( + Message( + role="assistant", + content=[ToolUse(id="xyz", name="screenshot", parameters={})], + ) + ) + ex.add( + Message( + role="user", + content=[ToolResult(tool_use_id="xyz", output='"image:tests/test_image.png"')], + ) + ) + + response = ex.reply() + + # It's possible this can be flakey, but in experience so far haven't seen it + assert "ask goose" in response.text.lower() diff --git a/packages/exchange/tests/test_invalid_choice_error.py b/packages/exchange/tests/test_invalid_choice_error.py new file mode 100644 index 000000000..9fad8b12f --- /dev/null +++ b/packages/exchange/tests/test_invalid_choice_error.py @@ -0,0 +1,13 @@ +from exchange.invalid_choice_error import InvalidChoiceError + + +def test_load_invalid_choice_error(): + attribute_name = "moderator" + attribute_value = "not_exist" + available_values = ["truncate", "summarizer"] + error = InvalidChoiceError(attribute_name, attribute_value, available_values) + + assert error.attribute_name == attribute_name + assert error.attribute_value == attribute_value + assert error.attribute_value == attribute_value + assert error.message == "Unknown moderator: not_exist. Available moderators: truncate, summarizer" diff --git a/packages/exchange/tests/test_message.py b/packages/exchange/tests/test_message.py new file mode 100644 index 000000000..d5442eb75 --- /dev/null +++ b/packages/exchange/tests/test_message.py @@ -0,0 +1,96 @@ +import subprocess +from pathlib import Path +import pytest + +from exchange.message import Message +from exchange.content import Text, ToolUse, ToolResult + + +def test_user_message(): + user_message = Message.user("abcd") + assert user_message.role == "user" + assert user_message.text == "abcd" + + +def test_assistant_message(): + assistant_message = Message.assistant("abcd") + assert assistant_message.role == "assistant" + assert assistant_message.text == "abcd" + + +def test_message_tool_use(): + from exchange.content import ToolUse + + tu1 = ToolUse(id="1", name="tool", parameters={}) + tu2 = ToolUse(id="2", name="tool", parameters={}) + message = Message(role="assistant", content=[tu1, tu2]) + assert len(message.tool_use) == 2 + assert message.tool_use[0].name == "tool" + + +def test_message_tool_result(): + from exchange.content import ToolResult + + tr1 = ToolResult(tool_use_id="1", output="result") + tr2 = ToolResult(tool_use_id="2", output="result") + message = Message(role="user", content=[tr1, tr2]) + assert len(message.tool_result) == 2 + assert message.tool_result[0].output == "result" + + +def test_message_load(tmpdir): + # To emulate the expected relative lookup, we need to create a mock code dir + # and run the load in a subprocess + test_dir = Path(tmpdir) + + # Create a temporary Jinja template file in the test_dir + template_content = "hello {{ name }} {% include 'relative.jinja' %}" + template_path = test_dir / "template.jinja" + template_path.write_text(template_content) + + relative_content = "and {{ name2 }}" + relative_path = test_dir / "relative.jinja" + relative_path.write_text(relative_content) + + # Create a temporary Python file in the sub_dir that calls the load method with a relative path + python_file_content = """ +from exchange.message import Message + +def test_function(): + message = Message.load('template.jinja', name="a", name2="b") + assert message.text == "hello a and b" + assert message.role == "user" + +test_function() +""" + python_file_path = test_dir / "test_script.py" + python_file_path.write_text(python_file_content) + + # Execute the temporary Python file to test the relative lookup functionality + result = subprocess.run(["python3", str(python_file_path)]) + + assert result.returncode == 0 + + +def test_message_validation(): + # Valid user message + message = Message(role="user", content=[Text(text="Hello")]) + assert message.text == "Hello" + + # Valid assistant message + message = Message(role="assistant", content=[Text(text="Hello")]) + assert message.text == "Hello" + + # Invalid message: user with tool_use + with pytest.raises(ValueError): + Message( + role="user", + content=[Text(text=""), ToolUse(id="1", name="tool", parameters={})], + ) + + # Invalid message: assistant with tool_result + with pytest.raises(ValueError): + Message( + role="assistant", + content=[Text(text=""), ToolResult(tool_use_id="1", output="result")], + ) diff --git a/packages/exchange/tests/test_moderators.py b/packages/exchange/tests/test_moderators.py new file mode 100644 index 000000000..16bcaa13b --- /dev/null +++ b/packages/exchange/tests/test_moderators.py @@ -0,0 +1,17 @@ +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.moderators import get_moderator +import pytest + + +def test_get_moderator(): + moderator = get_moderator("truncate") + assert moderator.__name__ == "ContextTruncate" + + +def test_get_moderator_raise_error_for_unknown_moderator(): + with pytest.raises(InvalidChoiceError) as error: + get_moderator("nonexistent") + assert error.value.attribute_name == "moderator" + assert error.value.attribute_value == "nonexistent" + assert "truncate" in error.value.available_values + assert "truncate" in error.value.message diff --git a/packages/exchange/tests/test_summarizer.py b/packages/exchange/tests/test_summarizer.py new file mode 100644 index 000000000..fa7281920 --- /dev/null +++ b/packages/exchange/tests/test_summarizer.py @@ -0,0 +1,227 @@ +import pytest +from exchange import Exchange, Message +from exchange.content import ToolResult, ToolUse +from exchange.moderators.passive import PassiveModerator +from exchange.moderators.summarizer import ContextSummarizer +from exchange.providers import Usage + + +class MockProvider: + def complete(self, model, system, messages, tools): + assistant_message_text = "Summarized content here." + output_tokens = len(assistant_message_text) + total_input_tokens = sum(len(msg.text) for msg in messages) + if not messages or messages[-1].role == "assistant": + message = Message.user(assistant_message_text) + else: + message = Message.assistant(assistant_message_text) + total_tokens = total_input_tokens + output_tokens + usage = Usage( + input_tokens=total_input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + return message, usage + + +@pytest.fixture +def exchange_instance(): + ex = Exchange( + provider=MockProvider(), + model="test-model", + system="test-system", + messages=[ + Message.user("Hi, can you help me with my homework?"), + Message.assistant("Of course! What do you need help with?"), + Message.user("I need help with math problems."), + Message.assistant("Sure, I can help with that. Let's get started."), + Message.user("Can you also help with my science homework?"), + Message.assistant("Yes, I can help with science too."), + Message.user("That's great! How about history?"), + Message.assistant("Of course, I can help with history as well."), + Message.user("Thanks! You're very helpful."), + Message.assistant("You're welcome! I'm here to help."), + ], + moderator=PassiveModerator(), + ) + return ex + + +@pytest.fixture +def summarizer_instance(): + return ContextSummarizer(max_tokens=300) + + +def test_context_summarizer_rewrite(exchange_instance: Exchange, summarizer_instance: ContextSummarizer): + # Pre-checks + assert len(exchange_instance.messages) == 10 + + exchange_instance.generate() + + # the exchange instance has a PassiveModerator so the messages are not truncated nor summarized + assert len(exchange_instance.messages) == 11 + assert len(exchange_instance.checkpoint_data.checkpoints) == 2 + + # we now tell the summarizer to summarize the exchange + summarizer_instance.rewrite(exchange_instance) + + assert exchange_instance.checkpoint_data.total_token_count <= 200 + assert len(exchange_instance.messages) == 2 + + # Assert that summarized content is the first message + first_message = exchange_instance.messages[0] + assert first_message.role == "user" or first_message.role == "assistant" + assert any("summarized" in content.text.lower() for content in first_message.content) + + # Ensure roles alternate in the output + for i in range(1, len(exchange_instance.messages)): + assert ( + exchange_instance.messages[i - 1].role != exchange_instance.messages[i].role + ), "Messages must alternate between user and assistant" + + +MESSAGE_SEQUENCE = [ + Message.user("Hi, can you help me with my homework?"), + Message.assistant("Of course! What do you need help with?"), + Message.user("I need help with math problems."), + Message.assistant("Sure, I can help with that. Let's get started."), + Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"), + Message( + role="assistant", + content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]), + Message( + role="assistant", + content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})], + ), + Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]), + Message( + role="assistant", + content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})], + ), + Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]), + Message( + role="assistant", + content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})], + ), + Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]), + Message( + role="assistant", + content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]), + Message.assistant("I'm done calculating the answers to your math questions."), + Message.user("Can you also help with my science homework?"), + Message.assistant("Yes, I can help with science too."), + Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"), + Message( + role="assistant", + content=[ToolUse(id="6", name="speed_of_light", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]), + Message( + role="assistant", + content=[ToolUse(id="7", name="photon_frequency", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]), + Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]), + Message( + role="user", + content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")], + ), + Message.assistant("I'm done calculating the answers to your science questions."), + Message.user("That's great! How about history?"), + Message.assistant("Of course, I can help with history as well."), + Message.user("Thanks! You're very helpful."), + Message.assistant("You're welcome! I'm here to help."), +] + + +class AnotherMockProvider: + def __init__(self): + self.sequence = MESSAGE_SEQUENCE + self.current_index = 1 + self.summarize_next = False + self.summarized_count = 0 + + def complete(self, model, system, messages, tools): + system_prompt_tokens = 100 + input_token_count = system_prompt_tokens + + message = self.sequence[self.current_index] + if self.summarize_next: + text = "Summary message here" + self.summarize_next = False + self.summarized_count += 1 + return Message.assistant(text=text), Usage( + # in this case, input tokens can really be whatever + input_tokens=40, + output_tokens=len(text) * 2, + total_tokens=40 + len(text) * 2, + ) + + if len(messages) > 0 and type(messages[0].content[0]) is ToolResult: + raise ValueError("ToolResult should not be the first message") + + if len(messages) == 1 and messages[0].text == "a": + # adding a +1 for the "a" + return Message.assistant("Getting system prompt size"), Usage( + input_tokens=80 + 1, output_tokens=20, total_tokens=system_prompt_tokens + 1 + ) + + for i in range(len(messages)): + if type(messages[i].content[0]) in (ToolResult, ToolUse): + input_token_count += 10 + else: + input_token_count += len(messages[i].text) * 2 + + if type(message.content[0]) in (ToolResult, ToolUse): + output_tokens = 10 + else: + output_tokens = len(message.text) * 2 + + total_tokens = input_token_count + output_tokens + if total_tokens > 300: + self.summarize_next = True + usage = Usage( + input_tokens=input_token_count, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + self.current_index += 2 + return message, usage + + +@pytest.fixture +def conversation_exchange_instance(): + ex = Exchange( + provider=AnotherMockProvider(), + model="test-model", + system="test-system", + moderator=ContextSummarizer(max_tokens=300), + # TODO: make it work with an offset so we don't have to send off requests basically + # at every generate step + ) + return ex + + +def test_summarizer_generic_conversation(conversation_exchange_instance: Exchange): + i = 0 + while i < len(MESSAGE_SEQUENCE): + next_message = MESSAGE_SEQUENCE[i] + conversation_exchange_instance.add(next_message) + message = conversation_exchange_instance.generate() + if message.text != "Summary message here": + i += 2 + checkpoints = conversation_exchange_instance.checkpoint_data.checkpoints + assert conversation_exchange_instance.checkpoint_data.total_token_count == 570 + assert len(checkpoints) == 10 + assert len(conversation_exchange_instance.messages) == 10 + assert checkpoints[0].start_index == 20 + assert checkpoints[0].end_index == 20 + assert checkpoints[-1].start_index == 29 + assert checkpoints[-1].end_index == 29 + assert conversation_exchange_instance.checkpoint_data.message_index_offset == 20 + assert conversation_exchange_instance.provider.summarized_count == 12 + assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 diff --git a/packages/exchange/tests/test_token_usage_collector.py b/packages/exchange/tests/test_token_usage_collector.py new file mode 100644 index 000000000..d277f63e9 --- /dev/null +++ b/packages/exchange/tests/test_token_usage_collector.py @@ -0,0 +1,24 @@ +from exchange.token_usage_collector import _TokenUsageCollector + + +def test_collect(usage_factory): + usage_collector = _TokenUsageCollector() + usage_collector.collect("model1", usage_factory(100, 1000, 1100)) + usage_collector.collect("model1", usage_factory(200, 2000, 2200)) + usage_collector.collect("model2", usage_factory(400, 4000, 4400)) + usage_collector.collect("model3", usage_factory(500, 5000, 5500)) + usage_collector.collect("model3", usage_factory(600, 6000, 6600)) + assert usage_collector.get_token_usage_group_by_model() == { + "model1": usage_factory(300, 3000, 3300), + "model2": usage_factory(400, 4000, 4400), + "model3": usage_factory(1100, 11000, 12100), + } + + +def test_collect_with_non_input_or_output_token(usage_factory): + usage_collector = _TokenUsageCollector() + usage_collector.collect("model1", usage_factory(100, None, None)) + usage_collector.collect("model1", usage_factory(None, 2000, None)) + assert usage_collector.get_token_usage_group_by_model() == { + "model1": usage_factory(100, 2000, 0), + } diff --git a/packages/exchange/tests/test_tool.py b/packages/exchange/tests/test_tool.py new file mode 100644 index 000000000..847e79fb5 --- /dev/null +++ b/packages/exchange/tests/test_tool.py @@ -0,0 +1,161 @@ +import attrs +from exchange.tool import Tool + + +def get_current_weather(location: str) -> None: + """Get the current weather in a given location + + Args: + location (str): The city and state, e.g. San Francisco, CA + """ + pass + + +def test_load(): + tool = Tool.from_function(get_current_weather) + + expected = { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + "required": ["location"], + }, + "function": get_current_weather, + } + + assert attrs.asdict(tool) == expected + + +def another_function( + param1: int, + param2: str, + param3: bool, + param4: float, + param5: list[int], + param6: dict[str, float], +) -> None: + """ + This is another example function with various types + + Args: + param1 (int): Description for param1 + param2 (str): Description for param2 + param3 (bool): Description for param3 + param4 (float): Description for param4 + param5 (list[int]): Description for param5 + param6 (dict[str, float]): Description for param6 + """ + pass + + +def test_load_types(): + tool = Tool.from_function(another_function) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + "param2": {"type": "string", "description": "Description for param2"}, + "param3": {"type": "boolean", "description": "Description for param3"}, + "param4": {"type": "number", "description": "Description for param4"}, + "param5": { + "type": "array", + "items": {"type": "integer"}, + "description": "Description for param5", + }, + "param6": { + "type": "object", + "additionalProperties": {"type": "number"}, + "description": "Description for param6", + }, + }, + "required": ["param1", "param2", "param3", "param4", "param5", "param6"], + } + assert tool.parameters == expected_schema + + +def numpy_function(param1: int, param2: str) -> None: + """ + This function uses numpy style docstrings + + Parameters + ---------- + param1 : int + Description for param1 + param2 : str + Description for param2 + """ + pass + + +def test_load_numpy_style(): + tool = Tool.from_function(numpy_function) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + "param2": {"type": "string", "description": "Description for param2"}, + }, + "required": ["param1", "param2"], + } + assert tool.parameters == expected_schema + + +def sphinx_function(param1: int, param2: str, param3: bool) -> None: + """ + This function uses sphinx style docstrings + + :param param1: Description for param1 + :type param1: int + :param param2: Description for param2 + :type param2: str + :param param3: Description for param3 + :type param3: bool + """ + pass + + +def test_load_sphinx_style(): + tool = Tool.from_function(sphinx_function) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + "param2": {"type": "string", "description": "Description for param2"}, + "param3": {"type": "boolean", "description": "Description for param3"}, + }, + "required": ["param1", "param2", "param3"], + } + assert tool.parameters == expected_schema + + +class FunctionLike: + def __init__(self, state: int) -> None: + self.state = state + + def __call__(self, param1: int) -> int: + """Example + + Args: + param1 (int): Description for param1 + """ + return self.state + param1 + + +def test_load_stateful_class(): + tool = Tool.from_function(FunctionLike(1)) + expected_schema = { + "type": "object", + "properties": { + "param1": {"type": "integer", "description": "Description for param1"}, + }, + "required": ["param1"], + } + assert tool.parameters == expected_schema + assert tool.function(2) == 3 diff --git a/packages/exchange/tests/test_truncate.py b/packages/exchange/tests/test_truncate.py new file mode 100644 index 000000000..3875303e7 --- /dev/null +++ b/packages/exchange/tests/test_truncate.py @@ -0,0 +1,132 @@ +import pytest +from exchange import Exchange +from exchange.content import ToolResult, ToolUse +from exchange.message import Message +from exchange.moderators.truncate import ContextTruncate +from exchange.providers import Provider, Usage + +MAX_TOKENS = 300 +SYSTEM_PROMPT_TOKENS = 100 + +MESSAGE_SEQUENCE = [ + Message.user("Hi, can you help me with my homework?"), + Message.assistant("Of course! What do you need help with?"), + Message.user("I need help with math problems."), + Message.assistant("Sure, I can help with that. Let's get started."), + Message.user("What is 2 + 2, 3*3, 9/5, 2*20, 14/2?"), + Message( + role="assistant", + content=[ToolUse(id="1", name="add", parameters={"a": 2, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="1", output="4")]), + Message( + role="assistant", + content=[ToolUse(id="2", name="multiply", parameters={"a": 3, "b": 3})], + ), + Message(role="user", content=[ToolResult(tool_use_id="2", output="9")]), + Message( + role="assistant", + content=[ToolUse(id="3", name="divide", parameters={"a": 9, "b": 5})], + ), + Message(role="user", content=[ToolResult(tool_use_id="3", output="1.8")]), + Message( + role="assistant", + content=[ToolUse(id="4", name="multiply", parameters={"a": 2, "b": 20})], + ), + Message(role="user", content=[ToolResult(tool_use_id="4", output="40")]), + Message( + role="assistant", + content=[ToolUse(id="5", name="divide", parameters={"a": 14, "b": 2})], + ), + Message(role="user", content=[ToolResult(tool_use_id="5", output="7")]), + Message.assistant("I'm done calculating the answers to your math questions."), + Message.user("Can you also help with my science homework?"), + Message.assistant("Yes, I can help with science too."), + Message.user("What is the speed of light? The frequency of a photon? The mass of an electron?"), + Message( + role="assistant", + content=[ToolUse(id="6", name="speed_of_light", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="6", output="299,792,458 m/s")]), + Message( + role="assistant", + content=[ToolUse(id="7", name="photon_frequency", parameters={})], + ), + Message(role="user", content=[ToolResult(tool_use_id="7", output="2.418 x 10^14 Hz")]), + Message(role="assistant", content=[ToolUse(id="8", name="electron_mass", parameters={})]), + Message( + role="user", + content=[ToolResult(tool_use_id="8", output="9.10938356 x 10^-31 kg")], + ), + Message.assistant("I'm done calculating the answers to your science questions."), + Message.user("That's great! How about history?"), + Message.assistant("Of course, I can help with history as well."), + Message.user("Thanks! You're very helpful."), + Message.assistant("You're welcome! I'm here to help."), +] + + +class TruncateLinearProvider(Provider): + def __init__(self): + self.sequence = MESSAGE_SEQUENCE + self.current_index = 1 + self.summarize_next = False + self.summarized_count = 0 + + def complete(self, model, system, messages, tools): + input_token_count = SYSTEM_PROMPT_TOKENS + + message = self.sequence[self.current_index] + + if len(messages) > 0 and type(messages[0].content[0]) is ToolResult: + raise ValueError("ToolResult should not be the first message") + + if len(messages) == 1 and messages[0].text == "a": + # adding a +1 for the "a" + return Message.assistant("Getting system prompt size"), Usage( + input_tokens=80 + 1, output_tokens=20, total_tokens=SYSTEM_PROMPT_TOKENS + 1 + ) + + for i in range(len(messages)): + if type(messages[i].content[0]) in (ToolResult, ToolUse): + input_token_count += 10 + else: + input_token_count += len(messages[i].text) * 2 + + if type(message.content[0]) in (ToolResult, ToolUse): + output_tokens = 10 + else: + output_tokens = len(message.text) * 2 + + total_tokens = input_token_count + output_tokens + usage = Usage( + input_tokens=input_token_count, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) + self.current_index += 2 + return message, usage + + +@pytest.fixture +def conversation_exchange_instance(): + ex = Exchange( + provider=TruncateLinearProvider(), + model="test-model", + system="test-system", + moderator=ContextTruncate(max_tokens=500), + ) + return ex + + +def test_truncate_on_generic_conversation(conversation_exchange_instance: Exchange): + i = 0 + while i < len(MESSAGE_SEQUENCE): + next_message = MESSAGE_SEQUENCE[i] + conversation_exchange_instance.add(next_message) + message = conversation_exchange_instance.generate() + if message.text != "Summary message here": + i += 2 + # ensure the total token count is not anything exhorbitant + assert conversation_exchange_instance.checkpoint_data.total_token_count < 700 + assert conversation_exchange_instance.moderator.system_prompt_token_count == 100 diff --git a/packages/exchange/tests/test_utils.py b/packages/exchange/tests/test_utils.py new file mode 100644 index 000000000..6bc00f9e0 --- /dev/null +++ b/packages/exchange/tests/test_utils.py @@ -0,0 +1,125 @@ +import pytest +from exchange import utils +from unittest.mock import patch +from exchange.message import Message +from exchange.content import Text, ToolResult +from exchange.providers.utils import messages_to_openai_spec, encode_image + + +def test_encode_image(): + image_path = "tests/test_image.png" + encoded_image = encode_image(image_path) + + # Adjust this string based on the actual initial part of your base64-encoded image. + expected_start = "iVBORw0KGgo" + assert encoded_image.startswith(expected_start) + + +def test_create_object_id() -> None: + prefix = "test" + object_id = utils.create_object_id(prefix) + assert object_id.startswith(prefix + "_") + assert len(object_id) == len(prefix) + 1 + 24 # prefix + _ + 24 chars + + +def test_compact() -> None: + content = "This is \n\n a test" + compacted = utils.compact(content) + assert compacted == "This is a test" + + +def test_parse_docstring() -> None: + def dummy_func(a, b, c): + """ + This function does something. + + Args: + a (int): The first parameter. + b (str): The second parameter. + c (list): The third parameter. + """ + pass + + description, parameters = utils.parse_docstring(dummy_func) + assert description == "This function does something." + assert parameters == [ + {"name": "a", "annotation": "int", "description": "The first parameter."}, + {"name": "b", "annotation": "str", "description": "The second parameter."}, + {"name": "c", "annotation": "list", "description": "The third parameter."}, + ] + + +def test_parse_docstring_no_description() -> None: + def dummy_func(a, b, c): + """ + Args: + a (int): The first parameter. + b (str): The second parameter. + c (list): The third parameter. + """ + pass + + with pytest.raises(ValueError) as e: + utils.parse_docstring(dummy_func) + + assert "Attempted to load from a function" in str(e.value) + + +def test_json_schema() -> None: + def dummy_func(a: int, b: str, c: list) -> None: + pass + + schema = utils.json_schema(dummy_func) + + assert schema == { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "string"}, + "c": {"type": "string"}, + }, + "required": ["a", "b", "c"], + } + + +def test_load_plugins() -> None: + class DummyEntryPoint: + def __init__(self, name, plugin): + self.name = name + self.plugin = plugin + + def load(self): + return self.plugin + + with patch("exchange.utils.entry_points") as entry_points_mock: + entry_points_mock.return_value = [ + DummyEntryPoint("plugin1", object()), + DummyEntryPoint("plugin2", object()), + ] + + plugins = utils.load_plugins("dummy_group") + + assert "plugin1" in plugins + assert "plugin2" in plugins + assert len(plugins) == 2 + + +def test_messages_to_openai_spec(): + # Use provided test image + png_path = "tests/test_image.png" + + # Create a list of messages as input + messages = [ + Message(role="user", content=[Text(text="Hello, Assistant!")]), + Message(role="assistant", content=[Text(text="Here is a text with tool usage")]), + Message( + role="tool", + content=[ToolResult(tool_use_id="1", output=f'"image:{png_path}')], + ), + ] + + # Call the function + output = messages_to_openai_spec(messages) + + assert "This tool result included an image that is uploaded in the next message." in str(output) + assert "{'role': 'user', 'content': [{'type': 'image_url'" in str(output) diff --git a/pyproject.toml b/pyproject.toml index 03eb480cd..56d32a5a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,17 @@ [project] name = "goose-ai" description = "a programming agent that runs on your machine" -version = "0.9.0" +version = "0.9.3" readme = "README.md" requires-python = ">=3.10" dependencies = [ + "ai-exchange", "attrs>=23.2.0", "rich>=13.7.1", "ruamel-yaml>=0.18.6", - "ai-exchange>=0.9.0", "click>=8.1.7", "prompt-toolkit>=3.0.47", + "keyring>=25.4.1", ] author = [{ name = "Block", email = "ai-oss-tools@block.xyz" }] packages = [{ include = "goose", from = "src" }] @@ -52,6 +53,26 @@ build-backend = "hatchling.build" [tool.uv] dev-dependencies = [ - "pytest>=8.3.2", "codecov>=2.1.13", + "mkdocs-callouts>=1.14.0", + "mkdocs-gen-files>=0.5.0", + "mkdocs-git-authors-plugin>=0.9.0", + "mkdocs-git-committers-plugin>=0.2.3", + "mkdocs-git-revision-date-localized-plugin>=1.2.9", + "mkdocs-glightbox>=0.4.0", + "mkdocs-include-markdown-plugin>=6.2.2", + "mkdocs-literate-nav>=0.6.1", + "mkdocs-material>=9.5.34", + "mkdocs-redirects>=1.2.1", + "mkdocs-section-index>=0.3.9", + "mkdocstrings-python>=1.11.1", + "mkdocstrings>=0.26.1", + "pytest-mock>=3.14.0", + "pytest>=8.3.2" ] + +[tool.uv.sources] +ai-exchange = { workspace = true } + +[tool.uv.workspace] +members = ["packages/*"] diff --git a/src/goose/_logger.py b/src/goose/_logger.py new file mode 100644 index 000000000..a364ceed7 --- /dev/null +++ b/src/goose/_logger.py @@ -0,0 +1,19 @@ +import logging +from pathlib import Path + +_LOGGER_NAME = "goose" +_LOGGER_FILE_NAME = "goose.log" + + +def setup_logging(log_file_directory: Path, log_level: str = "INFO") -> None: + logger = logging.getLogger(_LOGGER_NAME) + logger.setLevel(getattr(logging, log_level)) + log_file_directory.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(log_file_directory / _LOGGER_FILE_NAME) + logger.addHandler(file_handler) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + file_handler.setFormatter(formatter) + + +def get_logger() -> logging.Logger: + return logging.getLogger(_LOGGER_NAME) diff --git a/src/goose/cli/config.py b/src/goose/cli/config.py index d5c60c038..7bede0be5 100644 --- a/src/goose/cli/config.py +++ b/src/goose/cli/config.py @@ -1,22 +1,22 @@ from functools import cache -from io import StringIO from pathlib import Path -from typing import Callable, Dict, Mapping, Tuple +from typing import Callable, Dict, Mapping, Optional, Tuple from rich import print from rich.panel import Panel -from rich.prompt import Confirm -from rich.text import Text from ruamel.yaml import YAML +from exchange.providers.ollama import OLLAMA_MODEL + from goose.profile import Profile from goose.utils import load_plugins -from goose.utils.diff import pretty_diff GOOSE_GLOBAL_PATH = Path("~/.config/goose").expanduser() PROFILES_CONFIG_PATH = GOOSE_GLOBAL_PATH.joinpath("profiles.yaml") SESSIONS_PATH = GOOSE_GLOBAL_PATH.joinpath("sessions") SESSION_FILE_SUFFIX = ".jsonl" +LOG_PATH = GOOSE_GLOBAL_PATH.joinpath("logs") +RECOMMENDED_DEFAULT_PROVIDER = "openai" @cache @@ -38,15 +38,18 @@ def write_config(profiles: Dict[str, Profile]) -> None: yaml.dump(converted, f) -def ensure_config(name: str) -> Profile: +def ensure_config(name: Optional[str]) -> Tuple[str, Profile]: """Ensure that the config exists and has the default section""" # TODO we should copy a templated default config in to better document # but this is complicated a bit by autodetecting the provider - + default_profile_name = "default" + name = name or default_profile_name + default_profiles_dict = default_profiles() provider, processor, accelerator = default_model_configuration() - profile = default_profiles()[name](provider, processor, accelerator) + default_profile = default_profiles_dict.get(name, default_profiles_dict[default_profile_name])( + provider, processor, accelerator + ) - profiles = {} if not PROFILES_CONFIG_PATH.exists(): print( Panel( @@ -55,50 +58,16 @@ def ensure_config(name: str) -> Profile: + "You can add your own profile in this file to further configure goose!" ) ) - default = profile - profiles = {name: default} - write_config(profiles) - return profile + write_config({name: default_profile}) + return (name, default_profile) profiles = read_config() - if name not in profiles: - print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]")) - profiles.update({name: profile}) - write_config(profiles) - elif name in profiles: - # if the profile stored differs from the default one, we should prompt the user to see if they want - # to update it! we need to recursively compare the two profiles, as object comparison will always return false - is_profile_eq = profile.to_dict() == profiles[name].to_dict() - if not is_profile_eq: - yaml = YAML() - before = StringIO() - after = StringIO() - yaml.dump(profiles[name].to_dict(), before) - yaml.dump(profile.to_dict(), after) - before.seek(0) - after.seek(0) - - print( - Panel( - Text( - f"Your profile uses one of the default options - '{name}'" - + " - but it differs from the latest version:\n\n", - ) - + pretty_diff(before.read(), after.read()) - ) - ) - # should_update = Confirm.ask( - # "Do you want to update your profile to use the latest?", - # default=False, - # ) - should_update = False - if should_update: - profiles[name] = profile - write_config(profiles) - else: - profile = profiles[name] - - return profile + if name in profiles: + return (name, profiles[name]) + print(Panel(f"[yellow]Your configuration doesn't have a profile named '{name}', adding one now[/yellow]")) + profiles.update({name: default_profile}) + write_config(profiles) + return (name, default_profile) def read_config() -> Dict[str, Profile]: @@ -116,17 +85,13 @@ def default_model_configuration() -> Tuple[str, str, str]: for provider, cls in providers.items(): try: cls.from_env() - print(Panel(f"[green]Detected an available provider: [/]{provider}")) break except Exception: pass else: - raise ValueError( - "Could not detect an available provider," - + " make sure to plugin a provider or set an env var such as OPENAI_API_KEY" - ) - + provider = RECOMMENDED_DEFAULT_PROVIDER recommended = { + "ollama": (OLLAMA_MODEL, OLLAMA_MODEL), "openai": ("gpt-4o", "gpt-4o-mini"), "anthropic": ( "claude-3-5-sonnet-20240620", diff --git a/src/goose/cli/main.py b/src/goose/cli/main.py index 0e266dd70..7d1359889 100644 --- a/src/goose/cli/main.py +++ b/src/goose/cli/main.py @@ -1,6 +1,7 @@ +import os from datetime import datetime from pathlib import Path -from typing import Dict, Optional +from typing import Optional import click from rich import print @@ -8,7 +9,9 @@ from goose.cli.config import SESSIONS_PATH from goose.cli.session import Session +from goose.toolkit.utils import render_template, parse_plan from goose.utils import load_plugins +from goose.utils.autocomplete import SUPPORTED_SHELLS, setup_autocomplete from goose.utils.session_file import list_sorted_session_files @@ -17,8 +20,8 @@ def goose_cli() -> None: pass -@goose_cli.command() -def version() -> None: +@goose_cli.command(name="version") +def get_version() -> None: """Lists the version of goose and any plugins""" from importlib.metadata import entry_points, version @@ -42,6 +45,38 @@ def version() -> None: print(f" [red]Could not retrieve version for {module}: {e}[/red]") +def get_current_shell() -> str: + return os.getenv("SHELL", "").split("/")[-1] + + +@goose_cli.command(name="shell-completions", help="Manage shell completions for goose") +@click.option("--install", is_flag=True, help="Install shell completions") +@click.option("--generate", is_flag=True, help="Generate shell completions") +@click.argument( + "shell", + type=click.Choice(SUPPORTED_SHELLS), + default=get_current_shell(), +) +@click.pass_context +def shell_completions(ctx: click.Context, install: bool, generate: bool, shell: str) -> None: + """Generate or install shell completions for goose + + Args: + shell (str): shell to install completions for + install (bool): installs completions if true, otherwise generates + completions + """ + if not any([install, generate]): + print("[red]One of --install or --generate must be specified[/red]\n") + raise click.UsageError(ctx.get_help()) + + if sum([install, generate]) > 1: + print("[red]Only one of --install or --generate can be specified[/red]\n") + raise click.UsageError(ctx.get_help()) + + setup_autocomplete(shell, install=install) + + @goose_cli.group() def session() -> None: """Start or manage sessions""" @@ -62,10 +97,46 @@ def list_toolkits() -> None: print(f" - [bold]{toolkit_name}[/bold]: {first_line_of_doc}") +@goose_cli.group() +def providers() -> None: + """Manage providers""" + pass + + +@providers.command(name="list") +def list_providers() -> None: + providers = load_plugins(group="exchange.provider") + + for provider_name, provider in providers.items(): + lines_doc = provider.__doc__.split("\n") + first_line_of_doc = lines_doc[0] + print(f" - [bold]{provider_name}[/bold]: {first_line_of_doc}") + envs = provider.REQUIRED_ENV_VARS + if envs: + env_required_str = ", ".join(envs) + print(f" [dim]env vars required: {env_required_str}") + + print("\n") + + +def autocomplete_session_files(ctx: click.Context, args: str, incomplete: str) -> None: + return [ + f"{session_name}" + for session_name in sorted(get_session_files().keys(), reverse=True, key=lambda x: x.lower()) + if session_name.startswith(incomplete) + ] + + +def get_session_files() -> dict[str, Path]: + return list_sorted_session_files(SESSIONS_PATH) + + @session.command(name="start") +@click.argument("name", required=False, shell_complete=autocomplete_session_files) @click.option("--profile") @click.option("--plan", type=click.Path(exists=True)) -def session_start(profile: str, plan: Optional[str] = None) -> None: +@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +def session_start(name: Optional[str], profile: str, log_level: str, plan: Optional[str] = None) -> None: """Start a new goose session""" if plan: yaml = YAML() @@ -73,28 +144,71 @@ def session_start(profile: str, plan: Optional[str] = None) -> None: _plan = yaml.load(f) else: _plan = None + session = Session(name=name, profile=profile, plan=_plan, log_level=log_level) + session.run() + + +def parse_args(ctx: click.Context, param: click.Parameter, value: str) -> dict[str, str]: + if not value: + return {} + args = {} + for item in value.split(","): + key, val = item.split(":") + args[key.strip()] = val.strip() - session = Session(profile=profile, plan=_plan) + return args + + +@session.command(name="planned") +@click.option("--plan", type=click.Path(exists=True)) +@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +@click.option("-a", "--args", callback=parse_args, help="Args in the format arg1:value1,arg2:value2") +def session_planned(plan: str, log_level: str, args: Optional[dict[str, str]]) -> None: + plan_templated = render_template(Path(plan), context=args) + _plan = parse_plan(plan_templated) + session = Session(plan=_plan, log_level=log_level) session.run() @session.command(name="resume") -@click.argument("name", required=False) +@click.argument("name", required=False, shell_complete=autocomplete_session_files) @click.option("--profile") -def session_resume(name: Optional[str], profile: str) -> None: +@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +def session_resume(name: Optional[str], profile: str, log_level: str) -> None: """Resume an existing goose session""" + session_files = get_session_files() if name is None: - session_files = get_session_files() if session_files: name = list(session_files.keys())[0] print(f"Resuming most recent session: {name} from {session_files[name]}") else: print("No sessions found.") return - session = Session(name=name, profile=profile) + else: + if name in session_files: + print(f"Resuming session: {name}") + else: + print(f"Creating new session: {name}") + session = Session(name=name, profile=profile, log_level=log_level) session.run() +@goose_cli.command(name="run") +@click.argument("message_file", required=False, type=click.Path(exists=True)) +@click.option("--profile") +@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]), default="INFO") +def run(message_file: Optional[str], profile: str, log_level: str) -> None: + """Run a single-pass session with a message from a markdown input file""" + if message_file: + with open(message_file, "r") as f: + initial_message = f.read() + else: + initial_message = click.get_text_stream("stdin").read() + + session = Session(profile=profile, log_level=log_level) + session.single_pass(initial_message=initial_message) + + @session.command(name="list") def session_list() -> None: """List goose sessions""" @@ -112,18 +226,19 @@ def session_clear(keep: int) -> None: session_file.unlink() -def get_session_files() -> Dict[str, Path]: - return list_sorted_session_files(SESSIONS_PATH) - - @click.group( invoke_without_command=True, name="goose", help="AI-powered tool to assist in solving programming and operational tasks", ) +@click.option("-V", "--version", is_flag=True, help="List the version of goose and any plugins") @click.pass_context -def cli(_: click.Context, **kwargs: Dict) -> None: - pass +def cli(ctx: click.Context, version: bool, **kwargs: dict) -> None: + if version: + ctx.invoke(get_version) + ctx.exit() + elif ctx.invoked_subcommand is None: + click.echo(ctx.get_help()) all_cli_group_options = load_plugins("goose.cli.group_option") diff --git a/src/goose/cli/prompt/create.py b/src/goose/cli/prompt/create.py index 628c86ccb..899890824 100644 --- a/src/goose/cli/prompt/create.py +++ b/src/goose/cli/prompt/create.py @@ -6,10 +6,16 @@ from goose.cli.prompt.completer import GoosePromptCompleter from goose.cli.prompt.lexer import PromptLexer -from goose.command import get_commands +from goose.command.base import Command -def create_prompt() -> PromptSession: +def create_prompt(commands: dict[str, Command]) -> PromptSession: + """ + Create a prompt session with the given commands. + + Args: + commands (dict[str, Command]): A dictionary of command names, and instances of Command classes. + """ # Define custom style style = Style.from_dict( { @@ -52,12 +58,6 @@ def _(event: KeyPressEvent) -> None: # accept completion buffer.complete_state = None - # instantiate the commands available in the prompt - commands = dict() - command_plugins = get_commands() - for command, command_cls in command_plugins.items(): - commands[command] = command_cls() - return PromptSession( completer=GoosePromptCompleter(commands=commands), lexer=PromptLexer(list(commands.keys())), diff --git a/src/goose/cli/prompt/goose_prompt_session.py b/src/goose/cli/prompt/goose_prompt_session.py index cfcedd80a..5ba54427f 100644 --- a/src/goose/cli/prompt/goose_prompt_session.py +++ b/src/goose/cli/prompt/goose_prompt_session.py @@ -1,34 +1,88 @@ from typing import Optional from prompt_toolkit import PromptSession +from prompt_toolkit.document import Document from prompt_toolkit.formatted_text import FormattedText from prompt_toolkit.validation import DummyValidator from goose.cli.prompt.create import create_prompt +from goose.cli.prompt.lexer import PromptLexer from goose.cli.prompt.prompt_validator import PromptValidator from goose.cli.prompt.user_input import PromptAction, UserInput +from goose.command import get_commands class GoosePromptSession: - def __init__(self, prompt_session: PromptSession) -> None: - self.prompt_session = prompt_session + def __init__(self) -> None: + # instantiate the commands available in the prompt + self.commands = dict() + command_plugins = get_commands() + for command, command_cls in command_plugins.items(): + self.commands[command] = command_cls() - @staticmethod - def create_prompt_session() -> "GoosePromptSession": - return GoosePromptSession(create_prompt()) + # the main prompt session that is used to interact with the llm + self.main_prompt_session = create_prompt(self.commands) + + # a text-only prompt session that doesn't contain any commands + self.text_prompt_session = PromptSession() + + def get_message_after_commands(self, message: str) -> str: + lexer = PromptLexer(command_names=list(self.commands.keys())) + doc = Document(message) + lines = [] + # iterate through each line of the document + for line_num in range(len(doc.lines)): + classes_in_line = lexer.lex_document(doc)(line_num) + line_result = [] + i = 0 + while i < len(classes_in_line): + # if a command is found and it is not the last part of the line + if classes_in_line[i][0] == "class:command" and i + 1 < len(classes_in_line): + # extract the command name + command_name = classes_in_line[i][1].strip("/").strip(":") + # get the value following the command + if classes_in_line[i + 1][0] == "class:parameter": + command_value = classes_in_line[i + 1][1] + else: + command_value = "" + + # execute the command with the given argument, expecting a return value + value_after_execution = self.commands[command_name].execute(command_value, message) + + # if the command returns None, raise an error - this should never happen + # since the command should always return a string + if value_after_execution is None: + raise ValueError(f"Command {command_name} returned None") + + # append the result of the command execution to the line results + line_result.append(value_after_execution) + i += 1 + + # if the part is plain text, just append it to the line results + elif classes_in_line[i][0] == "class:text": + line_result.append(classes_in_line[i][1]) + i += 1 + + # join all processed parts of the current line and add it to the lines list + lines.append("".join(line_result)) + + # join all processed lines into a single string with newline characters and return + return "\n".join(lines) def get_user_input(self) -> "UserInput": try: text = FormattedText([("#00AEAE", "G❯ ")]) # Define the prompt style and text. - message = self.prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False) + message = self.main_prompt_session.prompt(text, validator=PromptValidator(), validate_while_typing=False) if message.strip() in ("exit", ":q"): return UserInput(PromptAction.EXIT) + + message = self.get_message_after_commands(message) return UserInput(PromptAction.CONTINUE, message) except (EOFError, KeyboardInterrupt): return UserInput(PromptAction.EXIT) def get_save_session_name(self) -> Optional[str]: - return self.prompt_session.prompt( + return self.text_prompt_session.prompt( "Enter a name to save this session under. A name will be generated for you if empty: ", validator=DummyValidator(), - ) + ).strip(" ") diff --git a/src/goose/cli/prompt/lexer.py b/src/goose/cli/prompt/lexer.py index b21cae7ac..0e2bb0c91 100644 --- a/src/goose/cli/prompt/lexer.py +++ b/src/goose/cli/prompt/lexer.py @@ -5,6 +5,11 @@ from prompt_toolkit.lexers import Lexer +# These are lexers for the commands in the prompt. This is how we +# are extracting the different parts of a command (here, used for styling), +# but likely will be used to parse the command as well in the future. + + def completion_for_command(target_string: str) -> re.Pattern[str]: escaped_string = re.escape(target_string) vals = [f"(?:{escaped_string[:i]}(?=$))" for i in range(len(escaped_string), 0, -1)] @@ -13,22 +18,21 @@ def completion_for_command(target_string: str) -> re.Pattern[str]: def command_itself(target_string: str) -> re.Pattern[str]: escaped_string = re.escape(target_string) - return re.compile(rf"(? re.Pattern[str]: - escaped_string = re.escape(command_string) - return re.compile(rf"(?<=(? None: self.patterns = [] for command_name in command_names: - full_command = command_name + ":" - self.patterns.append((completion_for_command(full_command), "class:command")) - self.patterns.append((value_for_command(full_command), "class:parameter")) - self.patterns.append((command_itself(full_command), "class:command")) + self.patterns.append((completion_for_command(command_name), "class:command")) + self.patterns.append((value_for_command(command_name), "class:parameter")) + self.patterns.append((command_itself(command_name), "class:command")) def lex_document(self, document: Document) -> Callable[[int], list]: def get_line_tokens(line_number: int) -> Tuple[str, str]: diff --git a/src/goose/cli/session.py b/src/goose/cli/session.py index 78767115d..113c98707 100644 --- a/src/goose/cli/session.py +++ b/src/goose/cli/session.py @@ -4,27 +4,21 @@ from typing import Any, Dict, List, Optional from exchange import Message, ToolResult, ToolUse, Text -from prompt_toolkit.shortcuts import confirm from rich import print -from rich.console import RenderableType -from rich.live import Live from rich.markdown import Markdown from rich.panel import Panel from rich.status import Status -from goose.build import build_exchange -from goose.cli.config import ( - default_profiles, - ensure_config, - read_config, - session_path, -) +from goose.cli.config import ensure_config, session_path, LOG_PATH +from goose._logger import get_logger, setup_logging from goose.cli.prompt.goose_prompt_session import GoosePromptSession -from goose.notifier import Notifier +from goose.cli.session_notifier import SessionNotifier from goose.profile import Profile from goose.toolkit.language_server import LanguageServerCoordinator from goose.utils import droid, load_plugins -from goose.utils.session_file import read_from_file, write_to_file +from goose.utils._cost_calculator import get_total_cost_message +from goose.utils._create_exchange import create_exchange +from goose.utils.session_file import read_or_create_file, save_latest_session RESUME_MESSAGE = "I see we were interrupted. How can I help you?" @@ -51,33 +45,8 @@ def load_provider() -> str: def load_profile(name: Optional[str]) -> Profile: - if name is None: - name = "default" - - # If the name is one of the default values, we ensure a valid configuration - if name in default_profiles(): - return ensure_config(name) - - # Otherwise this is a custom config and we return it from the config file - return read_config()[name] - - -class SessionNotifier(Notifier): - def __init__(self, status_indicator: Status) -> None: - self.status_indicator = status_indicator - self.live = Live(self.status_indicator, refresh_per_second=8, transient=True) - - def log(self, content: RenderableType) -> None: - print(content) - - def status(self, status: str) -> None: - self.status_indicator.update(status) - - def start(self) -> None: - self.live.start() - - def stop(self) -> None: - self.live.stop() + _, profile = ensure_config(name) + return profile class Session: @@ -92,39 +61,48 @@ def __init__( name: Optional[str] = None, profile: Optional[str] = None, plan: Optional[dict] = None, + log_level: Optional[str] = "INFO", **kwargs: Dict[str, Any], ) -> None: - self.name = name + if name is None: + self.name = droid() + else: + self.name = name + self.profile_name = profile + self.prompt_session = GoosePromptSession() self.status_indicator = Status("", spinner="dots") self.notifier = SessionNotifier(self.status_indicator) self.profile = load_profile(profile) - self.exchange = build_exchange(profile=self.profile, notifier=self.notifier) - - if name is not None and self.session_file_path.exists(): - messages = self.load_session() - - if messages and messages[-1].role == "user": - if type(messages[-1].content[-1]) is Text: - # remove the last user message - messages.pop() - elif type(messages[-1].content[-1]) is ToolResult: - # if we remove this message, we would need to remove - # the previous assistant message as well. instead of doing - # that, we just add a new assistant message to prompt the user - messages.append(Message.assistant(RESUME_MESSAGE)) - if messages and type(messages[-1].content[-1]) is ToolUse: - # remove the last request for a tool to be used - messages.pop() + self.exchange = create_exchange(profile=self.profile, notifier=self.notifier) + setup_logging(log_file_directory=LOG_PATH, log_level=log_level) - # add a new assistant text message to prompt the user - messages.append(Message.assistant(RESUME_MESSAGE)) - self.exchange.messages.extend(messages) + self.exchange.messages.extend(self._get_initial_messages()) if len(self.exchange.messages) == 0 and plan: self.setup_plan(plan=plan) - self.prompt_session = GoosePromptSession.create_prompt_session() + self.prompt_session = GoosePromptSession() + + def _get_initial_messages(self) -> List[Message]: + messages = self.load_session() + + if messages and messages[-1].role == "user": + if type(messages[-1].content[-1]) is Text: + # remove the last user message + messages.pop() + elif type(messages[-1].content[-1]) is ToolResult: + # if we remove this message, we would need to remove + # the previous assistant message as well. instead of doing + # that, we just add a new assistant message to prompt the user + messages.append(Message.assistant(RESUME_MESSAGE)) + if messages and type(messages[-1].content[-1]) is ToolUse: + # remove the last request for a tool to be used + messages.pop() + + # add a new assistant text message to prompt the user + messages.append(Message.assistant(RESUME_MESSAGE)) + return messages def setup_plan(self, plan: dict) -> None: if len(self.exchange.messages): @@ -152,11 +130,39 @@ def process_first_message(self) -> Optional[Message]: return Message.user(text=user_input.text) return self.exchange.messages.pop() + def single_pass(self, initial_message: str) -> None: + """ + Handles a single input message and processes a reply + without entering a loop for additional inputs. + + Args: + initial_message (str): The initial user message to process. + """ + profile = self.profile_name or "default" + print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{profile}[/]") + print(f"[dim]saving to {self.session_file_path}") + print() + + # Process initial message + message = Message.user(initial_message) + + self.exchange.add(message) + self.reply() # Process the user message + + save_latest_session(self.session_file_path, self.exchange.messages) + print() # Print a newline for separation. + + print(f"[dim]ended run | name:[cyan]{self.name}[/] profile:[cyan]{profile}[/]") + print(f"[dim]to resume: [magenta]goose session resume {self.name} --profile {profile}[/][/]") + def run(self) -> None: """ Runs the main loop to handle user inputs and responses. Continues until an empty string is returned from the prompt. """ + print(f"[dim]starting session | name:[cyan]{self.name}[/] profile:[cyan]{self.profile_name or 'default'}[/]") + print(f"[dim]saving to {self.session_file_path}") + print() message = self.process_first_message() with self.setup_language_server()() as _: while message: # Loop until no input (empty string). @@ -177,19 +183,16 @@ def run(self) -> None: + " - [yellow]depending on the error you may be able to continue[/]" ) self.notifier.stop() + save_latest_session(self.session_file_path, self.exchange.messages) print() # Print a newline for separation. user_input = self.prompt_session.get_user_input() message = Message.user(text=user_input.text) if user_input.to_continue() else None - self.save_session() + self._log_cost() def reply(self) -> None: - """Reply to the last user message, calling tools as needed - - Args: - text (str): The text input from the user. - """ + """Reply to the last user message, calling tools as needed""" self.status_indicator.update("responding") response = self.exchange.generate() @@ -242,29 +245,12 @@ def interrupt_reply(self) -> None: def session_file_path(self) -> Path: return session_path(self.name) - def save_session(self) -> None: - """Save the current session to a file in JSON format.""" - if self.name is None: - self.generate_session_name() - - try: - if self.session_file_path.exists(): - if not confirm(f"Session {self.name} exists in {self.session_file_path}, overwrite?"): - self.generate_session_name() - write_to_file(self.session_file_path, self.exchange.messages) - except PermissionError as e: - raise RuntimeError(f"Failed to save session due to permissions: {e}") - except (IOError, OSError) as e: - raise RuntimeError(f"Failed to save session due to I/O error: {e}") - def load_session(self) -> List[Message]: - """Load a session from a JSON file.""" - return read_from_file(self.session_file_path) + return read_or_create_file(self.session_file_path) - def generate_session_name(self) -> None: - user_entered_session_name = self.prompt_session.get_save_session_name() - self.name = user_entered_session_name if user_entered_session_name else droid() - print(f"Saving to [bold cyan]{self.session_file_path}[/bold cyan]") + def _log_cost(self) -> None: + get_logger().info(get_total_cost_message(self.exchange.get_token_usage())) + print(f"[dim]you can view the cost and token usage in the log directory {LOG_PATH}") if __name__ == "__main__": diff --git a/src/goose/cli/session_notifier.py b/src/goose/cli/session_notifier.py new file mode 100644 index 000000000..d29ce944a --- /dev/null +++ b/src/goose/cli/session_notifier.py @@ -0,0 +1,24 @@ +from rich.status import Status +from rich.live import Live +from rich.console import RenderableType +from rich import print + +from goose.notifier import Notifier + + +class SessionNotifier(Notifier): + def __init__(self, status_indicator: Status) -> None: + self.status_indicator = status_indicator + self.live = Live(self.status_indicator, refresh_per_second=8, transient=True) + + def log(self, content: RenderableType) -> None: + print(content) + + def status(self, status: str) -> None: + self.status_indicator.update(status) + + def start(self) -> None: + self.live.start() + + def stop(self) -> None: + self.live.stop() diff --git a/src/goose/command/base.py b/src/goose/command/base.py index dbf2d19db..5a8c346ff 100644 --- a/src/goose/command/base.py +++ b/src/goose/command/base.py @@ -8,9 +8,19 @@ class Command(ABC): """A command that can be executed by the CLI.""" def get_completions(self, query: str) -> List[Completion]: - """Get completions for the command.""" + """ + Get completions for the command. + + Args: + query (str): The current query. + """ return [] def execute(self, query: str) -> Optional[str]: - """Execute's the command and replaces it with the output.""" + """ + Execute's the command and replaces it with the output. + + Args: + query (str): The query to execute. + """ return "" diff --git a/src/goose/command/file.py b/src/goose/command/file.py index 7bbf7d9e3..cb8bdfd67 100644 --- a/src/goose/command/file.py +++ b/src/goose/command/file.py @@ -57,5 +57,4 @@ def get_completions(self, query: str) -> List[Completion]: return completions def execute(self, query: str) -> str | None: - # GOOSE-TODO: return the query - pass + return query diff --git a/src/goose/profile.py b/src/goose/profile.py index 88a7c3f7b..cdc34fb85 100644 --- a/src/goose/profile.py +++ b/src/goose/profile.py @@ -39,6 +39,10 @@ def check_toolkit_requirements(self, _: Type["ToolkitSpec"], toolkits: List[Tool def to_dict(self) -> Dict[str, Any]: return asdict(self) + def profile_info(self) -> str: + tookit_names = [toolkit.name for toolkit in self.toolkits] + return f"provider:{self.provider}, processor:{self.processor} toolkits: {', '.join(tookit_names)}" + def default_profile(provider: str, processor: str, accelerator: str, **kwargs: Dict[str, Any]) -> Profile: """Get the default profile""" diff --git a/src/goose/toolkit/__init__.py b/src/goose/toolkit/__init__.py index a3a97d41f..fc561ee67 100644 --- a/src/goose/toolkit/__init__.py +++ b/src/goose/toolkit/__init__.py @@ -1,9 +1,12 @@ from functools import cache - +from exchange.invalid_choice_error import InvalidChoiceError from goose.toolkit.base import Toolkit from goose.utils import load_plugins @cache def get_toolkit(name: str) -> type[Toolkit]: - return load_plugins(group="goose.toolkit")[name] + toolkits = load_plugins(group="goose.toolkit") + if name not in toolkits: + raise InvalidChoiceError("toolkit", name, toolkits.keys()) + return toolkits[name] diff --git a/src/goose/toolkit/developer.py b/src/goose/toolkit/developer.py index 450ce2444..ba600d921 100644 --- a/src/goose/toolkit/developer.py +++ b/src/goose/toolkit/developer.py @@ -1,19 +1,23 @@ -from pathlib import Path -from subprocess import CompletedProcess, run -from typing import List, Dict import os -from goose.utils.check_shell_command import is_dangerous_command +import re +import subprocess +import time +from pathlib import Path +from typing import Dict, List from exchange import Message -from rich import box +from goose.toolkit.base import Toolkit, tool +from goose.toolkit.utils import get_language, render_template +from goose.utils.ask import ask_an_ai +from goose.utils.check_shell_command import is_dangerous_command from rich.markdown import Markdown -from rich.panel import Panel from rich.prompt import Confirm from rich.table import Table from rich.text import Text +from rich.rule import Rule -from goose.toolkit.base import Toolkit, tool -from goose.toolkit.utils import get_language, render_template +RULESTYLE = "bold" +RULEPREFIX = f"[{RULESTYLE}]───[/] " def keep_unsafe_command_prompt(command: str) -> bool: @@ -39,9 +43,15 @@ def system(self) -> str: """Retrieve system configuration details for developer""" hints_path = Path(".goosehints") system_prompt = Message.load("prompts/developer.jinja").text + home_hints_path = Path.home() / ".config/goose/.goosehints" + hints = [] if hints_path.is_file(): - goosehints = render_template(hints_path) - system_prompt = f"{system_prompt}\n\nHints:\n{goosehints}" + hints.append(render_template(hints_path)) + if home_hints_path.is_file(): + hints.append(render_template(home_hints_path)) + if hints: + hints_text = "\n".join(hints) + system_prompt = f"{system_prompt}\n\nHints:\n{hints_text}" return system_prompt @tool @@ -116,7 +126,8 @@ def patch_file(self, path: str, before: str, after: str) -> str: {after} ``` """ - self.notifier.log(Panel.fit(Markdown(output), title=path)) + self.notifier.log(Rule(RULEPREFIX + path, style=RULESTYLE, align="left")) + self.notifier.log(Markdown(output)) return "Succesfully replaced before with after." @tool @@ -128,7 +139,7 @@ def read_file(self, path: str) -> str: """ language = get_language(path) content = Path(path).expanduser().read_text() - self.notifier.log(Panel.fit(Markdown(f"```\ncat {path}\n```"), box=box.MINIMAL)) + self.notifier.log(Markdown(f"```\ncat {path}\n```")) # Record the last read timestamp self.timestamps[path] = os.path.getmtime(path) return f"```{language}\n{content}\n```" @@ -136,7 +147,7 @@ def read_file(self, path: str) -> str: @tool def shell(self, command: str) -> str: """ - Execute a command on the shell (in OSX) + Execute a command on the shell This will return the output and error concatenated into a single string, as you would see from running on the command line. There will also be an indication @@ -146,12 +157,9 @@ def shell(self, command: str) -> str: command (str): The shell command to run. It can support multiline statements if you need to run more than one at a time """ - self.notifier.status("planning to run shell command") # Log the command being executed in a visually structured format (Markdown). - # The `.log` method is used here to log the command execution in the application's UX - # this method is dynamically attached to functions in the Goose framework to handle user-visible - # logging and integrates with the overall UI logging system - self.notifier.log(Panel.fit(Markdown(f"```bash\n{command}\n```"), title="shell")) + self.notifier.log(Rule(RULEPREFIX + "shell", style=RULESTYLE, align="left")) + self.notifier.log(Markdown(f"```bash\n{command}\n```")) if is_dangerous_command(command): # Stop the notifications so we can prompt @@ -159,16 +167,86 @@ def shell(self, command: str) -> str: if not keep_unsafe_command_prompt(command): raise RuntimeError( f"The command {command} was rejected as dangerous by the user." - + " Do not proceed further, instead ask for instructions." + " Do not proceed further, instead ask for instructions." ) self.notifier.start() self.notifier.status("running shell command") - result: CompletedProcess = run(command, shell=True, text=True, capture_output=True, check=False) - if result.returncode == 0: - output = "Command succeeded" + + # Define patterns that might indicate the process is waiting for input + interaction_patterns = [ + r"Do you want to", # Common prompt phrase + r"Enter password", # Password prompt + r"Are you sure", # Confirmation prompt + r"\(y/N\)", # Yes/No prompt + r"Press any key to continue", # Awaiting keypress + r"Waiting for input", # General waiting message + ] + compiled_patterns = [re.compile(pattern, re.IGNORECASE) for pattern in interaction_patterns] + + proc = subprocess.Popen( + command, + shell=True, + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + # this enables us to read lines without blocking + os.set_blocking(proc.stdout.fileno(), False) + + # Accumulate the output logs while checking if it might be blocked + output_lines = [] + last_output_time = time.time() + cutoff = 10 + while proc.poll() is None: + self.notifier.status("running shell command") + line = proc.stdout.readline() + if line: + output_lines.append(line) + last_output_time = time.time() + + # If we see a clear pattern match, we plan to abort + exit_criteria = any(pattern.search(line) for pattern in compiled_patterns) + + # and if we haven't seen a new line in 10+s, check with AI to see if it may be stuck + if not exit_criteria and time.time() - last_output_time > cutoff: + self.notifier.status("checking on shell status") + response = ask_an_ai( + input="\n".join([command] + output_lines), + prompt=( + "You will evaluate the output of shell commands to see if they may be stuck." + " Look for commands that appear to be awaiting user input, or otherwise running indefinitely (such as a web service)." # noqa + " A command that will take a while, such as downloading resources is okay." # noqa + " return [Yes] if stuck, [No] otherwise." + ), + exchange=self.exchange_view.processor, + with_tools=False, + ) + exit_criteria = "[yes]" in response.content[0].text.lower() + # We add exponential backoff for how often we check for the command being stuck + cutoff *= 10 + + if exit_criteria: + proc.terminate() + raise ValueError( + f"The command `{command}` looks like it will run indefinitely or is otherwise stuck." + f"You may be able to specify inputs if it applies to this command." + f"Otherwise to enable continued iteration, you'll need to ask the user to run this command in another terminal." # noqa + ) + + # read any remaining lines + while line := proc.stdout.readline(): + output_lines.append(line) + output = "".join(output_lines) + + # Determine the result based on the return code + if proc.returncode == 0: + result = "Command succeeded" else: - output = f"Command failed with returncode {result.returncode}" - return "\n".join([output, result.stdout, result.stderr]) + result = f"Command failed with returncode {proc.returncode}" + + # Return the combined result and outputs if we made it this far + return "\n".join([result, output]) @tool def write_file(self, path: str, content: str) -> str: @@ -188,7 +266,8 @@ def write_file(self, path: str, content: str) -> str: # Log the content that will be written to the file # .log` method is used here to log the command execution in the application's UX # this method is dynamically attached to functions in the Goose framework - self.notifier.log(Panel.fit(Markdown(md), title=path)) + self.notifier.log(Rule(RULEPREFIX + path, style=RULESTYLE, align="left")) + self.notifier.log(Markdown(md)) _path = Path(path) if path in self.timestamps: diff --git a/src/goose/toolkit/lint.py b/src/goose/toolkit/lint.py index 0f08f222d..a12335c74 100644 --- a/src/goose/toolkit/lint.py +++ b/src/goose/toolkit/lint.py @@ -10,3 +10,14 @@ def lint_toolkits() -> None: assert first_line_of_docstring[ 0 ].isupper(), f"`{toolkit_name}` toolkit docstring must start with a capital letter" + + +def lint_providers() -> None: + for provider_name, provider in load_plugins(group="exchange.provider").items(): + assert provider.__doc__ is not None, f"`{provider_name}` provider must have a docstring" + first_line_of_docstring = provider.__doc__.split("\n")[0] + assert len(first_line_of_docstring.split(" ")) > 5, f"`{provider_name}` provider docstring is too short" + assert len(first_line_of_docstring.split(" ")) < 20, f"`{provider_name}` provider docstring is too long" + assert first_line_of_docstring[ + 0 + ].isupper(), f"`{provider_name}` provider docstring must start with a capital letter" diff --git a/src/goose/toolkit/prompts/developer.jinja b/src/goose/toolkit/prompts/developer.jinja index d84c3d902..6da404e9e 100644 --- a/src/goose/toolkit/prompts/developer.jinja +++ b/src/goose/toolkit/prompts/developer.jinja @@ -15,23 +15,23 @@ the actions you will need to take, such as writing files or executing shell comm For example, here's a plan to unstage all edited files in a git repo -{"description": "Use the git status command to find edited files", "status": "pending"} -{"description": "For each file with changes, call git restore on the file", "status": "pending"} +{"description": "Use the git status command to find edited files", "status": "planned"} +{"description": "For each file with changes, call git restore on the file", "status": "planned"} After running git status, you would update to {"description": "Use the git status command to find edited files", "status": "complete"} -{"description": "For each file with changes, call git restore on the file", "status": "pending"} +{"description": "For each file with changes, call git restore on the file", "status": "planned"} Here's another plan example to get the sum of the "payment" column in data.csv -{"description": "Install pandas", "status": "pending"} -{"description": "Write a python script in the file sum.py that loads the csv in pandas and sums the column", "status": "pending"} -{"description": "Run the python script with bash", "status": "pending"} +{"description": "Install pandas", "status": "planned"} +{"description": "Write a python script in the file sum.py that loads the csv in pandas and sums the column", "status": "planned"} +{"description": "Run the python script with bash", "status": "planned"} If you were to encounter an error along the way, you can update the plan to specify a new approach. -Always call update_plan before any other tool calls! You should specify the whole plan upfront as pending, +Always call update_plan before any other tool calls! You should specify the whole plan upfront as planned, and then update status at each step. **Do not describe the plan in your text response, only communicate the plan through the tool** diff --git a/src/goose/toolkit/screen.py b/src/goose/toolkit/screen.py index ce5524881..f0cc5722d 100644 --- a/src/goose/toolkit/screen.py +++ b/src/goose/toolkit/screen.py @@ -1,6 +1,9 @@ import subprocess import uuid +from rich.markdown import Markdown +from rich.panel import Panel + from goose.toolkit.base import Toolkit, tool @@ -8,17 +11,30 @@ class Screen(Toolkit): """Provides an instructions on when and how to work with screenshots""" @tool - def take_screenshot(self) -> str: + def take_screenshot(self, display: int = 1) -> str: """ - Take a screenshot to assist the user in debugging or designing an app. They may need help with moving a screen element, or interacting in some way where you could do with seeing the screen. + Take a screenshot to assist the user in debugging or designing an app. They may need + help with moving a screen element, or interacting in some way where you could do with + seeing the screen. - Return: - (str) a path to the screenshot file, in the format of image: followed by the path to the file. + Args: + display (int): Display to take the screen shot in. Default is the main display (1). Must be a value greater than 1. """ # noqa: E501 # Generate a random tmp filename for screenshot - filename = f"/tmp/goose_screenshot_{uuid.uuid4().hex}.png" + filename = f"/tmp/goose_screenshot_{uuid.uuid4().hex}.jpg" + screen_capture_command = ["screencapture", "-x", "-D", str(display), filename, "-f", "jpg"] + + subprocess.run(screen_capture_command, check=True, capture_output=True) + + resize_command = ["sips", "--resampleWidth", "768", filename, "-s", "format", "jpeg"] + subprocess.run(resize_command, check=True, capture_output=True) - subprocess.run(["screencapture", "-x", filename]) + self.notifier.log( + Panel.fit( + Markdown(f"```bash\n{' '.join(screen_capture_command)}"), + title="screen", + ) + ) return f"image:{filename}" diff --git a/src/goose/toolkit/utils.py b/src/goose/toolkit/utils.py index 61632b776..ad97360f2 100644 --- a/src/goose/toolkit/utils.py +++ b/src/goose/toolkit/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Dict from pygments.lexers import get_lexer_for_filename from pygments.util import ClassNotFound @@ -42,3 +42,38 @@ def render_template(template_path: Path, context: Optional[dict] = None) -> str: env = Environment(loader=FileSystemLoader(template_path.parent)) template = env.get_template(template_path.name) return template.render(context or {}) + + +def find_last_task_group_index(input_str: str) -> int: + lines = input_str.splitlines() + last_group_start_index = -1 + current_group_start_index = -1 + + for i, line in enumerate(lines): + line = line.strip() + if line.startswith("-"): + # If this is the first line of a new group, mark its start + if current_group_start_index == -1: + current_group_start_index = i + else: + # If we encounter a non-hyphenated line and had a group, update last group start + if current_group_start_index != -1: + last_group_start_index = current_group_start_index + current_group_start_index = -1 # Reset for potential future groups + + # If the input ended in a task group, update the last group index + if current_group_start_index != -1: + last_group_start_index = current_group_start_index + return last_group_start_index + + +def parse_plan(input_plan_str: str) -> Dict: + last_group_start_index = find_last_task_group_index(input_plan_str) + if last_group_start_index == -1: + return {"kickoff_message": input_plan_str, "tasks": []} + + kickoff_message_list = input_plan_str.splitlines()[:last_group_start_index] + kickoff_message = "\n".join(kickoff_message_list).strip() + tasks_list = input_plan_str.splitlines()[last_group_start_index:] + tasks_list_output = [s[1:] for s in tasks_list if s.strip()] # filter leading - + return {"kickoff_message": kickoff_message, "tasks": tasks_list_output} diff --git a/src/goose/utils/_cost_calculator.py b/src/goose/utils/_cost_calculator.py new file mode 100644 index 000000000..5cdf69509 --- /dev/null +++ b/src/goose/utils/_cost_calculator.py @@ -0,0 +1,40 @@ +from typing import Optional +from exchange.providers.base import Usage + +PRICES = { + "gpt-4o": (2.50, 10.00), + "gpt-4o-2024-08-06": (2.50, 10.00), + "gpt-4o-2024-05-13": (5.00, 15.00), + "gpt-4o-mini": (0.150, 0.600), + "gpt-4o-mini-2024-07-18": (0.150, 0.600), + "o1-preview": (15.00, 60.00), + "o1-mini": (3.00, 12.00), + "claude-3-5-sonnet-20240620": (3.00, 15.00), + "anthropic.claude-3-5-sonnet-20240620-v1:0": (3.00, 15.00), + "claude-3-opus-20240229": (15.00, 75.00), + "anthropic.claude-3-opus-20240229-v1:0": (15.00, 75.00), + "claude-3-haiku-20240307": (0.25, 1.25), + "anthropic.claude-3-haiku-20240307-v1:0": (0.25, 1.25), +} + + +def _calculate_cost(model: str, token_usage: Usage) -> Optional[float]: + model_name = model.lower() + if model_name in PRICES: + input_token_price, output_token_price = PRICES[model_name] + return (input_token_price * token_usage.input_tokens + output_token_price * token_usage.output_tokens) / 1000000 + return None + + +def get_total_cost_message(token_usages: dict[str, Usage]) -> str: + total_cost = 0 + message = "" + for model, token_usage in token_usages.items(): + cost = _calculate_cost(model, token_usage) + if cost is not None: + message += f"Cost for model {model} {str(token_usage)}: ${cost:.2f}\n" + total_cost += cost + else: + message += f"Cost for model {model} {str(token_usage)}: Not available\n" + message += f"Total cost: ${total_cost:.2f}" + return message diff --git a/src/goose/utils/_create_exchange.py b/src/goose/utils/_create_exchange.py new file mode 100644 index 000000000..d1aa318f5 --- /dev/null +++ b/src/goose/utils/_create_exchange.py @@ -0,0 +1,52 @@ +import os +import sys +from typing import Optional +import keyring + +from prompt_toolkit import prompt +from prompt_toolkit.shortcuts import confirm +from rich import print +from rich.panel import Panel + +from goose.build import build_exchange +from goose.cli.config import PROFILES_CONFIG_PATH +from goose.cli.session_notifier import SessionNotifier +from goose.profile import Profile +from exchange import Exchange +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers.base import MissingProviderEnvVariableError + + +def create_exchange(profile: Profile, notifier: SessionNotifier) -> Exchange: + try: + return build_exchange(profile, notifier=notifier) + except InvalidChoiceError as e: + error_message = ( + f"[bold red]{e.message}[/bold red].\nPlease check your configuration file at {PROFILES_CONFIG_PATH}.\n" + + "Configuration doc: https://block-open-source.github.io/goose/configuration.html" + ) + print(error_message) + sys.exit(1) + except MissingProviderEnvVariableError as e: + api_key = _get_api_key_from_keychain(e.env_variable, e.provider) + if api_key is None or api_key == "": + error_message = f"{e.message}. Please set the required environment variable to continue." + print(Panel(error_message, style="red")) + sys.exit(1) + else: + os.environ[e.env_variable] = api_key + return build_exchange(profile=profile, notifier=notifier) + + +def _get_api_key_from_keychain(env_variable: str, provider: str) -> Optional[str]: + api_key = keyring.get_password("goose", env_variable) + if api_key is not None: + print(f"Using {env_variable} value for {provider} from your keychain") + else: + api_key = prompt(f"Enter {env_variable} value for {provider}:".strip()) + if api_key is not None and len(api_key) > 0: + save_to_keyring = confirm(f"Would you like to save the {env_variable} value to your keychain?") + if save_to_keyring: + keyring.set_password("goose", env_variable, api_key) + print(f"Saved {env_variable} to your key_chain. service_name: goose, user_name: {env_variable}") + return api_key diff --git a/src/goose/utils/ask.py b/src/goose/utils/ask.py index 0e34b444a..c0fee1bcd 100644 --- a/src/goose/utils/ask.py +++ b/src/goose/utils/ask.py @@ -1,7 +1,13 @@ from exchange import Exchange, Message, CheckpointData -def ask_an_ai(input: str, exchange: Exchange, prompt: str = "", no_history: bool = True) -> Message: +def ask_an_ai( + input: str, + exchange: Exchange, + prompt: str = "", + no_history: bool = True, + with_tools: bool = True, +) -> Message: """Sends a separate message to an LLM using a separate Exchange than the one underlying the Goose session. Can be used to summarize a file, or submit any other request that you'd like to an AI. The Exchange can have a @@ -36,6 +42,9 @@ def ask_an_ai(input: str, exchange: Exchange, prompt: str = "", no_history: bool if no_history: exchange = clear_exchange(exchange) + if not with_tools: + exchange = exchange.replace(tools=()) + if prompt: exchange = replace_prompt(exchange, prompt) diff --git a/src/goose/utils/autocomplete.py b/src/goose/utils/autocomplete.py new file mode 100644 index 000000000..6feb0807e --- /dev/null +++ b/src/goose/utils/autocomplete.py @@ -0,0 +1,100 @@ +import sys +from pathlib import Path + +from rich import print + +SUPPORTED_SHELLS = ["bash", "zsh", "fish"] + + +def is_autocomplete_installed(file: Path) -> bool: + if not file.exists(): + print(f"[yellow]{file} does not exist, creating file") + with open(file, "w") as f: + f.write("") + + # https://click.palletsprojects.com/en/8.1.x/shell-completion/#enabling-completion + if "_GOOSE_COMPLETE" in open(file).read(): + print(f"auto-completion already installed in {file}") + return True + return False + + +def setup_bash(install: bool) -> None: + bashrc = Path("~/.bashrc").expanduser() + if install: + if is_autocomplete_installed(bashrc): + return + f = open(bashrc, "a") + else: + f = sys.stdout + print(f"# add the following to your bash config, typically {bashrc}") + + with f: + f.write('eval "$(_GOOSE_COMPLETE=bash_source goose)"\n') + + if install: + print(f"installed auto-completion to {bashrc}") + print(f"run `source {bashrc}` to enable auto-completion") + + +def setup_fish(install: bool) -> None: + completion_dir = Path("~/.config/fish/completions").expanduser() + if not completion_dir.exists(): + completion_dir.mkdir(parents=True, exist_ok=True) + + completion_file = completion_dir / "goose.fish" + if install: + if is_autocomplete_installed(completion_file): + return + f = open(completion_file, "a") + else: + f = sys.stdout + print(f"# add the following to your fish config, typically {completion_file}") + + with f: + f.write("_GOOSE_COMPLETE=fish_source goose | source\n") + + if install: + print(f"installed auto-completion to {completion_file}") + + +def setup_zsh(install: bool) -> None: + zshrc = Path("~/.zshrc").expanduser() + if install: + if is_autocomplete_installed(zshrc): + return + f = open(zshrc, "a") + else: + f = sys.stdout + print(f"# add the following to your zsh config, typically {zshrc}") + + with f: + f.write("autoload -U +X compinit && compinit\n") + f.write("autoload -U +X bashcompinit && bashcompinit\n") + f.write('eval "$(_GOOSE_COMPLETE=zsh_source goose)"\n') + + if install: + print(f"installed auto-completion to {zshrc}") + print(f"run `source {zshrc}` to enable auto-completion") + + +def setup_autocomplete(shell: str, install: bool) -> None: + """Installs shell completions for goose + + Args: + shell (str): shell to install completions for + install (bool): whether to install or generate completions + """ + + match shell: + case "bash": + setup_bash(install=install) + + case "zsh": + setup_zsh(install=install) + + case "fish": + setup_fish(install=install) + + case _: + print(f"Shell {shell} not supported") diff --git a/src/goose/utils/diff.py b/src/goose/utils/diff.py deleted file mode 100644 index e3583be01..000000000 --- a/src/goose/utils/diff.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import List - -from rich.text import Text - - -def diff(a: str, b: str) -> List[str]: - """Returns a string containing the unified diff of two strings.""" - - import difflib - - a_lines = a.splitlines() - b_lines = b.splitlines() - - # Create a Differ object - d = difflib.Differ() - - # Generate the diff - diff = list(d.compare(a_lines, b_lines)) - - return diff - - -def pretty_diff(a: str, b: str) -> Text: - """Returns a pretty-printed diff of two strings.""" - - diff_lines = diff(a, b) - result = Text() - for line in diff_lines: - if line.startswith("+"): - result.append(line, style="green") - elif line.startswith("-"): - result.append(line, style="red") - elif line.startswith("?"): - result.append(line, style="yellow") - else: - result.append(line, style="dim grey") - result.append("\n") - - return result diff --git a/src/goose/utils/session_file.py b/src/goose/utils/session_file.py index a47efcb1e..435186ce5 100644 --- a/src/goose/utils/session_file.py +++ b/src/goose/utils/session_file.py @@ -1,5 +1,7 @@ import json +import os from pathlib import Path +import tempfile from typing import Dict, Iterator, List from exchange import Message @@ -9,9 +11,15 @@ def write_to_file(file_path: Path, messages: List[Message]) -> None: with open(file_path, "w") as f: - for m in messages: - json.dump(m.to_dict(), f) - f.write("\n") + _write_messages_to_file(f, messages) + + +def read_or_create_file(file_path: Path) -> List[Message]: + if file_path.exists(): + return read_from_file(file_path) + with open(file_path, "w"): + pass + return [] def read_from_file(file_path: Path) -> List[Message]: @@ -37,3 +45,17 @@ def session_file_exists(session_files_directory: Path) -> bool: if not session_files_directory.exists(): return False return any(list_session_files(session_files_directory)) + + +def save_latest_session(file_path: Path, messages: List[Message]) -> None: + with tempfile.NamedTemporaryFile("w", delete=False) as temp_file: + _write_messages_to_file(temp_file, messages) + temp_file_path = temp_file.name + + os.replace(temp_file_path, file_path) + + +def _write_messages_to_file(file: any, messages: List[Message]) -> None: + for m in messages: + json.dump(m.to_dict(), file) + file.write("\n") diff --git a/tests/cli/prompt/test_goose_prompt_session.py b/tests/cli/prompt/test_goose_prompt_session.py index eca44cc67..1c9578fa2 100644 --- a/tests/cli/prompt/test_goose_prompt_session.py +++ b/tests/cli/prompt/test_goose_prompt_session.py @@ -1,5 +1,6 @@ from unittest.mock import patch +from prompt_toolkit import PromptSession import pytest from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput @@ -7,41 +8,48 @@ @pytest.fixture def mock_prompt_session(): - with patch("prompt_toolkit.PromptSession") as mock_prompt_session: + with patch("goose.cli.prompt.goose_prompt_session.PromptSession") as mock_prompt_session: yield mock_prompt_session def test_get_save_session_name(mock_prompt_session): - mock_prompt_session.prompt.return_value = "my_session" - goose_prompt_session = GoosePromptSession(mock_prompt_session) + mock_prompt_session.return_value.prompt.return_value = "my_session" + goose_prompt_session = GoosePromptSession() assert goose_prompt_session.get_save_session_name() == "my_session" -def test_get_user_input_to_continue(mock_prompt_session): - mock_prompt_session.prompt.return_value = "input_value" - goose_prompt_session = GoosePromptSession(mock_prompt_session) +def test_get_save_session_name_with_space(mock_prompt_session): + mock_prompt_session.return_value.prompt.return_value = "my_session " + goose_prompt_session = GoosePromptSession() - user_input = goose_prompt_session.get_user_input() + assert goose_prompt_session.get_save_session_name() == "my_session" + + +def test_get_user_input_to_continue(): + with patch.object(PromptSession, "prompt", return_value="input_value"): + goose_prompt_session = GoosePromptSession() + + user_input = goose_prompt_session.get_user_input() - assert user_input == UserInput(PromptAction.CONTINUE, "input_value") + assert user_input == UserInput(PromptAction.CONTINUE, "input_value") @pytest.mark.parametrize("exit_input", ["exit", ":q"]) def test_get_user_input_to_exit(exit_input, mock_prompt_session): - mock_prompt_session.prompt.return_value = exit_input - goose_prompt_session = GoosePromptSession(mock_prompt_session) + with patch.object(PromptSession, "prompt", return_value=exit_input): + goose_prompt_session = GoosePromptSession() - user_input = goose_prompt_session.get_user_input() + user_input = goose_prompt_session.get_user_input() - assert user_input == UserInput(PromptAction.EXIT) + assert user_input == UserInput(PromptAction.EXIT) @pytest.mark.parametrize("error", [EOFError, KeyboardInterrupt]) def test_get_user_input_to_exit_when_error_occurs(error, mock_prompt_session): - mock_prompt_session.prompt.side_effect = error - goose_prompt_session = GoosePromptSession(mock_prompt_session) + with patch.object(PromptSession, "prompt", side_effect=error): + goose_prompt_session = GoosePromptSession() - user_input = goose_prompt_session.get_user_input() + user_input = goose_prompt_session.get_user_input() - assert user_input == UserInput(PromptAction.EXIT) + assert user_input == UserInput(PromptAction.EXIT) diff --git a/tests/cli/prompt/test_lexer.py b/tests/cli/prompt/test_lexer.py index 585bead9b..790bed40b 100644 --- a/tests/cli/prompt/test_lexer.py +++ b/tests/cli/prompt/test_lexer.py @@ -232,22 +232,45 @@ def test_lex_document_ending_char_of_parameter_is_symbol(): assert actual_tokens == expected_tokens -def test_command_itself(): - pattern = command_itself("file:") - matches = pattern.match("/file:example.txt") +def assert_pattern_matches(pattern, text, expected_group): + matches = pattern.search(text) assert matches is not None - assert matches.group(1) == "/file:" + assert matches.group() == expected_group + + +def test_command_itself(): + pattern = command_itself("file") + assert_pattern_matches(pattern, "/file:example.txt", "/file:") + assert_pattern_matches(pattern, "/file asdf", "/file") + assert_pattern_matches(pattern, "some /file", "/file") + assert_pattern_matches(pattern, "some /file:", "/file:") + assert_pattern_matches(pattern, "/file /file", "/file") + + assert pattern.search("file") is None + assert pattern.search("/anothercommand") is None def test_value_for_command(): - pattern = value_for_command("file:") - matches = pattern.search("/file:example.txt") - assert matches is not None - assert matches.group(1) == "example.txt" + pattern = value_for_command("file") + assert_pattern_matches(pattern, "/file:example.txt", "example.txt") + assert_pattern_matches(pattern, '/file:"example space.txt"', '"example space.txt"') + assert_pattern_matches(pattern, '/file:"example.txt" some other string', '"example.txt"') + assert_pattern_matches(pattern, "something before /file:example.txt", "example.txt") + + # assert no pattern matches when there is no value + assert pattern.search("/file:").group() == "" + assert pattern.search("/file: other").group() == "" + assert pattern.search("/file: ").group() == "" + assert pattern.search("/file other") is None def test_completion_for_command(): - pattern = completion_for_command("file:") - matches = pattern.search("/file:") - assert matches is not None - assert matches.group(1) == "file:" + pattern = completion_for_command("file") + assert_pattern_matches(pattern, "/file", "/file") + assert_pattern_matches(pattern, "/fi", "/fi") + assert_pattern_matches(pattern, "before /fi", "/fi") + assert_pattern_matches(pattern, "some /f", "/f") + + assert pattern.search("/file after") is None + assert pattern.search("/ file") is None + assert pattern.search("/file:") is None diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py index b857f8b99..0694034dc 100644 --- a/tests/cli/test_config.py +++ b/tests/cli/test_config.py @@ -28,53 +28,66 @@ def test_read_write_config(mock_profile_config_path, profile_factory): assert read_config() == profiles -def test_ensure_config_create_profiles_file_with_default_profile( +def test_ensure_config_create_profiles_file_with_default_profile_with_name_default( mock_profile_config_path, mock_default_model_configuration ): assert not mock_profile_config_path.exists() - ensure_config(name="default") + (profile_name, profile) = ensure_config(name=None) + + expected_profile = default_profile(*mock_default_model_configuration()) + + assert profile_name == "default" + assert profile == expected_profile assert mock_profile_config_path.exists() + assert read_config() == {"default": expected_profile} - assert read_config() == {"default": default_profile(*mock_default_model_configuration())} +def test_ensure_config_create_profiles_file_with_default_profile_with_profile_name( + mock_profile_config_path, mock_default_model_configuration +): + assert not mock_profile_config_path.exists() -def test_ensure_config_add_default_profile(mock_profile_config_path, profile_factory, mock_default_model_configuration): - existing_profile = profile_factory({"provider": "providerA"}) - write_config({"profile1": existing_profile}) + (profile_name, profile) = ensure_config(name="my_profile") - ensure_config(name="default") + expected_profile = default_profile(*mock_default_model_configuration()) - assert read_config() == { - "profile1": existing_profile, - "default": default_profile(*mock_default_model_configuration()), - } + assert profile_name == "my_profile" + assert profile == expected_profile + assert mock_profile_config_path.exists() + assert read_config() == {"my_profile": expected_profile} -@patch("goose.cli.config.Confirm.ask", return_value=True) -def test_ensure_config_overwrite_default_profile( - mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration +def test_ensure_config_add_default_profile_when_profile_not_exist( + mock_profile_config_path, profile_factory, mock_default_model_configuration ): existing_profile = profile_factory({"provider": "providerA"}) - profile_name = "default" - write_config({profile_name: existing_profile}) + write_config({"profile1": existing_profile}) + + (profile_name, new_profile) = ensure_config(name="my_new_profile") - expected_default_profile = default_profile(*mock_default_model_configuration()) - assert ensure_config(name="default") == expected_default_profile - assert read_config() == {"default": expected_default_profile} + expected_profile = default_profile(*mock_default_model_configuration()) + assert profile_name == "my_new_profile" + assert new_profile == expected_profile + assert read_config() == { + "profile1": existing_profile, + "my_new_profile": expected_profile, + } -@patch("goose.cli.config.Confirm.ask", return_value=False) -def test_ensure_config_keep_original_default_profile( - mock_confirm, mock_profile_config_path, profile_factory, mock_default_model_configuration +def test_ensure_config_get_existing_profile_not_exist( + mock_profile_config_path, profile_factory, mock_default_model_configuration ): existing_profile = profile_factory({"provider": "providerA"}) - profile_name = "default" - write_config({profile_name: existing_profile}) + write_config({"profile1": existing_profile}) - assert ensure_config(name="default") == existing_profile + (profile_name, profile) = ensure_config(name="profile1") - assert read_config() == {"default": existing_profile} + assert profile_name == "profile1" + assert profile == existing_profile + assert read_config() == { + "profile1": existing_profile, + } def test_session_path(mock_sessions_path): diff --git a/tests/cli/test_main.py b/tests/cli/test_main.py index 617b3d5c1..9f6f2f66f 100644 --- a/tests/cli/test_main.py +++ b/tests/cli/test_main.py @@ -30,11 +30,19 @@ def mock_session(): yield mock_session_class, mock_session_instance +def test_session_start_command_with_session_name(mock_session): + mock_session_class, mock_session_instance = mock_session + runner = CliRunner() + runner.invoke(goose_cli, ["session", "start", "session1", "--profile", "default"]) + mock_session_class.assert_called_once_with(name="session1", profile="default", plan=None, log_level="INFO") + mock_session_instance.run.assert_called_once() + + def test_session_resume_command_with_session_name(mock_session): mock_session_class, mock_session_instance = mock_session runner = CliRunner() runner.invoke(goose_cli, ["session", "resume", "session1", "--profile", "default"]) - mock_session_class.assert_called_once_with(name="session1", profile="default") + mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO") mock_session_instance.run.assert_called_once() @@ -59,7 +67,7 @@ def test_session_resume_command_without_session_name_use_latest_session( second_file_path = mock_session_files_path / "second.jsonl" mock_print.assert_called_once_with(f"Resuming most recent session: second from {second_file_path}") - mock_session_class.assert_called_once_with(name="second", profile="default") + mock_session_class.assert_called_once_with(name="second", profile="default", log_level="INFO") mock_session_instance.run.assert_called_once() @@ -121,5 +129,35 @@ def test_combined_group_commands(mock_session): mock_session_class, mock_session_instance = mock_session runner = CliRunner() runner.invoke(cli, ["session", "resume", "session1", "--profile", "default"]) - mock_session_class.assert_called_once_with(name="session1", profile="default") + mock_session_class.assert_called_once_with(name="session1", profile="default", log_level="INFO") mock_session_instance.run.assert_called_once() + + +def test_version_long_option(): + runner = CliRunner() + result = runner.invoke(cli, ["--version"]) + assert result.exit_code == 0 + assert "version" in result.output.lower() + + +def test_version_short_option(): + runner = CliRunner() + result = runner.invoke(cli, ["-V"]) + assert result.exit_code == 0 + assert "version" in result.output.lower() + + +def test_version_subcommand(): + runner = CliRunner() + result = runner.invoke(cli, ["version"]) + assert result.exit_code == 0 + assert "version" in result.output.lower() + + +def test_goose_no_args_print_help(): + runner = CliRunner() + result = runner.invoke(cli, []) + assert result.exit_code == 0 + assert "Usage:" in result.output + assert "Options:" in result.output + assert "Commands:" in result.output diff --git a/tests/cli/test_session.py b/tests/cli/test_session.py index 79a7c4a2b..f2437462c 100644 --- a/tests/cli/test_session.py +++ b/tests/cli/test_session.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from exchange import Message, ToolUse, ToolResult +from exchange import Exchange, Message, ToolUse, ToolResult from goose.cli.prompt.goose_prompt_session import GoosePromptSession from goose.cli.prompt.user_input import PromptAction, UserInput from goose.cli.session import Session @@ -19,12 +19,14 @@ def mock_specified_session_name(): @pytest.fixture def create_session_with_mock_configs(mock_sessions_path, exchange_factory, profile_factory): - with patch("goose.cli.session.build_exchange", return_value=exchange_factory()), patch( - "goose.cli.session.load_profile", return_value=profile_factory() - ), patch("goose.cli.session.SessionNotifier") as mock_session_notifier, patch( - "goose.cli.session.load_provider", return_value="provider" + with ( + patch("goose.cli.session.create_exchange") as mock_exchange, + patch("goose.cli.session.load_profile", return_value=profile_factory()), + patch("goose.cli.session.SessionNotifier") as mock_session_notifier, + patch("goose.cli.session.load_provider", return_value="provider"), ): mock_session_notifier.return_value = MagicMock() + mock_exchange.return_value = exchange_factory() def create_session(session_attributes: dict = {}): return Session(**session_attributes) @@ -79,59 +81,6 @@ def test_session_removes_tool_use_and_adds_resume_message_if_last_message_is_too ] -def test_save_session_create_session(mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name): - session = create_session_with_mock_configs() - session.exchange.messages.append(Message.assistant("Hello")) - - session.save_session() - session_file = mock_sessions_path / f"{SPECIFIED_SESSION_NAME}.jsonl" - assert session_file.exists() - - saved_messages = session.load_session() - assert len(saved_messages) == 1 - assert saved_messages[0].text == "Hello" - - -def test_save_session_resume_session_new_file( - mock_sessions_path, create_session_with_mock_configs, mock_specified_session_name, create_session_file -): - with patch("goose.cli.session.confirm", return_value=False): - existing_messages = [Message.assistant("existing_message")] - existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" - create_session_file(existing_messages, existing_session_file) - - new_session_file = mock_sessions_path / f"{SPECIFIED_SESSION_NAME}.jsonl" - assert not new_session_file.exists() - - session = create_session_with_mock_configs({"name": SESSION_NAME}) - session.exchange.messages.append(Message.assistant("new_message")) - - session.save_session() - - assert new_session_file.exists() - assert existing_session_file.exists() - - saved_messages = session.load_session() - assert [message.text for message in saved_messages] == ["existing_message", "new_message"] - - -def test_save_session_resume_session_existing_session_file( - mock_sessions_path, create_session_with_mock_configs, create_session_file -): - with patch("goose.cli.session.confirm", return_value=True): - existing_messages = [Message.assistant("existing_message")] - existing_session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" - create_session_file(existing_messages, existing_session_file) - - session = create_session_with_mock_configs({"name": SESSION_NAME}) - session.exchange.messages.append(Message.assistant("new_message")) - - session.save_session() - - saved_messages = session.load_session() - assert [message.text for message in saved_messages] == ["existing_message", "new_message"] - - def test_process_first_message_return_message(create_session_with_mock_configs): session = create_session_with_mock_configs() with patch.object( @@ -161,9 +110,49 @@ def test_process_first_message_return_last_exchange_message(create_session_with_ assert len(session.exchange.messages) == 0 -def test_generate_session_name(create_session_with_mock_configs): +def test_log_log_cost(create_session_with_mock_configs): session = create_session_with_mock_configs() - with patch.object(GoosePromptSession, "get_save_session_name", return_value=SPECIFIED_SESSION_NAME): - session.generate_session_name() + mock_logger = MagicMock() + cost_message = "You have used 100 tokens" + with ( + patch("exchange.Exchange.get_token_usage", return_value={}), + patch("goose.cli.session.get_total_cost_message", return_value=cost_message), + patch("goose.cli.session.get_logger", return_value=mock_logger), + ): + session._log_cost() + mock_logger.info.assert_called_once_with(cost_message) + + +def test_run_should_auto_save_session(create_session_with_mock_configs, mock_sessions_path): + def custom_exchange_generate(self, *args, **kwargs): + message = Message.assistant("Response") + self.add(message) + return message + + user_inputs = [ + UserInput(action=PromptAction.CONTINUE, text="Question1"), + UserInput(action=PromptAction.CONTINUE, text="Question2"), + UserInput(action=PromptAction.EXIT), + ] + + session = create_session_with_mock_configs({"name": SESSION_NAME}) + with ( + patch.object(GoosePromptSession, "get_user_input", side_effect=user_inputs), + patch.object(Exchange, "generate") as mock_generate, + patch("goose.cli.session.save_latest_session") as mock_save_latest_session, + ): + mock_generate.side_effect = lambda *args, **kwargs: custom_exchange_generate(session.exchange, *args, **kwargs) + session.run() + + session_file = mock_sessions_path / f"{SESSION_NAME}.jsonl" + assert session.exchange.generate.call_count == 2 + assert mock_save_latest_session.call_count == 2 + assert mock_save_latest_session.call_args_list[0][0][0] == session_file + assert session_file.exists() + - assert session.name == SPECIFIED_SESSION_NAME +def test_set_generated_session_name(create_session_with_mock_configs, mock_sessions_path): + generated_session_name = "generated_session_name" + with patch("goose.cli.session.droid", return_value=generated_session_name): + session = create_session_with_mock_configs({"name": None}) + assert session.name == generated_session_name diff --git a/tests/test_autocomplete.py b/tests/test_autocomplete.py new file mode 100644 index 000000000..789b5ec23 --- /dev/null +++ b/tests/test_autocomplete.py @@ -0,0 +1,34 @@ +import sys +import unittest.mock as mock + +from goose.utils.autocomplete import SUPPORTED_SHELLS, is_autocomplete_installed, setup_autocomplete + + +def test_supported_shells(): + assert SUPPORTED_SHELLS == ["bash", "zsh", "fish"] + + +def test_install_autocomplete(tmp_path): + file = tmp_path / "test_bash_autocomplete" + assert is_autocomplete_installed(file) is False + + file.write_text("_GOOSE_COMPLETE") + assert is_autocomplete_installed(file) is True + + +@mock.patch("sys.stdout") +def test_setup_bash(mocker: mock.MagicMock): + setup_autocomplete("bash", install=False) + sys.stdout.write.assert_called_with('eval "$(_GOOSE_COMPLETE=bash_source goose)"\n') + + +@mock.patch("sys.stdout") +def test_setup_zsh(mocker: mock.MagicMock): + setup_autocomplete("zsh", install=False) + sys.stdout.write.assert_called_with('eval "$(_GOOSE_COMPLETE=zsh_source goose)"\n') + + +@mock.patch("sys.stdout") +def test_setup_fish(mocker: mock.MagicMock): + setup_autocomplete("fish", install=False) + sys.stdout.write.assert_called_with("_GOOSE_COMPLETE=fish_source goose | source\n") diff --git a/tests/test_cli_main.py b/tests/test_cli_main.py new file mode 100644 index 000000000..9160be1bf --- /dev/null +++ b/tests/test_cli_main.py @@ -0,0 +1,21 @@ +from click.testing import CliRunner +from goose.cli.main import get_current_shell, shell_completions + + +def test_get_current_shell(mocker): + mocker.patch("os.getenv", return_value="/bin/bash") + assert get_current_shell() == "bash" + + +def test_shell_completions_install_invalid_combination(): + runner = CliRunner() + result = runner.invoke(shell_completions, ["--install", "--generate", "bash"]) + assert result.exit_code != 0 + assert "Only one of --install or --generate can be specified" in result.output + + +def test_shell_completions_install_no_option(): + runner = CliRunner() + result = runner.invoke(shell_completions, ["bash"]) + assert result.exit_code != 0 + assert "One of --install or --generate must be specified" in result.output diff --git a/tests/test_linting.py b/tests/test_linting.py index f6e246ff6..cae0d1a73 100644 --- a/tests/test_linting.py +++ b/tests/test_linting.py @@ -1,5 +1,11 @@ from goose.toolkit.lint import lint_toolkits +from goose.toolkit.lint import lint_providers + def test_lint_toolkits(): lint_toolkits() + + +def test_lint_providers(): + lint_providers() diff --git a/tests/test_profile.py b/tests/test_profile.py new file mode 100644 index 000000000..3c022f740 --- /dev/null +++ b/tests/test_profile.py @@ -0,0 +1,12 @@ +from goose.profile import ToolkitSpec + + +def test_profile_info(profile_factory): + profile = profile_factory( + { + "provider": "provider", + "processor": "processor", + "toolkits": [ToolkitSpec("developer"), ToolkitSpec("github")], + } + ) + assert profile.profile_info() == "provider:provider, processor:processor toolkits: developer, github" diff --git a/tests/toolkit/test_developer.py b/tests/toolkit/test_developer.py index e36c498c0..ff67b8fb7 100644 --- a/tests/toolkit/test_developer.py +++ b/tests/toolkit/test_developer.py @@ -55,7 +55,50 @@ def test_system_prompt_with_goosehints(temp_dir, developer_toolkit): assert system_prompt.endswith(expected_end) -def test_update_plan(developer_toolkit): +def test_system_prompt_with_goosehints_only_from_home_dir(temp_dir, developer_toolkit): + readme_file_home = Path.home() / ".config/goose/README.md" + readme_file_home.parent.mkdir(parents=True, exist_ok=True) + readme_file_home.write_text("This is from the README.md file in home.") + + home_hints_file = Path.home() / ".config/goose/.goosehints" + home_jinja_template_content = "Hints from home:\n\n{% include 'README.md' %}\nEnd." + home_hints_file.write_text(home_jinja_template_content) + + try: + with change_dir(temp_dir): + system_prompt = developer_toolkit.system() + expected_content_home = "Hints from home:\n\nThis is from the README.md file in home.\nEnd." + expected_end = f"Hints:\n{expected_content_home}" + assert system_prompt.endswith(expected_end) + finally: + home_hints_file.unlink() + readme_file_home.unlink() + + readme_file = temp_dir / "README.md" + readme_file.write_text("This is from the README.md file.") + + hints_file = temp_dir / ".goosehints" + jinja_template_content = "Hints from local:\n\n{% include 'README.md' %}\nEnd." + hints_file.write_text(jinja_template_content) + + home_hints_file = Path.home() / ".config/goose/.goosehints" + home_jinja_template_content = "Hints from home:\n\n{% include 'README.md' %}\nEnd." + home_hints_file.write_text(home_jinja_template_content) + + home_readme_file = Path.home() / ".config/goose/README.md" + home_readme_file.write_text("This is from the README.md file in home.") + + try: + with change_dir(temp_dir): + system_prompt = developer_toolkit.system() + expected_content_local = "Hints from local:\n\nThis is from the README.md file.\nEnd." + expected_content_home = "Hints from home:\n\nThis is from the README.md file in home.\nEnd." + expected_end = f"Hints:\n{expected_content_local}\n{expected_content_home}" + assert system_prompt.endswith(expected_end) + finally: + home_hints_file.unlink() + home_readme_file.unlink() + tasks = [ {"description": "Task 1", "status": "planned"}, {"description": "Task 2", "status": "complete"}, diff --git a/tests/toolkit/test_utils.py b/tests/toolkit/test_utils.py new file mode 100644 index 000000000..b5b45ac90 --- /dev/null +++ b/tests/toolkit/test_utils.py @@ -0,0 +1,65 @@ +from goose.toolkit.utils import parse_plan + + +def test_parse_plan_simple(): + plan_str = ( + "Here is python repo\n" + "-use uv\n" + "-do not use poetry\n\n" + "Now you should:\n\n" + "-Open a file\n" + "-Run a test" + ) + expected_result = { + "kickoff_message": "Here is python repo\n-use uv\n-do not use poetry\n\nNow you should:", + "tasks": ["Open a file", "Run a test"], + } + assert expected_result == parse_plan(plan_str) + + +def test_parse_plan_multiple_groups(): + plan_str = ( + "Here is python repo\n" + "-use uv\n" + "-do not use poetry\n\n" + "Now you should:\n\n" + "-Open a file\n" + "-Run a test\n\n" + "Now actually follow the steps:\n" + "-Step1\n" + "-Step2" + ) + expected_result = { + "kickoff_message": ( + "Here is python repo\n" + "-use uv\n" + "-do not use poetry\n\n" + "Now you should:\n\n" + "-Open a file\n" + "-Run a test\n\n" + "Now actually follow the steps:" + ), + "tasks": ["Step1", "Step2"], + } + assert expected_result == parse_plan(plan_str) + + +def test_parse_plan_empty_tasks(): + plan_str = "Here is python repo" + expected_result = {"kickoff_message": "Here is python repo", "tasks": []} + assert expected_result == parse_plan(plan_str) + + +def test_parse_plan_empty_kickoff_message(): + plan_str = "-task1\n-task2" + expected_result = {"kickoff_message": "", "tasks": ["task1", "task2"]} + assert expected_result == parse_plan(plan_str) + + +def test_parse_plan_with_numbers(): + plan_str = "Here is python repo\n" "Now you should:\n\n" "-1 Open a file\n" "-2 Run a test" + expected_result = { + "kickoff_message": "Here is python repo\nNow you should:", + "tasks": ["1 Open a file", "2 Run a test"], + } + assert expected_result == parse_plan(plan_str) diff --git a/tests/utils/test_cost_calculator.py b/tests/utils/test_cost_calculator.py new file mode 100644 index 000000000..edde1daa2 --- /dev/null +++ b/tests/utils/test_cost_calculator.py @@ -0,0 +1,47 @@ +from unittest.mock import patch +import pytest +from goose.utils._cost_calculator import _calculate_cost, get_total_cost_message +from exchange.providers.base import Usage + + +@pytest.fixture +def mock_prices(): + prices = {"gpt-4o": (5.00, 15.00), "gpt-4o-mini": (0.150, 0.600)} + with patch("goose.utils._cost_calculator.PRICES", prices) as mock_prices: + yield mock_prices + + +def test_calculate_cost(mock_prices): + cost = _calculate_cost("gpt-4o", Usage(input_tokens=10000, output_tokens=600, total_tokens=10600)) + assert cost == 0.059 + + +def test_get_total_cost_message(mock_prices): + message = get_total_cost_message( + { + "gpt-4o": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600), + "gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000), + } + ) + expected_message = ( + "Cost for model gpt-4o Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): $0.06\n" + + "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)" + + ": $2.85\nTotal cost: $2.91" + ) + assert message == expected_message + + +def test_get_total_cost_message_with_non_available_pricing(mock_prices): + message = get_total_cost_message( + { + "non_pricing_model": Usage(input_tokens=10000, output_tokens=600, total_tokens=10600), + "gpt-4o-mini": Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000), + } + ) + expected_message = ( + "Cost for model non_pricing_model Usage(input_tokens=10000, output_tokens=600, total_tokens=10600): " + + "Not available\n" + + "Cost for model gpt-4o-mini Usage(input_tokens=3000000, output_tokens=4000000, total_tokens=7000000)" + + ": $2.85\nTotal cost: $2.85" + ) + assert message == expected_message diff --git a/tests/utils/test_create_exchange.py b/tests/utils/test_create_exchange.py new file mode 100644 index 000000000..62fdde5f2 --- /dev/null +++ b/tests/utils/test_create_exchange.py @@ -0,0 +1,151 @@ +import os +from unittest.mock import MagicMock, patch + +from exchange.exchange import Exchange +from exchange.invalid_choice_error import InvalidChoiceError +from exchange.providers.base import MissingProviderEnvVariableError +import pytest + +from goose.notifier import Notifier +from goose.profile import Profile +from goose.utils._create_exchange import create_exchange + +TEST_PROFILE = MagicMock(spec=Profile) +TEST_EXCHANGE = MagicMock(spec=Exchange) +TEST_NOTIFIER = MagicMock(spec=Notifier) + + +@pytest.fixture +def mock_print(): + with patch("goose.utils._create_exchange.print") as mock_print: + yield mock_print + + +@pytest.fixture +def mock_prompt(): + with patch("goose.utils._create_exchange.prompt") as mock_prompt: + yield mock_prompt + + +@pytest.fixture +def mock_confirm(): + with patch("goose.utils._create_exchange.confirm") as mock_confirm: + yield mock_confirm + + +@pytest.fixture +def mock_sys_exit(): + with patch("sys.exit") as mock_exit: + yield mock_exit + + +@pytest.fixture +def mock_keyring_get_password(): + with patch("keyring.get_password") as mock_get_password: + yield mock_get_password + + +@pytest.fixture +def mock_keyring_set_password(): + with patch("keyring.set_password") as mock_set_password: + yield mock_set_password + + +def test_create_exchange_success(mock_print): + with patch("goose.utils._create_exchange.build_exchange", return_value=TEST_EXCHANGE): + assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + +def test_create_exchange_fail_with_invalid_choice_error(mock_print, mock_sys_exit): + expected_error = InvalidChoiceError( + attribute_name="provider", attribute_value="wrong_provider", available_values=["openai"] + ) + with patch("goose.utils._create_exchange.build_exchange", side_effect=expected_error): + create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) + + assert "Unknown provider: wrong_provider. Available providers: openai" in mock_print.call_args_list[0][0][0] + mock_sys_exit.assert_called_once_with(1) + + +class TestWhenProviderEnvVarNotFound: + API_KEY_ENV_VAR = "OPENAI_API_KEY" + API_KEY_ENV_VALUE = "api_key_value" + PROVIDER_NAME = "openai" + SERVICE_NAME = "goose" + EXPECTED_ERROR = MissingProviderEnvVariableError(env_variable=API_KEY_ENV_VAR, provider=PROVIDER_NAME) + + def test_create_exchange_get_api_key_from_keychain( + self, mock_print, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = self.API_KEY_ENV_VALUE + + assert create_exchange(profile=TEST_PROFILE, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_get_password.assert_called_once_with(self.SERVICE_NAME, self.API_KEY_ENV_VAR) + mock_print.assert_called_once_with( + f"Using {self.API_KEY_ENV_VAR} value for {self.PROVIDER_NAME} from your keychain" + ) + mock_sys_exit.assert_not_called() + mock_keyring_set_password.assert_not_called() + + def test_create_exchange_ask_api_key_and_user_set_in_keychain( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password, mock_print + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = self.API_KEY_ENV_VALUE + mock_confirm.return_value = True + + assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_set_password.assert_called_once_with( + self.SERVICE_NAME, self.API_KEY_ENV_VAR, self.API_KEY_ENV_VALUE + ) + mock_confirm.assert_called_once_with( + f"Would you like to save the {self.API_KEY_ENV_VAR} value to your keychain?" + ) + mock_print.assert_called_once_with( + f"Saved {self.API_KEY_ENV_VAR} to your key_chain. " + + f"service_name: goose, user_name: {self.API_KEY_ENV_VAR}" + ) + mock_sys_exit.assert_not_called() + + def test_create_exchange_ask_api_key_and_user_not_set_in_keychain( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_keyring_set_password + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=[self.EXPECTED_ERROR, TEST_EXCHANGE]): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = self.API_KEY_ENV_VALUE + mock_confirm.return_value = False + + assert create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) == TEST_EXCHANGE + + assert os.environ[self.API_KEY_ENV_VAR] == self.API_KEY_ENV_VALUE + mock_keyring_set_password.assert_not_called() + mock_sys_exit.assert_not_called() + + def test_create_exchange_fails_when_user_not_provide_api_key( + self, mock_prompt, mock_confirm, mock_sys_exit, mock_keyring_get_password, mock_print + ): + self._clean_env() + with patch("goose.utils._create_exchange.build_exchange", side_effect=self.EXPECTED_ERROR): + mock_keyring_get_password.return_value = None + mock_prompt.return_value = None + mock_confirm.return_value = False + + create_exchange(profile=TEST_NOTIFIER, notifier=TEST_NOTIFIER) + + assert ( + "Please set the required environment variable to continue." + in mock_print.call_args_list[0][0][0].renderable + ) + mock_sys_exit.assert_called_once_with(1) + + def _clean_env(self): + os.environ.pop(self.API_KEY_ENV_VAR, None) diff --git a/tests/utils/test_session_file.py b/tests/utils/test_session_file.py index d922bd81d..6a2a64981 100644 --- a/tests/utils/test_session_file.py +++ b/tests/utils/test_session_file.py @@ -1,8 +1,16 @@ +import os from pathlib import Path import pytest from exchange import Message -from goose.utils.session_file import list_sorted_session_files, read_from_file, session_file_exists, write_to_file +from goose.utils.session_file import ( + list_sorted_session_files, + read_from_file, + read_or_create_file, + save_latest_session, + session_file_exists, + write_to_file, +) @pytest.fixture @@ -32,6 +40,23 @@ def test_read_from_file_non_jsonl_file(file_path): read_from_file(file_path) +def test_read_or_create_file_when_file_not_exist(tmp_path): + file_path = tmp_path / "no_existing.json" + + assert read_or_create_file(file_path) == [] + assert os.path.exists(file_path) + + +def test_read_or_create_file_when_file_exists(file_path): + messages = [ + Message.user("prompt1"), + ] + write_to_file(file_path, messages) + + assert file_path.exists() + assert read_from_file(file_path) == messages + + def test_list_sorted_session_files(tmp_path): session_files_directory = tmp_path / "session_files_dir" session_files_directory.mkdir() @@ -71,6 +96,21 @@ def test_session_file_exists_return_true_when_session_file_exists(tmp_path): assert session_file_exists(session_files_directory) +def test_save_latest_session(file_path, tmp_path): + messages = [ + Message.user("prompt1"), + Message.user("prompt2"), + ] + write_to_file(file_path, messages) + + messages.append(Message.user("prompt3")) + save_latest_session(file_path, messages) + + messages_in_file = read_from_file(file_path) + assert messages_in_file == messages + assert len(messages_in_file) == 3 + + def create_session_file(file_path, file_name) -> Path: file = file_path / f"{file_name}.jsonl" file.touch()