From 55f0aa603e19e052e4e33aec19f99610e609c76e Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 1 Jul 2024 11:33:39 +0300 Subject: [PATCH 1/3] fix: Added deprecation warning --- ai21/clients/bedrock/ai21_bedrock_client.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index 03eeb0be..77f353ba 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Dict, Any import boto3 @@ -32,6 +33,13 @@ def __init__( session: Optional[boto3.Session] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): + if model_id is not None: + warnings.warn( + "Please consider using the 'model_id' parameter in the " + "'create' method calls instead of the constructor.", + DeprecationWarning, + ) + region = _get_aws_region(env_config=env_config, session=session, region=region) self._http_client = http_client or HttpClient( From 2e2ab2515b58a0dd1f5e246d96822236c3740905 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 1 Jul 2024 11:35:25 +0300 Subject: [PATCH 2/3] fix: Added deprecation warning to async --- ai21/clients/bedrock/ai21_bedrock_client.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index 77f353ba..12cefab3 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -66,6 +66,13 @@ def __init__( session: Optional[boto3.Session] = None, env_config: _AI21EnvConfig = AI21EnvConfig, ): + if model_id is not None: + warnings.warn( + "Please consider using the 'model_id' parameter in the " + "'create' method calls instead of the constructor.", + DeprecationWarning, + ) + region = _get_aws_region(env_config=env_config, session=session, region=region) self._http_client = http_client or AsyncHttpClient( From e24fd139cebd15fc5a44b41e167dfda7b42ef9e7 Mon Sep 17 00:00:00 2001 From: Miri Bar Date: Mon, 1 Jul 2024 12:14:29 +0300 Subject: [PATCH 3/3] chore: add log for ignoring stream --- ai21/clients/bedrock/resources/bedrock_completion.py | 9 +++++++++ .../bedrock/resources/chat/bedrock_chat_completions.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/ai21/clients/bedrock/resources/bedrock_completion.py b/ai21/clients/bedrock/resources/bedrock_completion.py index 9bca6623..00ec8334 100644 --- a/ai21/clients/bedrock/resources/bedrock_completion.py +++ b/ai21/clients/bedrock/resources/bedrock_completion.py @@ -6,6 +6,7 @@ from ai21.models import Penalty, CompletionsResponse from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given +from ai21.logger import logger class BedrockCompletion(BedrockResource): @@ -25,6 +26,10 @@ def create( count_penalty: Penalty | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: + stream_field = kwargs.get("stream", NOT_GIVEN) + if stream_field is not NOT_GIVEN: + logger.warning("Field stream is not supported. Ignoring it.") + body = remove_not_given( { "prompt": prompt, @@ -67,6 +72,10 @@ async def create( count_penalty: Penalty | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: + stream_field = kwargs.get("stream", NOT_GIVEN) + if stream_field is not NOT_GIVEN: + logger.warning("Field stream is not supported. Ignoring it.") + body = remove_not_given( { "prompt": prompt, diff --git a/ai21/clients/bedrock/resources/chat/bedrock_chat_completions.py b/ai21/clients/bedrock/resources/chat/bedrock_chat_completions.py index aca24274..dbe83bc3 100644 --- a/ai21/clients/bedrock/resources/chat/bedrock_chat_completions.py +++ b/ai21/clients/bedrock/resources/chat/bedrock_chat_completions.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import List, Any +from ai21.logger import logger from ai21.clients.bedrock.resources.bedrock_resource import BedrockResource, AsyncBedrockResource from ai21.models.chat import ChatMessage, ChatCompletionResponse from ai21.types import NotGiven, NOT_GIVEN @@ -19,6 +20,10 @@ def create( n: int | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: + stream_field = kwargs.get("stream", NOT_GIVEN) + if stream_field is not NOT_GIVEN: + logger.warning("Field stream is not supported. Ignoring it.") + body = remove_not_given( { "messages": [message.to_dict() for message in messages], @@ -52,6 +57,10 @@ async def create( n: int | NotGiven = NOT_GIVEN, **kwargs: Any, ) -> ChatCompletionResponse: + stream_field = kwargs.get("stream", NOT_GIVEN) + if stream_field is not NOT_GIVEN: + logger.warning("Field stream is not supported. Ignoring it.") + body = remove_not_given( { "messages": [message.to_dict() for message in messages],