Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Chat model with tools docs and response format #202

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ messages = [

chat_completions = client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
)
```

Expand Down Expand Up @@ -207,7 +207,7 @@ client = AsyncAI21Client(
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
)

print(response)
Expand All @@ -227,8 +227,9 @@ A more detailed example can be found [here](examples/studio/chat/chat_completion
### Supported Models:

- j2-light
- j2-mid
- j2-ultra
- [j2-ultra](#Chat)
- [j2-mid](#Completion)
- [jamba-instruct](#Chat-Completion)

you can read more about the models [here](https://docs.ai21.com/reference/j2-complete-api-ref#jurassic-2-models).

Expand Down Expand Up @@ -270,6 +271,36 @@ completion_response = client.completion.create(
)
```

### Chat Completion

```python
from ai21 import AI21Client
from ai21.models.chat import ChatMessage

system = "You're a support engineer in a SaaS company"
messages = [
ChatMessage(content=system, role="system"),
ChatMessage(content="Hello, I need help with a signup process.", role="user"),
ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"),
ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"),
]

client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="jamba-instruct",
max_tokens=100,
temperature=0.7,
top_p=1.0,
stop=["\n"],
)

print(response)
```

Note that jamba-instruct supports async streaming as well.

</details>

For a more detailed example, see the completion [examples](examples/studio/completion.py).
Expand All @@ -290,7 +321,7 @@ client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-instruct",
stream=True,
)
for chunk in response:
Expand All @@ -314,7 +345,7 @@ client = AsyncAI21Client()
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
stream=True,
)
async for chunk in response:
Expand Down Expand Up @@ -700,7 +731,7 @@ messages = [
]

response = client.chat.completions.create(
model="jamba-instruct",
model="jamba-1.5-mini",
messages=messages,
)
```
Expand Down
17 changes: 16 additions & 1 deletion ai21/clients/studio/resources/chat/async_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.async_stream import AsyncStream
from ai21.types import NotGiven, NOT_GIVEN

Expand All @@ -23,7 +26,10 @@ async def create(
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
stream: Optional[False] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
pass
Expand All @@ -39,6 +45,9 @@ async def create(
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> AsyncStream[ChatCompletionChunk]:
pass
Expand All @@ -53,6 +62,9 @@ async def create(
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse | AsyncStream[ChatCompletionChunk]:
if any(isinstance(item, J2ChatMessage) for item in messages):
Expand All @@ -70,6 +82,9 @@ async def create(
top_p=top_p,
n=n,
stream=stream or False,
tools=tools,
response_format=response_format,
documents=documents,
**kwargs,
)

Expand Down
13 changes: 13 additions & 0 deletions ai21/clients/studio/resources/chat/base_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from typing import List, Optional, Union, Any, Dict, Literal

from ai21.models.chat import ChatMessage
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.types import NotGiven
from ai21.utils.typing import remove_not_given
from ai21.models._pydantic_compatibility import _to_dict
Expand Down Expand Up @@ -40,6 +43,9 @@ def _create_body(
stop: Optional[Union[str, List[str]]] | NotGiven,
n: Optional[int] | NotGiven,
stream: Literal[False] | Literal[True] | NotGiven,
tools: List[ToolDefinition] | NotGiven,
response_format: ResponseFormat | NotGiven,
documents: List[DocumentSchema] | NotGiven,
**kwargs: Any,
) -> Dict[str, Any]:
return remove_not_given(
Expand All @@ -52,6 +58,13 @@ def _create_body(
"stop": stop,
"n": n,
"stream": stream,
"tools": tools,
"response_format": (
_to_dict(response_format) if not isinstance(response_format, NotGiven) else response_format
),
"documents": (
[_to_dict(document) for document in documents] if not isinstance(documents, NotGiven) else documents
),
**kwargs,
}
)
15 changes: 15 additions & 0 deletions ai21/clients/studio/resources/chat/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.stream import Stream
from ai21.types import NotGiven, NOT_GIVEN

Expand All @@ -24,6 +27,9 @@ def create(
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NOT_GIVEN = NOT_GIVEN,
response_format: ResponseFormat | NOT_GIVEN = NOT_GIVEN,
documents: List[DocumentSchema] | NOT_GIVEN = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
pass
Expand All @@ -39,6 +45,9 @@ def create(
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NOT_GIVEN = NOT_GIVEN,
response_format: ResponseFormat | NOT_GIVEN = NOT_GIVEN,
documents: List[DocumentSchema] | NOT_GIVEN = NOT_GIVEN,
**kwargs: Any,
) -> Stream[ChatCompletionChunk]:
pass
Expand All @@ -53,6 +62,9 @@ def create(
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse | Stream[ChatCompletionChunk]:
if any(isinstance(item, J2ChatMessage) for item in messages):
Expand All @@ -70,6 +82,9 @@ def create(
top_p=top_p,
n=n,
stream=stream or False,
tools=tools,
response_format=response_format,
documents=documents,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion ai21/http_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def execute_http_request(

if response.status_code != httpx.codes.OK:
_logger.error(
f"Calling {method} {self._base_url} failed with a non-200 response code: {response.status_code}"
f"Calling {method} {self._base_url} failed with a non-200 "
f"response code: {response.status_code} headers: {response.headers}"
)
handle_non_success_response(response.status_code, response.text)

Expand Down
21 changes: 20 additions & 1 deletion ai21/models/chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from .chat_completion_response import ChatCompletionResponse
from .chat_completion_response import ChatCompletionResponseChoice
from .chat_message import ChatMessage
from .chat_message import ChatMessage, AssistantMessage, ToolMessage, UserMessage, SystemMessage
from .document_schema import DocumentSchema
from .function_tool_definition import FunctionToolDefinition
from .response_format import ResponseFormat
from .role_type import RoleType as RoleType
from .chat_completion_chunk import ChatCompletionChunk, ChoicesChunk, ChoiceDelta
from .tool_call import ToolCall
from .tool_defintions import ToolDefinition
from .tool_function import ToolFunction

__all__ = [
"ChatCompletionResponse",
Expand All @@ -14,4 +20,17 @@
"ChatCompletionChunk",
"ChoicesChunk",
"ChoiceDelta",
"AssistantMessage",
"ToolMessage",
"UserMessage",
"SystemMessage",
"DocumentSchema",
"FunctionToolDefinition",
"ResponseFormat",
"ToolCall",
"ToolDefinition",
"ToolFunction",
"ToolParameters",
]

from .tool_parameters import ToolParameters
4 changes: 2 additions & 2 deletions ai21/models/chat/chat_completion_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.logprobs import Logprobs
from ai21.models.usage_info import UsageInfo
from .chat_message import ChatMessage
from .chat_message import AssistantMessage


class ChatCompletionResponseChoice(AI21BaseModel):
index: int
message: ChatMessage
message: AssistantMessage
logprobs: Optional[Logprobs] = None
finish_reason: Optional[str] = None

Expand Down
22 changes: 21 additions & 1 deletion ai21/models/chat/chat_message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
from __future__ import annotations
from typing import Literal, List, Optional

from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.chat.tool_call import ToolCall


class ChatMessage(AI21BaseModel):
role: str
content: str


class AssistantMessage(ChatMessage):
role: Literal["assistant"] = "assistant"
tool_calls: Optional[List[ToolCall]] = None
content: Optional[str] = None


class ToolMessage(ChatMessage):
role: Literal["tool"] = "tool"
tool_call_id: str


class UserMessage(ChatMessage):
role: Literal["user"] = "user"


class SystemMessage(ChatMessage):
role: Literal["system"] = "system"
9 changes: 9 additions & 0 deletions ai21/models/chat/document_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Optional, Dict

from ai21.models.ai21_base_model import AI21BaseModel


class DocumentSchema(AI21BaseModel):
content: str
id: Optional[str] = None
metadata: Optional[Dict[str, str]] = None
9 changes: 9 additions & 0 deletions ai21/models/chat/function_tool_definition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing_extensions import TypedDict, Required

from ai21.models.chat.tool_parameters import ToolParameters


class FunctionToolDefinition(TypedDict, total=False):
name: Required[str]
description: str
parameters: ToolParameters
7 changes: 7 additions & 0 deletions ai21/models/chat/response_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Literal

from ai21.models.ai21_base_model import AI21BaseModel


class ResponseFormat(AI21BaseModel):
type: Literal["text", "json_object"]
2 changes: 2 additions & 0 deletions ai21/models/chat/role_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
class RoleType(str, Enum):
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
SYSTEM = "system"
10 changes: 10 additions & 0 deletions ai21/models/chat/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Literal

from ai21.models.ai21_base_model import AI21BaseModel
from ai21.models.chat.tool_function import ToolFunction


class ToolCall(AI21BaseModel):
id: str
function: ToolFunction
type: Literal["function"] = "function"
8 changes: 8 additions & 0 deletions ai21/models/chat/tool_defintions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing_extensions import Literal, TypedDict, Required

from ai21.models.chat import FunctionToolDefinition


class ToolDefinition(TypedDict, total=False):
type: Required[Literal["function"]]
function: Required[FunctionToolDefinition]
6 changes: 6 additions & 0 deletions ai21/models/chat/tool_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ai21.models.ai21_base_model import AI21BaseModel


class ToolFunction(AI21BaseModel):
name: str
arguments: str
7 changes: 7 additions & 0 deletions ai21/models/chat/tool_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing_extensions import Literal, Any, Dict, List, TypedDict, Required


class ToolParameters(TypedDict, total=False):
type: Literal["object"]
properties: Required[Dict[str, Any]]
required: List[str]
2 changes: 1 addition & 1 deletion examples/studio/chat/async_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
max_tokens=100,
temperature=0.7,
top_p=1.0,
Expand Down
Loading
Loading