Skip to content

Commit

Permalink
Update openai to v1 (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins authored Nov 14, 2023
1 parent 29e2a67 commit ac39bd2
Show file tree
Hide file tree
Showing 6 changed files with 278 additions and 227 deletions.
7 changes: 7 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"python.testing.pytestArgs": [
"--no-cov" // Breaks debugger. https://code.visualstudio.com/docs/python/testing#_pytest-configuration-settings
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false
}
298 changes: 183 additions & 115 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ strict = true
[[tool.mypy.overrides]]
module = [
"litellm",
"litellm.utils",
]
ignore_missing_imports = true

Expand All @@ -25,8 +26,8 @@ repository = "https://github.com/jackmpcollins/magentic"

[tool.poetry.dependencies]
python = ">=3.10,<4.0"
litellm = {version = ">=0.8.4", optional = true}
openai = ">=0.27,<1.0"
litellm = {version = ">=1.0.0", optional = true}
openai = ">=1.0"
pydantic = ">=2.0.0"
pydantic-settings = ">=2.0.0"

Expand Down
57 changes: 31 additions & 26 deletions src/magentic/chat_model/litellm_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from collections.abc import AsyncIterator, Callable, Iterable, Iterator
from collections.abc import AsyncIterator, Callable, Iterable
from itertools import chain
from typing import Any, Literal, TypeVar, cast, overload

from magentic.chat_model.openai_chat_model import (
OpenaiChatCompletionChoiceMessage,
OpenaiChatCompletionChunk,
message_to_openai_message,
)
from litellm.utils import CustomStreamWrapper, ModelResponse
from openai.types.chat import ChatCompletionMessageParam

from magentic.chat_model.openai_chat_model import message_to_openai_message

try:
import litellm
Expand Down Expand Up @@ -36,13 +35,13 @@

def litellm_completion(
model: str,
messages: Iterable[OpenaiChatCompletionChoiceMessage],
messages: list[ChatCompletionMessageParam],
api_base: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
functions: list[dict[str, Any]] | None = None,
function_call: Literal["auto", "none"] | dict[str, Any] | None = None,
) -> Iterator[OpenaiChatCompletionChunk]:
) -> CustomStreamWrapper:
"""Type-annotated version of `litellm.completion`."""
# `litellm.completion` doesn't accept `None`
# so only pass args with values
Expand All @@ -58,24 +57,24 @@ def litellm_completion(
if temperature is not None:
kwargs["temperature"] = temperature

response: Iterator[dict[str, Any]] = litellm.completion( # type: ignore[no-untyped-call,unused-ignore]
response: CustomStreamWrapper = litellm.completion( # type: ignore[no-untyped-call,unused-ignore]
model=model,
messages=[m.model_dump(mode="json", exclude_unset=True) for m in messages],
messages=messages,
stream=True,
**kwargs,
)
return (OpenaiChatCompletionChunk.model_validate(chunk) for chunk in response)
return response


async def litellm_acompletion(
model: str,
messages: Iterable[OpenaiChatCompletionChoiceMessage],
messages: list[ChatCompletionMessageParam],
api_base: str | None = None,
max_tokens: int | None = None,
temperature: float | None = None,
functions: list[dict[str, Any]] | None = None,
function_call: Literal["auto", "none"] | dict[str, Any] | None = None,
) -> AsyncIterator[OpenaiChatCompletionChunk]:
) -> AsyncIterator[ModelResponse]:
"""Type-annotated version of `litellm.acompletion`."""
# `litellm.acompletion` doesn't accept `None`
# so only pass args with values
Expand All @@ -91,13 +90,13 @@ async def litellm_acompletion(
if temperature is not None:
kwargs["temperature"] = temperature

response: AsyncIterator[dict[str, Any]] = await litellm.acompletion( # type: ignore[no-untyped-call,unused-ignore]
response: AsyncIterator[ModelResponse] = await litellm.acompletion( # type: ignore[no-untyped-call,unused-ignore]
model=model,
messages=[m.model_dump(mode="json", exclude_unset=True) for m in messages],
messages=messages,
stream=True,
**kwargs,
)
return (OpenaiChatCompletionChunk.model_validate(chunk) async for chunk in response)
return response


R = TypeVar("R")
Expand Down Expand Up @@ -217,19 +216,22 @@ def complete(
response = chain([first_chunk], response) # Replace first chunk
first_chunk_delta = first_chunk.choices[0].delta

if first_chunk_delta.function_call:
if function_call := first_chunk_delta.get("function_call", None):
function_schema_by_name = {
function_schema.name: function_schema
for function_schema in function_schemas
}
function_name = first_chunk_delta.function_call.get_name_or_raise()
function_schema = function_schema_by_name[function_name]
if function_call.name is None:
msg = "OpenAI function call name is None"
raise ValueError(msg)
function_schema = function_schema_by_name[function_call.name]
try:
return AssistantMessage(
function_schema.parse_args(
chunk.choices[0].delta.function_call.arguments
for chunk in response
if chunk.choices[0].delta.function_call
if chunk.choices[0].delta.function_call.arguments is not None
)
)
except ValidationError as e:
Expand All @@ -246,9 +248,9 @@ def complete(
)
raise ValueError(msg)
streamed_str = StreamedStr(
chunk.choices[0].delta.content
chunk.choices[0].delta.get("content", None)
for chunk in response
if chunk.choices[0].delta.content is not None
if chunk.choices[0].delta.get("content", None) is not None
)
if streamed_str_in_output_types:
return cast(AssistantMessage[R], AssistantMessage(streamed_str))
Expand Down Expand Up @@ -335,19 +337,22 @@ async def acomplete(
response = achain(async_iter([first_chunk]), response) # Replace first chunk
first_chunk_delta = first_chunk.choices[0].delta

if first_chunk_delta.function_call:
if function_call := first_chunk_delta.get("function_call", None):
function_schema_by_name = {
function_schema.name: function_schema
for function_schema in function_schemas
}
function_name = first_chunk_delta.function_call.get_name_or_raise()
function_schema = function_schema_by_name[function_name]
if function_call.name is None:
msg = "OpenAI function call name is None"
raise ValueError(msg)
function_schema = function_schema_by_name[function_call.name]
try:
return AssistantMessage(
await function_schema.aparse_args(
chunk.choices[0].delta.function_call.arguments
async for chunk in response
if chunk.choices[0].delta.function_call
if chunk.choices[0].delta.function_call.arguments is not None
)
)
except ValidationError as e:
Expand All @@ -364,9 +369,9 @@ async def acomplete(
)
raise ValueError(msg)
async_streamed_str = AsyncStreamedStr(
chunk.choices[0].delta.content
chunk.choices[0].delta.get("content", None)
async for chunk in response
if chunk.choices[0].delta.content is not None
if chunk.choices[0].delta.get("content", None) is not None
)
if async_streamed_str_in_output_types:
return cast(AssistantMessage[R], AssistantMessage(async_streamed_str))
Expand Down
Loading

0 comments on commit ac39bd2

Please sign in to comment.