Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZhipuAI provider #89

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Simple, unified interface to multiple Generative AI providers.
`aisuite` makes it easy for developers to use multiple LLM through a standardized interface. Using an interface similar to OpenAI's, `aisuite` makes it easy to interact with the most popular LLMs and compare the results. It is a thin wrapper around python client libraries, and allows creators to seamlessly swap out and test responses from different LLM providers without changing their code. Today, the library is primarily focussed on chat completions. We will expand it cover more use cases in near future.

Currently supported providers are -
OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace and Ollama.
OpenAI, Anthropic, Azure, Google, AWS, Groq, Mistral, HuggingFace, Ollama, and ZhipuAI.
To maximize stability, `aisuite` uses either the HTTP endpoint or the SDK for making calls to the provider.

## Installation
Expand Down Expand Up @@ -107,4 +107,4 @@ We follow a convention-based approach for loading providers, which relies on str
```
in providers/openai_provider.py

This convention simplifies the addition of new providers and ensures consistency across provider implementations.
This convention simplifies the addition of new providers and ensures consistency across provider implementations.
26 changes: 26 additions & 0 deletions aisuite/providers/zhipuai_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

from zhipuai import ZhipuAI
from aisuite.provider import Provider


class ZhipuaiProvider(Provider):
def __init__(self, **config):
"""
Initialize the ZhipuAI provider with the given configuration.
Pass the entire configuration dictionary to the ZhipuAI client constructor.
"""
# Ensure API key is provided either in config or via environment variable
config.setdefault("api_key", os.getenv("ZHIPUAI_API_KEY"))
if not config["api_key"]:
raise ValueError(
"API key is missing. Please provide it in the config or set the ZHIPUAI_API_KEY environment variable."
)
self.client = ZhipuAI(**config)

def chat_completions_create(self, model, messages, **kwargs):
return self.client.chat.completions.create(
model=model,
messages=messages,
**kwargs # Pass any additional arguments to the ZhipuAI API
)
44 changes: 44 additions & 0 deletions guides/zhipuai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# ZhipuAI

To use ZhipuAI with `aisuite`, you'll need a [ZhipuAI account](https://open.zhipuai.cn/). After logging in, obtain your API key from your account settings. Once you have your key, add it to your environment as follows:

```shell
export ZHIPUAI_API_KEY="your-zhipuai-api-key"
```

## Create a Chat Completion

Install the `zhipuai` Python client:

Example with pip:
```shell
pip install zhipuai
```

Example with poetry:
```shell
poetry add zhipuai
```

In your code:
```python
import aisuite as ai
client = ai.Client()

provider = "zhipuai"
model_id = "glm-4"

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What's the weather like in Beijing?"},
]

response = client.chat.completions.create(
model=f"{provider}:{model_id}",
messages=messages,
)

print(response.choices[0].message.content)
```

Happy coding! If you'd like to contribute, please read our [Contributing Guide](../CONTRIBUTING.md).
5,912 changes: 3,069 additions & 2,843 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ vertexai = { version = "^1.63.0", optional = true }
groq = { version = "^0.9.0", optional = true }
mistralai = { version = "^1.0.3", optional = true }
openai = { version = "^1.35.8", optional = true }
zhipuai = { version = "^2.1.5", optional = true }

# Optional dependencies for different providers
[tool.poetry.extras]
Expand All @@ -25,7 +26,8 @@ huggingface = []
mistral = ["mistralai"]
ollama = []
openai = ["openai"]
all = ["anthropic", "aws", "google", "groq", "mistral", "openai"] # To install all providers
zhipuai = ["zhipuai"]
all = ["anthropic", "aws", "google", "groq", "mistral", "openai", "zhipuai"] # To install all providers

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
Expand All @@ -44,6 +46,7 @@ chromadb = "^0.5.4"
sentence-transformers = "^3.0.1"
datasets = "^2.20.0"
vertexai = "^1.63.0"
zhipuai = "^2.1.5"

[build-system]
requires = ["poetry-core"]
Expand Down
14 changes: 14 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@


class TestClient(unittest.TestCase):
@patch("aisuite.providers.zhipuai_provider.ZhipuaiProvider.chat_completions_create")
@patch("aisuite.providers.mistral_provider.MistralProvider.chat_completions_create")
@patch("aisuite.providers.groq_provider.GroqProvider.chat_completions_create")
@patch("aisuite.providers.openai_provider.OpenaiProvider.chat_completions_create")
Expand All @@ -26,6 +27,7 @@ def test_client_chat_completions(
mock_openai,
mock_groq,
mock_mistral,
mock_zhipuai,
):
# Mock responses from providers
mock_openai.return_value = "OpenAI Response"
Expand All @@ -36,6 +38,7 @@ def test_client_chat_completions(
mock_mistral.return_value = "Mistral Response"
mock_google.return_value = "Google Response"
mock_fireworks.return_value = "Fireworks Response"
mock_zhipuai.return_value = "ZhipuAI Response"

# Provider configurations
provider_configs = {
Expand Down Expand Up @@ -64,6 +67,9 @@ def test_client_chat_completions(
"fireworks": {
"api_key": "fireworks-api-key",
},
"zhipuai": {
"api_key": "zhipuai-api-key",
},
}

# Initialize the client
Expand Down Expand Up @@ -134,6 +140,14 @@ def test_client_chat_completions(
self.assertEqual(fireworks_response, "Fireworks Response")
mock_fireworks.assert_called_once()

# Test ZhipuAI model
zhipuai_model = "zhipuai" + ":" + "glm-4"
zhipuai_response = client.chat.completions.create(
zhipuai_model, messages=messages
)
self.assertEqual(zhipuai_response, "ZhipuAI Response")
mock_zhipuai.assert_called_once()

# Test that new instances of Completion are not created each time we make an inference call.
compl_instance = client.chat.completions
next_compl_instance = client.chat.completions
Expand Down
46 changes: 46 additions & 0 deletions tests/providers/test_zhipuai_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from unittest.mock import MagicMock, patch

import pytest

from aisuite.providers.zhipuai_provider import ZhipuaiProvider


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("ZHIPUAI_API_KEY", "test-api-key")


def test_zhipuai_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "glm-4"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

provider = ZhipuaiProvider()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = response_text_content

with patch.object(
provider.client.chat.completions,
"create",
return_value=mock_response,
) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

mock_create.assert_called_with(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert response.choices[0].message.content == response_text_content