diff --git a/README.md b/README.md index db9106dc..e0daa085 100644 --- a/README.md +++ b/README.md @@ -177,7 +177,7 @@ messages = [ chat_completions = client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-1.5-mini", ) ``` @@ -207,7 +207,7 @@ client = AsyncAI21Client( async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-1.5-mini", ) print(response) @@ -227,8 +227,9 @@ A more detailed example can be found [here](examples/studio/chat/chat_completion ### Supported Models: - j2-light -- j2-mid -- j2-ultra +- [j2-ultra](#Chat) +- [j2-mid](#Completion) +- [jamba-instruct](#Chat-Completion) you can read more about the models [here](https://docs.ai21.com/reference/j2-complete-api-ref#jurassic-2-models). @@ -270,6 +271,36 @@ completion_response = client.completion.create( ) ``` +### Chat Completion + +```python +from ai21 import AI21Client +from ai21.models.chat import ChatMessage + +system = "You're a support engineer in a SaaS company" +messages = [ + ChatMessage(content=system, role="system"), + ChatMessage(content="Hello, I need help with a signup process.", role="user"), + ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), +] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="jamba-instruct", + max_tokens=100, + temperature=0.7, + top_p=1.0, + stop=["\n"], +) + +print(response) +``` + +Note that jamba-instruct supports async streaming as well. + For a more detailed example, see the completion [examples](examples/studio/completion.py). @@ -290,7 +321,7 @@ client = AI21Client() response = client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-instruct", stream=True, ) for chunk in response: @@ -314,7 +345,7 @@ client = AsyncAI21Client() async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-1.5-mini", stream=True, ) async for chunk in response: @@ -700,7 +731,7 @@ messages = [ ] response = client.chat.completions.create( - model="jamba-instruct", + model="jamba-1.5-mini", messages=messages, ) ``` diff --git a/ai21/clients/studio/resources/chat/async_chat_completions.py b/ai21/clients/studio/resources/chat/async_chat_completions.py index e2881fda..8a746f7f 100644 --- a/ai21/clients/studio/resources/chat/async_chat_completions.py +++ b/ai21/clients/studio/resources/chat/async_chat_completions.py @@ -6,6 +6,9 @@ from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions from ai21.models import ChatMessage as J2ChatMessage from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk +from ai21.models.chat.document_schema import DocumentSchema +from ai21.models.chat.response_format import ResponseFormat +from ai21.models.chat.tool_defintions import ToolDefinition from ai21.stream.async_stream import AsyncStream from ai21.types import NotGiven, NOT_GIVEN @@ -23,7 +26,10 @@ async def create( top_p: float | NotGiven = NOT_GIVEN, stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, - stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, + stream: Optional[False] | NotGiven = NOT_GIVEN, + tools: List[ToolDefinition] | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + documents: List[DocumentSchema] | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: pass @@ -39,6 +45,9 @@ async def create( top_p: float | NotGiven = NOT_GIVEN, stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, + tools: List[ToolDefinition] | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + documents: List[DocumentSchema] | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> AsyncStream[ChatCompletionChunk]: pass @@ -53,6 +62,9 @@ async def create( stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + tools: List[ToolDefinition] | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + documents: List[DocumentSchema] | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse | AsyncStream[ChatCompletionChunk]: if any(isinstance(item, J2ChatMessage) for item in messages): @@ -70,6 +82,9 @@ async def create( top_p=top_p, n=n, stream=stream or False, + tools=tools, + response_format=response_format, + documents=documents, **kwargs, ) diff --git a/ai21/clients/studio/resources/chat/base_chat_completions.py b/ai21/clients/studio/resources/chat/base_chat_completions.py index 0bbfce35..f66528dd 100644 --- a/ai21/clients/studio/resources/chat/base_chat_completions.py +++ b/ai21/clients/studio/resources/chat/base_chat_completions.py @@ -5,6 +5,9 @@ from typing import List, Optional, Union, Any, Dict, Literal from ai21.models.chat import ChatMessage +from ai21.models.chat.document_schema import DocumentSchema +from ai21.models.chat.response_format import ResponseFormat +from ai21.models.chat.tool_defintions import ToolDefinition from ai21.types import NotGiven from ai21.utils.typing import remove_not_given from ai21.models._pydantic_compatibility import _to_dict @@ -40,6 +43,9 @@ def _create_body( stop: Optional[Union[str, List[str]]] | NotGiven, n: Optional[int] | NotGiven, stream: Literal[False] | Literal[True] | NotGiven, + tools: List[ToolDefinition] | NotGiven, + response_format: ResponseFormat | NotGiven, + documents: List[DocumentSchema] | NotGiven, **kwargs: Any, ) -> Dict[str, Any]: return remove_not_given( @@ -52,6 +58,13 @@ def _create_body( "stop": stop, "n": n, "stream": stream, + "tools": tools, + "response_format": ( + _to_dict(response_format) if not isinstance(response_format, NotGiven) else response_format + ), + "documents": ( + [_to_dict(document) for document in documents] if not isinstance(documents, NotGiven) else documents + ), **kwargs, } ) diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index e71dc4bd..cbb67e8c 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -6,6 +6,9 @@ from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions from ai21.models import ChatMessage as J2ChatMessage from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk +from ai21.models.chat.document_schema import DocumentSchema +from ai21.models.chat.response_format import ResponseFormat +from ai21.models.chat.tool_defintions import ToolDefinition from ai21.stream.stream import Stream from ai21.types import NotGiven, NOT_GIVEN @@ -24,6 +27,9 @@ def create( stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, + tools: List[ToolDefinition] | NOT_GIVEN = NOT_GIVEN, + response_format: ResponseFormat | NOT_GIVEN = NOT_GIVEN, + documents: List[DocumentSchema] | NOT_GIVEN = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: pass @@ -39,6 +45,9 @@ def create( top_p: float | NotGiven = NOT_GIVEN, stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, + tools: List[ToolDefinition] | NOT_GIVEN = NOT_GIVEN, + response_format: ResponseFormat | NOT_GIVEN = NOT_GIVEN, + documents: List[DocumentSchema] | NOT_GIVEN = NOT_GIVEN, **kwargs: Any, ) -> Stream[ChatCompletionChunk]: pass @@ -53,6 +62,9 @@ def create( stop: str | List[str] | NotGiven = NOT_GIVEN, n: int | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + tools: List[ToolDefinition] | NotGiven = NOT_GIVEN, + response_format: ResponseFormat | NotGiven = NOT_GIVEN, + documents: List[DocumentSchema] | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse | Stream[ChatCompletionChunk]: if any(isinstance(item, J2ChatMessage) for item in messages): @@ -70,6 +82,9 @@ def create( top_p=top_p, n=n, stream=stream or False, + tools=tools, + response_format=response_format, + documents=documents, **kwargs, ) diff --git a/ai21/http_client/http_client.py b/ai21/http_client/http_client.py index 7f5c4600..bacc10bb 100644 --- a/ai21/http_client/http_client.py +++ b/ai21/http_client/http_client.py @@ -99,7 +99,8 @@ def execute_http_request( if response.status_code != httpx.codes.OK: _logger.error( - f"Calling {method} {self._base_url} failed with a non-200 response code: {response.status_code}" + f"Calling {method} {self._base_url} failed with a non-200 " + f"response code: {response.status_code} headers: {response.headers}" ) handle_non_success_response(response.status_code, response.text) diff --git a/ai21/models/chat/__init__.py b/ai21/models/chat/__init__.py index d9332ff1..8963aada 100644 --- a/ai21/models/chat/__init__.py +++ b/ai21/models/chat/__init__.py @@ -2,9 +2,15 @@ from .chat_completion_response import ChatCompletionResponse from .chat_completion_response import ChatCompletionResponseChoice -from .chat_message import ChatMessage +from .chat_message import ChatMessage, AssistantMessage, ToolMessage, UserMessage, SystemMessage +from .document_schema import DocumentSchema +from .function_tool_definition import FunctionToolDefinition +from .response_format import ResponseFormat from .role_type import RoleType as RoleType from .chat_completion_chunk import ChatCompletionChunk, ChoicesChunk, ChoiceDelta +from .tool_call import ToolCall +from .tool_defintions import ToolDefinition +from .tool_function import ToolFunction __all__ = [ "ChatCompletionResponse", @@ -14,4 +20,17 @@ "ChatCompletionChunk", "ChoicesChunk", "ChoiceDelta", + "AssistantMessage", + "ToolMessage", + "UserMessage", + "SystemMessage", + "DocumentSchema", + "FunctionToolDefinition", + "ResponseFormat", + "ToolCall", + "ToolDefinition", + "ToolFunction", + "ToolParameters", ] + +from .tool_parameters import ToolParameters diff --git a/ai21/models/chat/chat_completion_response.py b/ai21/models/chat/chat_completion_response.py index 4a831845..d1aff878 100644 --- a/ai21/models/chat/chat_completion_response.py +++ b/ai21/models/chat/chat_completion_response.py @@ -3,12 +3,12 @@ from ai21.models.ai21_base_model import AI21BaseModel from ai21.models.logprobs import Logprobs from ai21.models.usage_info import UsageInfo -from .chat_message import ChatMessage +from .chat_message import AssistantMessage class ChatCompletionResponseChoice(AI21BaseModel): index: int - message: ChatMessage + message: AssistantMessage logprobs: Optional[Logprobs] = None finish_reason: Optional[str] = None diff --git a/ai21/models/chat/chat_message.py b/ai21/models/chat/chat_message.py index 6cc93ce9..2b3c5827 100644 --- a/ai21/models/chat/chat_message.py +++ b/ai21/models/chat/chat_message.py @@ -1,8 +1,28 @@ -from __future__ import annotations +from typing import Literal, List, Optional from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.chat.tool_call import ToolCall class ChatMessage(AI21BaseModel): role: str content: str + + +class AssistantMessage(ChatMessage): + role: Literal["assistant"] = "assistant" + tool_calls: Optional[List[ToolCall]] = None + content: Optional[str] = None + + +class ToolMessage(ChatMessage): + role: Literal["tool"] = "tool" + tool_call_id: str + + +class UserMessage(ChatMessage): + role: Literal["user"] = "user" + + +class SystemMessage(ChatMessage): + role: Literal["system"] = "system" diff --git a/ai21/models/chat/document_schema.py b/ai21/models/chat/document_schema.py new file mode 100644 index 00000000..445868f3 --- /dev/null +++ b/ai21/models/chat/document_schema.py @@ -0,0 +1,9 @@ +from typing import Optional, Dict + +from ai21.models.ai21_base_model import AI21BaseModel + + +class DocumentSchema(AI21BaseModel): + content: str + id: Optional[str] = None + metadata: Optional[Dict[str, str]] = None diff --git a/ai21/models/chat/function_tool_definition.py b/ai21/models/chat/function_tool_definition.py new file mode 100644 index 00000000..dcce3fdc --- /dev/null +++ b/ai21/models/chat/function_tool_definition.py @@ -0,0 +1,9 @@ +from typing_extensions import TypedDict, Required + +from ai21.models.chat.tool_parameters import ToolParameters + + +class FunctionToolDefinition(TypedDict, total=False): + name: Required[str] + description: str + parameters: ToolParameters diff --git a/ai21/models/chat/response_format.py b/ai21/models/chat/response_format.py new file mode 100644 index 00000000..70bdfe85 --- /dev/null +++ b/ai21/models/chat/response_format.py @@ -0,0 +1,7 @@ +from typing import Literal + +from ai21.models.ai21_base_model import AI21BaseModel + + +class ResponseFormat(AI21BaseModel): + type: Literal["text", "json_object"] diff --git a/ai21/models/chat/role_type.py b/ai21/models/chat/role_type.py index a1630a23..eda033bc 100644 --- a/ai21/models/chat/role_type.py +++ b/ai21/models/chat/role_type.py @@ -4,3 +4,5 @@ class RoleType(str, Enum): USER = "user" ASSISTANT = "assistant" + TOOL = "tool" + SYSTEM = "system" diff --git a/ai21/models/chat/tool_call.py b/ai21/models/chat/tool_call.py new file mode 100644 index 00000000..8cb50105 --- /dev/null +++ b/ai21/models/chat/tool_call.py @@ -0,0 +1,10 @@ +from typing import Literal + +from ai21.models.ai21_base_model import AI21BaseModel +from ai21.models.chat.tool_function import ToolFunction + + +class ToolCall(AI21BaseModel): + id: str + function: ToolFunction + type: Literal["function"] = "function" diff --git a/ai21/models/chat/tool_defintions.py b/ai21/models/chat/tool_defintions.py new file mode 100644 index 00000000..cfa63371 --- /dev/null +++ b/ai21/models/chat/tool_defintions.py @@ -0,0 +1,8 @@ +from typing_extensions import Literal, TypedDict, Required + +from ai21.models.chat import FunctionToolDefinition + + +class ToolDefinition(TypedDict, total=False): + type: Required[Literal["function"]] + function: Required[FunctionToolDefinition] diff --git a/ai21/models/chat/tool_function.py b/ai21/models/chat/tool_function.py new file mode 100644 index 00000000..97f234ad --- /dev/null +++ b/ai21/models/chat/tool_function.py @@ -0,0 +1,6 @@ +from ai21.models.ai21_base_model import AI21BaseModel + + +class ToolFunction(AI21BaseModel): + name: str + arguments: str diff --git a/ai21/models/chat/tool_parameters.py b/ai21/models/chat/tool_parameters.py new file mode 100644 index 00000000..ba89d688 --- /dev/null +++ b/ai21/models/chat/tool_parameters.py @@ -0,0 +1,7 @@ +from typing_extensions import Literal, Any, Dict, List, TypedDict, Required + + +class ToolParameters(TypedDict, total=False): + type: Literal["object"] + properties: Required[Dict[str, Any]] + required: List[str] diff --git a/examples/studio/chat/async_chat_completions.py b/examples/studio/chat/async_chat_completions.py index 2dcca1f4..d7bbb2df 100644 --- a/examples/studio/chat/async_chat_completions.py +++ b/examples/studio/chat/async_chat_completions.py @@ -18,7 +18,7 @@ async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-1.5-mini", max_tokens=100, temperature=0.7, top_p=1.0, diff --git a/examples/studio/chat/async_stream_chat_completions.py b/examples/studio/chat/async_stream_chat_completions.py index b8309152..994205dc 100644 --- a/examples/studio/chat/async_stream_chat_completions.py +++ b/examples/studio/chat/async_stream_chat_completions.py @@ -17,7 +17,7 @@ async def main(): response = await client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-1.5-large", max_tokens=100, stream=True, ) diff --git a/examples/studio/chat/chat_completions.py b/examples/studio/chat/chat_completions.py index 5d5c8a77..727312bd 100644 --- a/examples/studio/chat/chat_completions.py +++ b/examples/studio/chat/chat_completions.py @@ -1,19 +1,19 @@ from ai21 import AI21Client -from ai21.models.chat import ChatMessage +from ai21.models.chat.chat_message import SystemMessage, UserMessage, AssistantMessage system = "You're a support engineer in a SaaS company" messages = [ - ChatMessage(content=system, role="system"), - ChatMessage(content="Hello, I need help with a signup process.", role="user"), - ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), - ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), + SystemMessage(content=system, role="system"), + UserMessage(content="Hello, I need help with a signup process.", role="user"), + AssistantMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + UserMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), ] client = AI21Client() response = client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-1.5-mini", max_tokens=100, temperature=0.7, top_p=1.0, diff --git a/examples/studio/chat/chat_completions_jamba_instruct.py b/examples/studio/chat/chat_completions_jamba_instruct.py new file mode 100644 index 00000000..af84d0bc --- /dev/null +++ b/examples/studio/chat/chat_completions_jamba_instruct.py @@ -0,0 +1,23 @@ +from ai21 import AI21Client +from ai21.models.chat.chat_message import SystemMessage, UserMessage, AssistantMessage + +system = "You're a support engineer in a SaaS company" +messages = [ + SystemMessage(content=system, role="system"), + UserMessage(content="Hello, I need help with a signup process.", role="user"), + AssistantMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"), + UserMessage(content="I am having trouble signing up for your product with my Google account.", role="user"), +] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="jamba-instruct", + max_tokens=100, + temperature=0.7, + top_p=1.0, + stop=["\n"], +) + +print(response) diff --git a/examples/studio/chat/chat_documents.py b/examples/studio/chat/chat_documents.py new file mode 100644 index 00000000..1a69b91d --- /dev/null +++ b/examples/studio/chat/chat_documents.py @@ -0,0 +1,45 @@ +import uuid + +from ai21 import AI21Client +from ai21.logger import set_verbose +from ai21.models.chat import ChatMessage +from ai21.models.chat.document_schema import DocumentSchema + +set_verbose(True) + +schnoodel = DocumentSchema( + id=str(uuid.uuid4()), + content="Schnoodel Inc. Annual Report - 2024. Schnoodel Inc., a leader in innovative culinary technology, saw a " + "15% revenue growth this year, reaching $120 million. The launch of SchnoodelChef Pro has significantly " + "contributed, making up 35% of total sales. We've expanded into the Asian market, notably Japan, " + "and increased our global presence. Committed to sustainability, we reduced our carbon footprint " + "by 20%. Looking ahead, we plan to integrate more advanced machine learning features and expand " + "into South America.", + metadata={"topic": "revenue"}, +) +shnokel = DocumentSchema( + id=str(uuid.uuid4()), + content="Shnokel Corp. TL;DR Annual Report - 2024. Shnokel Corp., a pioneer in renewable energy solutions, " + "reported a 20% increase in revenue this year, reaching $200 million. The successful deployment of " + "our advanced solar panels, SolarFlex, accounted for 40% of our sales. We entered new markets in Europe " + "and have plans to develop wind energy projects next year. Our commitment to reducing environmental " + "impact saw a 25% decrease in operational emissions. Upcoming initiatives include a significant " + "investment in R&D for sustainable technologies.", + metadata={"topic": "revenue"}, +) + +documents = [schnoodel, shnokel] + +messages = [ + ChatMessage( + role="system", + content="You are a helpful assistant that receives revenue documents and answers related questions", + ), + ChatMessage(role="user", content="Hi, which company earned more during 2024 - Schnoodel or Shnokel?"), +] + +client = AI21Client() + +response = client.chat.completions.create(messages=messages, model="jamba-1.5-mini", documents=documents) + +print(response) diff --git a/examples/studio/chat/chat_function_calling.py b/examples/studio/chat/chat_function_calling.py new file mode 100644 index 00000000..cdfb9385 --- /dev/null +++ b/examples/studio/chat/chat_function_calling.py @@ -0,0 +1,72 @@ +from ai21 import AI21Client +from ai21.logger import set_verbose +from ai21.models.chat import ChatMessage, ToolMessage +from ai21.models.chat.function_tool_definition import FunctionToolDefinition +from ai21.models.chat.tool_defintions import ToolDefinition +from ai21.models.chat.tool_parameters import ToolParameters + +set_verbose(True) + + +def get_order_delivery_date(order_id: str) -> str: + print(f"Getting delivery date from database for order ID: {order_id}...") + return "2025-05-04" + + +messages = [ + ChatMessage( + role="system", + content="You are a helpful customer support assistant. Use the supplied tools to assist the user.", + ), + ChatMessage(role="user", content="Hi, can you tell me the delivery date for my order?"), + ChatMessage(role="assistant", content="Hi there! I can help with that. Can you please provide your order ID?"), + ChatMessage(role="user", content="i think it is order_12345"), +] + +tool_definition = ToolDefinition( + type="function", + function=FunctionToolDefinition( + name="get_order_delivery_date", + description="Get the delivery date for a given order ID", + parameters=ToolParameters( + type="object", + properties={"order_id": {"type": "string", "description": "The customer's order ID."}}, + required=["order_id"], + ), + ), +) + +tools = [tool_definition] + +client = AI21Client() + +response = client.chat.completions.create(messages=messages, model="jamba-1.5", tools=tools) + +print(response) + +assistant_message = response.choices[0].message +tool_calls = assistant_message.tool_calls + +delivery_date = None +if tool_calls: + tool_call = tool_calls[0] + if tool_call.function.name == "get_order_delivery_date": + func_arguments = tool_call.function.arguments + if "order_id" in func_arguments: + # extract the order ID from the function arguments logic... (in this case it's just 1 argument) + order_id = func_arguments + delivery_date = get_order_delivery_date(order_id) + print(f"Delivery date for order ID {order_id}: {delivery_date}") + else: + print("order_id not found in function arguments") + else: + print(f"Unexpected tool call found - {tool_call.function.name}") +else: + print("No tool calls found.") + +if delivery_date is not None: + tool_message = ToolMessage(role="tool", tool_call_id=tool_calls[0].id, content=delivery_date) + messages.append(assistant_message) + messages.append(tool_message) + response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools) + print(response) diff --git a/examples/studio/chat/chat_function_calling_multiple_tools.py b/examples/studio/chat/chat_function_calling_multiple_tools.py new file mode 100644 index 00000000..9655e91c --- /dev/null +++ b/examples/studio/chat/chat_function_calling_multiple_tools.py @@ -0,0 +1,114 @@ +import json + +from ai21 import AI21Client +from ai21.logger import set_verbose +from ai21.models.chat import ChatMessage, ToolMessage +from ai21.models.chat.function_tool_definition import FunctionToolDefinition +from ai21.models.chat.tool_defintions import ToolDefinition +from ai21.models.chat.tool_parameters import ToolParameters + +set_verbose(True) + + +def get_weather(place: str, date: str) -> str: + print(f"Getting the expected weather at {place} during {date} from the internet...") + return "32 celsius" + + +def get_sunset_hour(place: str, date: str): + print(f"Getting the expected sunset hour at {place} during {date} from the internet...") + return "7 pm" + + +messages = [ + ChatMessage( + role="system", + content="You are a helpful assistant. Use the supplied tools to assist the user.", + ), + ChatMessage( + role="user", content="Hi, can you assist me to get info about the weather and expected sunset in Tel Aviv?" + ), + ChatMessage(role="assistant", content="Hi there! I can help with that. On which date?"), + ChatMessage(role="user", content="At 2024-08-27"), +] + +get_sunset_tool = ToolDefinition( + type="function", + function=FunctionToolDefinition( + name="get_sunset_hour", + description="Search the internet for the sunset hour at a given place on a given date", + parameters=ToolParameters( + type="object", + properties={ + "place": {"type": "string", "description": "The place to look for the weather at"}, + "date": {"type": "string", "description": "The date to look for the weather at"}, + }, + required=["place", "date"], + ), + ), +) + +get_weather_tool = ToolDefinition( + type="function", + function=FunctionToolDefinition( + name="get_weather", + description="Search the internet for the weather at a given place on a given date", + parameters=ToolParameters( + type="object", + properties={ + "place": {"type": "string", "description": "The place to look for the weather at"}, + "date": {"type": "string", "description": "The date to look for the weather at"}, + }, + required=["place", "date"], + ), + ), +) + +tools = [get_sunset_tool, get_weather_tool] + +client = AI21Client() + +response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools) + +print(response.choices[0].message) + +assistant_message = response.choices[0].message +messages.append(assistant_message) + +tool_calls = assistant_message.tool_calls + +too_call_id_to_result = {} +if tool_calls: + for tool_call in tool_calls: + if tool_call.function.name == "get_weather": + func_arguments = tool_call.function.arguments + args = json.loads(func_arguments) + if "place" in args and "date" in args: + result = get_weather(args["place"], args["date"]) + too_call_id_to_result[tool_call.id] = result + else: + print(f"Got unexpected arguments in function call - {args}") + elif tool_call.function.name == "get_sunset_hour": + func_arguments = tool_call.function.arguments + args = json.loads(func_arguments) + if "place" in args and "date" in args: + result = get_sunset_hour(args["place"], args["date"]) + too_call_id_to_result[tool_call.id] = result + else: + print(f"Got unexpected arguments in function call - {args}") + else: + print(f"Unexpected tool call found - {tool_call.function.name}") +else: + print("No tool calls found.") + + +if too_call_id_to_result: + for tool_id_called, result in too_call_id_to_result.items(): + tool_message = ToolMessage(role="tool", tool_call_id=tool_id_called, content=str(result)) + messages.append(tool_message) + + for message in messages: + print(message) + + response = client.chat.completions.create(messages=messages, model="jamba-1.5-large", tools=tools) + print(response.choices[0].message) diff --git a/examples/studio/chat/chat_response_format.py b/examples/studio/chat/chat_response_format.py new file mode 100644 index 00000000..5c393b1f --- /dev/null +++ b/examples/studio/chat/chat_response_format.py @@ -0,0 +1,52 @@ +from enum import Enum + +from ai21 import AI21Client +from ai21.logger import set_verbose +from ai21.models.chat import ChatMessage +from ai21.models.chat.response_format import ResponseFormat +from pydantic import BaseModel + +set_verbose(True) + + +class TicketType(Enum): + ADULT = "adult" + CHILD = "child" + + +class ZooTicket(BaseModel): + ticket_type: TicketType + quantity: int + + +class ZooTicketsOrder(BaseModel): + date: str + tickets: list[ZooTicket] + + +messages = [ + ChatMessage( + role="system", + content="As an assistant, your task is to generate structured data from user requests for zoo tickets. " + "Create a valid JSON object specifying the date of the visit and the number of tickets. Your " + "answer should be in JSON format, no extra spaces or new lines or any character that is not " + f"part of the JSON. Here is the JSON format to follow: {ZooTicketsOrder.model_json_schema()}", + ), + ChatMessage(role="user", content="Can I order a ticket for September 22, 2024, for myself and two kids?"), +] + +client = AI21Client() + +response = client.chat.completions.create( + messages=messages, + model="jamba-1.5-mini", + max_tokens=2000, + temperature=0, + response_format=ResponseFormat(type="text"), +) + +print(response) + +order = ZooTicketsOrder.model_validate_json(response.choices[0].message.content) + +print(order) diff --git a/examples/studio/chat/stream_chat_completions.py b/examples/studio/chat/stream_chat_completions.py index fd079962..415b3260 100644 --- a/examples/studio/chat/stream_chat_completions.py +++ b/examples/studio/chat/stream_chat_completions.py @@ -13,7 +13,7 @@ response = client.chat.completions.create( messages=messages, - model="jamba-instruct-preview", + model="jamba-1.5-large", max_tokens=100, stream=True, ) diff --git a/poetry.lock b/poetry.lock index d517d1a8..109881d1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "ai21-tokenizer" @@ -1245,7 +1245,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1361,51 +1360,37 @@ python-versions = ">=3.6" files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d92f81886165cb14d7b067ef37e142256f1c6a90a65cd156b063a43da1708cfd"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, - {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:b5edda50e5e9e15e54a6a8a0070302b00c518a9d32accc2346ad6c984aacd279"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, - {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:7048c338b6c86627afb27faecf418768acb6331fc24cfa56c93e8c9780f815fa"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, - {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3fcc54cb0c8b811ff66082de1680b4b14cf8a81dce0d4fbf665c2265a81e07a1"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, - {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:665f58bfd29b167039f714c6998178d27ccd83984084c286110ef26b230f259f"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, - {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:9eb5dee2772b0f704ca2e45b1713e4e5198c18f515b52743576d196348f374d3"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, - {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, {file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"}, diff --git a/tests/integration_tests/clients/test_studio.py b/tests/integration_tests/clients/test_studio.py index 48e11248..8b157c5e 100644 --- a/tests/integration_tests/clients/test_studio.py +++ b/tests/integration_tests/clients/test_studio.py @@ -28,6 +28,7 @@ ("summarize_by_segment.py",), ("tokenization.py",), ("chat/chat_completions.py",), + ("chat/chat_completions_jamba_instruct.py",), ("chat/stream_chat_completions.py",), # ("custom_model.py", ), # ('custom_model_completion.py', ), @@ -48,6 +49,7 @@ "when_summarize_by_segment__should_return_ok", "when_tokenization__should_return_ok", "when_chat_completions__should_return_ok", + "when_chat_completions_jamba_instruct__should_return_ok", "when_stream_chat_completions__should_return_ok", # "when_custom_model__should_return_ok", # "when_custom_model_completion__should_return_ok", diff --git a/tests/unittests/clients/studio/resources/chat/test_chat_completions.py b/tests/unittests/clients/studio/resources/chat/test_chat_completions.py index d6228ac8..5987db9a 100644 --- a/tests/unittests/clients/studio/resources/chat/test_chat_completions.py +++ b/tests/unittests/clients/studio/resources/chat/test_chat_completions.py @@ -1,6 +1,17 @@ +import uuid + +import httpx import pytest +from pytest_mock import MockerFixture + from ai21 import AI21Client, AsyncAI21Client from ai21.models import ChatMessage, RoleType +from ai21.models.chat import ChatCompletionResponse +from ai21.models.chat.chat_message import UserMessage, SystemMessage, AssistantMessage +from ai21.models.chat.document_schema import DocumentSchema +from ai21.models.chat.function_tool_definition import FunctionToolDefinition +from ai21.models.chat.tool_defintions import ToolDefinition +from ai21.models.chat.tool_parameters import ToolParameters _DUMMY_API_KEY = "dummy_api_key" @@ -8,7 +19,7 @@ def test_chat_create__when_bad_import_to_chat_message__raise_error(): with pytest.raises(ValueError) as e: AI21Client(api_key=_DUMMY_API_KEY).chat.completions.create( - model="jamba-instruct-preview", + model="jamba-1.5", messages=[ChatMessage(role=RoleType.USER, text="Hello")], system="System Test", ) @@ -23,7 +34,7 @@ def test_chat_create__when_bad_import_to_chat_message__raise_error(): async def test_async_chat_create__when_bad_import_to_chat_message__raise_error(): with pytest.raises(ValueError) as e: await AsyncAI21Client(api_key=_DUMMY_API_KEY).chat.completions.create( - model="jamba-instruct-preview", + model="jamba-1.5", messages=[ChatMessage(role=RoleType.USER, text="Hello")], system="System Test", ) @@ -38,7 +49,151 @@ def test__when_model_and_model_id__raise_error(): client = AI21Client() with pytest.raises(ValueError): client.chat.completions.create( - model="jamba-instruct", + model="jamba-1.5", model_id="jamba-instruct", messages=[ChatMessage(role=RoleType.USER, text="Hello")], ) + + +def test_chat_completion_tools_call_happy_flow(): + client = AI21Client() + with pytest.raises(ValueError): + client.chat.completions.create( + model="jamba-instruct", model_id="jamba-instruct", messages=[ChatMessage(role=RoleType.USER, text="Hello")] + ) + + +def test_chat_completion_basic_happy_flow(mocker: MockerFixture) -> None: + response_json = { + "id": "chat-cc8ce5c05f1d4ed9b722123ac4a0f267", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": " Hello! How can I assist you today?"}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 144, "completion_tokens": 14, "total_tokens": 158}, + "meta": {"requestDurationMillis": 236}, + } + mocked_client = mocker.Mock(spec=httpx.Client) + mocked_client.send.return_value = httpx.Response(status_code=200, json=response_json) + client = AI21Client(api_key=_DUMMY_API_KEY, http_client=mocked_client) + response: ChatCompletionResponse = client.chat.completions.create( + model="jamba-1.5", messages=[UserMessage(role="user", content="Hello")] + ) + assert response.choices[0].message.content == " Hello! How can I assist you today?" + + +def test_chat_completion_with_tool_calls_happy_flow(mocker: MockerFixture) -> None: + response_json = { + "id": "chat-cc8ce5c05f1d4ed9b722123ac4a0f267", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_62136354", + "type": "function", + "function": {"name": "get_delivery_date", "arguments": '{"order_id":"order_12345"}'}, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": {"prompt_tokens": 144, "completion_tokens": 14, "total_tokens": 158}, + "meta": {"requestDurationMillis": 236}, + } + mocked_client = mocker.Mock(spec=httpx.Client) + mocked_client.send.return_value = httpx.Response(status_code=200, json=response_json) + client = AI21Client(api_key=_DUMMY_API_KEY, http_client=mocked_client) + + messages = [ + SystemMessage( + role="system", + content="You are a helpful customer support assistant. Use the supplied tools to assist the user.", + ), + UserMessage(role="user", content="Hi, can you tell me the delivery date for my order?"), + AssistantMessage( + role="assistant", content="Hi there! I can help with that. Can you please provide your order ID?" + ), + UserMessage(role="user", content="i think it is order_12345"), + ] + + tools = [ + ToolDefinition( + type="function", + function=FunctionToolDefinition( + name="get_delivery_date", + description="Get the delivery date for a given order ID", + parameters=ToolParameters( + type="object", + properties={"order_id": {"type": "string", "description": "The customer's order ID."}}, + required=["order_id"], + ), + ), + ) + ] + + response = client.chat.completions.create(model="jamba-1.5", messages=messages, tools=tools) + assert response.choices[0].message.tool_calls[0].function.name == "get_delivery_date" + assert response.choices[0].message.tool_calls[0].function.arguments == '{"order_id":"order_12345"}' + + +def test_chat_completion_with_documents_happy_flow(mocker: MockerFixture) -> None: + response_json = { + "id": "chat-cc8ce5c05f1d4ed9b722123ac4a0f267", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Shnokel.", + }, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 144, "completion_tokens": 14, "total_tokens": 158}, + "meta": {"requestDurationMillis": 236}, + } + mocked_client = mocker.Mock(spec=httpx.Client) + mocked_client.send.return_value = httpx.Response(status_code=200, json=response_json) + client = AI21Client(api_key=_DUMMY_API_KEY, http_client=mocked_client) + + schnoodel = DocumentSchema( + id=str(uuid.uuid4()), + content="Schnoodel Inc. Annual Report - 2024. Schnoodel Inc., a leader in innovative culinary technology, " + "saw a 15% revenue growth this year, reaching $120 million. The launch of SchnoodelChef Pro has significantly " + "contributed, making up 35% of total sales. We've expanded into the Asian market, notably Japan, " + "and increased our global presence. Committed to sustainability, we reduced our carbon footprint " + "by 20%. Looking ahead, we plan to integrate more advanced machine learning features and expand " + "into South America.", + metadata={"topic": "revenue"}, + ) + shnokel = DocumentSchema( + id=str(uuid.uuid4()), + content="Shnokel Corp. TL;DR Annual Report - 2024. Shnokel Corp., a pioneer in renewable energy solutions, " + "reported a 20% increase in revenue this year, reaching $200 million. The successful deployment of " + "our advanced solar panels, SolarFlex, accounted for 40% of our sales. We entered new markets in Europe " + "and have plans to develop wind energy projects next year. Our commitment to reducing environmental " + "impact saw a 25% decrease in operational emissions. Upcoming initiatives include a significant " + "investment in R&D for sustainable technologies.", + metadata={"topic": "revenue"}, + ) + + documents = [schnoodel, shnokel] + + messages = [ + SystemMessage( + role="system", + content="You are a helpful assistant that receives revenue documents and answers related questions", + ), + UserMessage(role="user", content="Hi, which company earned more during 2024 - Schnoodel or Shnokel?"), + ] + + response = client.chat.completions.create(model="jamba-1.5", messages=messages, documents=documents) + assert response.choices[0].message.content == "Shnokel." diff --git a/tests/unittests/clients/studio/resources/conftest.py b/tests/unittests/clients/studio/resources/conftest.py index 7c73a479..a26d2514 100644 --- a/tests/unittests/clients/studio/resources/conftest.py +++ b/tests/unittests/clients/studio/resources/conftest.py @@ -150,7 +150,7 @@ def get_chat_completions(is_async: bool = False): "index": 0, "message": { "content": "Hello, I need help with a signup process.", - "role": "user", + "role": "assistant", }, "finish_reason": "dummy_reason", "logprobs": None, diff --git a/tests/unittests/models/response_mocks.py b/tests/unittests/models/response_mocks.py index 78961a75..4a7b2245 100644 --- a/tests/unittests/models/response_mocks.py +++ b/tests/unittests/models/response_mocks.py @@ -23,7 +23,8 @@ SegmentSummary, Highlight, ) -from ai21.models.chat import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage +from ai21.models.chat import ChatCompletionResponse, ChatCompletionResponseChoice +from ai21.models.chat.chat_message import AssistantMessage from ai21.models.responses.segmentation_response import Segment from ai21.models.usage_info import UsageInfo @@ -77,6 +78,7 @@ def get_chat_completions_response(): "more information about the issue you're facing? For example, are you receiving an " "error message when you try to sign up with your Google account? If so, what does the " "error message say?", + "tool_calls": None, }, "logprobs": None, "finish_reason": "stop", @@ -87,8 +89,8 @@ def get_chat_completions_response(): choice = ChatCompletionResponseChoice( index=0, - message=ChatMessage( - role=RoleType.ASSISTANT, + message=AssistantMessage( + role="assistant", content="I apologize for any inconvenience you're experiencing. Can you please provide me with more " "information about the issue you're facing? For example, are you receiving an error message when " "you try to sign up with your Google account? If so, what does the error message say?",