Skip to content

Commit

Permalink
chore: merge recent
Browse files Browse the repository at this point in the history
  • Loading branch information
miri-bar committed Jul 1, 2024
2 parents e0cc96a + e24fd13 commit 0e407e7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
31 changes: 28 additions & 3 deletions ai21/clients/bedrock/ai21_bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import warnings
from typing import Optional, Dict, Any

import boto3

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.clients.aws.utils import get_aws_region
from ai21.clients.bedrock.resources.chat.bedrock_chat import BedrockChat, AsyncBedrockChat
from ai21.clients.bedrock.resources.bedrock_completion import BedrockCompletion, AsyncBedrockCompletion
from ai21.http_client.http_client import HttpClient
from ai21.http_client.async_http_client import AsyncHttpClient


def _get_aws_region(
env_config: _AI21EnvConfig,
session: Optional[boto3.Session] = None,
region: Optional[str] = None,
) -> str:
if session is not None:
return session.region_name

return region or env_config.aws_region


class AI21BedrockClient:
def __init__(
self,
Expand All @@ -22,7 +33,14 @@ def __init__(
session: Optional[boto3.Session] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
region = get_aws_region(env_config=env_config, session=session, region=region)
if model_id is not None:
warnings.warn(
"Please consider using the 'model_id' parameter in the "
"'create' method calls instead of the constructor.",
DeprecationWarning,
)

region = _get_aws_region(env_config=env_config, session=session, region=region)

self._http_client = http_client or HttpClient(
timeout_sec=timeout_sec,
Expand All @@ -48,7 +66,14 @@ def __init__(
session: Optional[boto3.Session] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
region = get_aws_region(env_config=env_config, session=session, region=region)
if model_id is not None:
warnings.warn(
"Please consider using the 'model_id' parameter in the "
"'create' method calls instead of the constructor.",
DeprecationWarning,
)

region = _get_aws_region(env_config=env_config, session=session, region=region)

self._http_client = http_client or AsyncHttpClient(
timeout_sec=timeout_sec,
Expand Down
9 changes: 9 additions & 0 deletions ai21/clients/bedrock/resources/bedrock_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ai21.models import Penalty, CompletionsResponse
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given
from ai21.logger import logger


class BedrockCompletion(BedrockResource):
Expand All @@ -25,6 +26,10 @@ def create(
count_penalty: Penalty | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
stream_field = kwargs.get("stream", NOT_GIVEN)
if stream_field is not NOT_GIVEN:
logger.warning("Field stream is not supported. Ignoring it.")

body = remove_not_given(
{
"prompt": prompt,
Expand Down Expand Up @@ -67,6 +72,10 @@ async def create(
count_penalty: Penalty | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
stream_field = kwargs.get("stream", NOT_GIVEN)
if stream_field is not NOT_GIVEN:
logger.warning("Field stream is not supported. Ignoring it.")

body = remove_not_given(
{
"prompt": prompt,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import List, Any

from ai21.logger import logger
from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource, AsyncBedrockResource
from ai21.models.chat import ChatMessage, ChatCompletionResponse
from ai21.types import NotGiven, NOT_GIVEN
Expand All @@ -19,6 +20,10 @@ def create(
n: int | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
stream_field = kwargs.get("stream", NOT_GIVEN)
if stream_field is not NOT_GIVEN:
logger.warning("Field stream is not supported. Ignoring it.")

body = remove_not_given(
{
"messages": [message.to_dict() for message in messages],
Expand Down Expand Up @@ -52,6 +57,10 @@ async def create(
n: int | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
stream_field = kwargs.get("stream", NOT_GIVEN)
if stream_field is not NOT_GIVEN:
logger.warning("Field stream is not supported. Ignoring it.")

body = remove_not_given(
{
"messages": [message.to_dict() for message in messages],
Expand Down

0 comments on commit 0e407e7

Please sign in to comment.