From cf1a576a3deb4e7f2dbe7d36688d916ef1dca723 Mon Sep 17 00:00:00 2001 From: miri-bar <160584887+miri-bar@users.noreply.github.com> Date: Thu, 4 Jul 2024 10:24:11 +0300 Subject: [PATCH] Add sagemaker async support (#155) * refactor: migrate from boto3 to custom http client * refactor: create an aws http client, and switch bedrock client to use it * test: rename test, add async tests * feat: add async client, remove bedrock_session class * docs: Azure README (#139) * feat: Add async tokenizer, add detokenize method (#144) * feat: add detokenize method, add async tokenizer * chore: update pyproject and poetry.lock * fix: fix tokenizer name in examples and readme, add example * test: adjust unittest + add unittest for async * chore: cache -> lru_cache to support python 3.8 * test: fix test_imports test * chore: add env_config to bedrock client to avoid breaking changes * refactor: sagemaker client, boto->aws http client * refactor: export aws auth logic to new class * refactor: remove aws_http_client, use http_client instead, add aws auth test * test: fix tests * refactor: remove aws_http_client * chore(deps-dev): bump authlib from 1.3.0 to 1.3.1 (#131) Bumps [authlib](https://github.com/lepture/authlib) from 1.3.0 to 1.3.1. - [Release notes](https://github.com/lepture/authlib/releases) - [Changelog](https://github.com/lepture/authlib/blob/master/docs/changelog.rst) - [Commits](https://github.com/lepture/authlib/compare/v1.3.0...v1.3.1) --- updated-dependencies: - dependency-name: authlib dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> * chore(deps): bump pypa/gh-action-pypi-publish from 1.8.14 to 1.9.0 (#138) Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.8.14 to 1.9.0. - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/81e9d935c883d0b210363ab89cf05f3894778450...ec4db0b4ddc65acdf4bff5fa45ac92d78b56bdf0) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> * chore: rebase code * refactor: chat + chat completions - migrate to new client * refactor: cr comments * chore: add async to new bedrock components * refactor: rename aws folder * chore: refactor to use http client * chore: fix typo on file name bedrock_chat_completions * fix: fix errors * chore: fix typo * chore: add async to sm resources * test: fix imports test * fix: Added deprecation warning * fix: Added deprecation warning to async * chore: add log for ignoring stream * chore: fix lint * refactor: export get_aws_region, add async sm to readme, add async examples * test: add async test files to test_sagemaker.py * refactor: remove get_aws_region func --------- Signed-off-by: dependabot[bot] Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- README.md | 18 ++++ ai21/__init__.py | 10 ++ ai21/clients/bedrock/ai21_bedrock_client.py | 19 +--- .../sagemaker/ai21_sagemaker_client.py | 92 +++++++++++++++---- .../sagemaker/resources/sagemaker_answer.py | 19 +++- .../resources/sagemaker_completion.py | 64 ++++++++++++- .../sagemaker/resources/sagemaker_gec.py | 15 ++- .../resources/sagemaker_paraphrase.py | 27 +++++- .../sagemaker/resources/sagemaker_resource.py | 89 ++++++++++++++++-- .../resources/sagemaker_summarize.py | 28 +++++- ai21/errors.py | 5 + examples/sagemaker/async_answer.py | 21 +++++ examples/sagemaker/async_completion.py | 47 ++++++++++ examples/sagemaker/async_gec.py | 14 +++ examples/sagemaker/async_paraphrase.py | 17 ++++ examples/sagemaker/async_summarization.py | 23 +++++ .../clients/test_sagemaker.py | 10 ++ tests/unittests/test_imports.py | 1 + 18 files changed, 465 insertions(+), 54 deletions(-) create mode 100644 examples/sagemaker/async_answer.py create mode 100644 examples/sagemaker/async_completion.py create mode 100644 examples/sagemaker/async_gec.py create mode 100644 examples/sagemaker/async_paraphrase.py create mode 100644 examples/sagemaker/async_summarization.py diff --git a/README.md b/README.md index c835bbae..00d51c93 100644 --- a/README.md +++ b/README.md @@ -589,6 +589,24 @@ response = client.summarize.create( print(response.summary) ``` +#### Async + +```python +import asyncio +from ai21 import AsyncAI21SageMakerClient + +client = AsyncAI21SageMakerClient(endpoint_name="j2-endpoint-name") + +async def main(): + response = await client.summarize.create( + source="Text to summarize", + source_type="TEXT", + ) + print(response.summary) + +asyncio.run(main()) +``` + ### With Boto3 Session ```python diff --git a/ai21/__init__.py b/ai21/__init__.py index ea26e843..380a0409 100644 --- a/ai21/__init__.py +++ b/ai21/__init__.py @@ -45,6 +45,12 @@ def _import_async_bedrock_client(): return AsyncAI21BedrockClient +def _import_async_sagemaker_client(): + from ai21.clients.sagemaker.ai21_sagemaker_client import AsyncAI21SageMakerClient + + return AsyncAI21SageMakerClient + + def __getattr__(name: str) -> Any: try: if name == "AI21BedrockClient": @@ -58,6 +64,9 @@ def __getattr__(name: str) -> Any: if name == "AsyncAI21BedrockClient": return _import_async_bedrock_client() + + if name == "AsyncAI21SageMakerClient": + return _import_async_sagemaker_client() except ImportError as e: raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e @@ -79,4 +88,5 @@ def __getattr__(name: str) -> Any: "AI21AzureClient", "AsyncAI21AzureClient", "AsyncAI21BedrockClient", + "AsyncAI21SageMakerClient", ] diff --git a/ai21/clients/bedrock/ai21_bedrock_client.py b/ai21/clients/bedrock/ai21_bedrock_client.py index 12cefab3..97aaca5e 100644 --- a/ai21/clients/bedrock/ai21_bedrock_client.py +++ b/ai21/clients/bedrock/ai21_bedrock_client.py @@ -3,24 +3,13 @@ import boto3 -from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig +from ai21.ai21_env_config import AI21EnvConfig from ai21.clients.bedrock.resources.chat.bedrock_chat import BedrockChat, AsyncBedrockChat from ai21.clients.bedrock.resources.bedrock_completion import BedrockCompletion, AsyncBedrockCompletion from ai21.http_client.http_client import HttpClient from ai21.http_client.async_http_client import AsyncHttpClient -def _get_aws_region( - env_config: _AI21EnvConfig, - session: Optional[boto3.Session] = None, - region: Optional[str] = None, -) -> str: - if session is not None: - return session.region_name - - return region or env_config.aws_region - - class AI21BedrockClient: def __init__( self, @@ -31,7 +20,6 @@ def __init__( num_retries: Optional[int] = None, http_client: Optional[HttpClient] = None, session: Optional[boto3.Session] = None, - env_config: _AI21EnvConfig = AI21EnvConfig, ): if model_id is not None: warnings.warn( @@ -40,7 +28,7 @@ def __init__( DeprecationWarning, ) - region = _get_aws_region(env_config=env_config, session=session, region=region) + region = region or AI21EnvConfig.aws_region self._http_client = http_client or HttpClient( timeout_sec=timeout_sec, @@ -64,7 +52,6 @@ def __init__( num_retries: Optional[int] = None, http_client: Optional[AsyncHttpClient] = None, session: Optional[boto3.Session] = None, - env_config: _AI21EnvConfig = AI21EnvConfig, ): if model_id is not None: warnings.warn( @@ -73,7 +60,7 @@ def __init__( DeprecationWarning, ) - region = _get_aws_region(env_config=env_config, session=session, region=region) + region = region or AI21EnvConfig.aws_region self._http_client = http_client or AsyncHttpClient( timeout_sec=timeout_sec, diff --git a/ai21/clients/sagemaker/ai21_sagemaker_client.py b/ai21/clients/sagemaker/ai21_sagemaker_client.py index 57de75ad..701a9918 100644 --- a/ai21/clients/sagemaker/ai21_sagemaker_client.py +++ b/ai21/clients/sagemaker/ai21_sagemaker_client.py @@ -1,14 +1,16 @@ -from typing import Optional +from typing import Optional, Dict, Any import boto3 from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig -from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer -from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion -from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC -from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase -from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize -from ai21.clients.sagemaker.sagemaker_session import SageMakerSession + +from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer, AsyncSageMakerAnswer +from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion, AsyncSageMakerCompletion +from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC, AsyncSageMakerGEC +from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase, AsyncSageMakerParaphrase +from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize, AsyncSageMakerSummarize +from ai21.http_client.async_http_client import AsyncHttpClient +from ai21.http_client.http_client import HttpClient class AI21SageMakerClient: @@ -18,22 +20,80 @@ class AI21SageMakerClient: :param session: An optional boto3 session to use for the client. """ + def __init__( + self, + endpoint_name: str, + region: Optional[str] = None, + session: Optional["boto3.Session"] = None, + headers: Optional[Dict[str, Any]] = None, + timeout_sec: Optional[float] = None, + num_retries: Optional[int] = None, + http_client: Optional[HttpClient] = None, + **kwargs, + ): + region = region or AI21EnvConfig.aws_region + self._http_client = http_client or HttpClient( + headers=headers, + timeout_sec=timeout_sec, + num_retries=num_retries, + ) + + self.completion = SageMakerCompletion( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.paraphrase = SageMakerParaphrase( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.answer = SageMakerAnswer( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.gec = SageMakerGEC( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.summarize = SageMakerSummarize( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + + +class AsyncAI21SageMakerClient: + """ + :param endpoint_name: The name of the endpoint to use for the client. + :param region: The AWS region of the endpoint. + :param session: An optional boto3 session to use for the client. + """ + def __init__( self, endpoint_name: str, region: Optional[str] = None, session: Optional["boto3.Session"] = None, env_config: _AI21EnvConfig = AI21EnvConfig, + headers: Optional[Dict[str, Any]] = None, + timeout_sec: Optional[float] = None, + num_retries: Optional[int] = None, + http_client: Optional[AsyncHttpClient] = None, **kwargs, ): + region = region or AI21EnvConfig.aws_region + + self._http_client = http_client or AsyncHttpClient( + headers=headers, + timeout_sec=timeout_sec, + num_retries=num_retries, + ) - self._env_config = env_config - _session = () - self._session = SageMakerSession( - session=session, region=region or self._env_config.aws_region, endpoint_name=endpoint_name + self.completion = AsyncSageMakerCompletion( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.paraphrase = AsyncSageMakerParaphrase( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.answer = AsyncSageMakerAnswer( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.gec = AsyncSageMakerGEC( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session + ) + self.summarize = AsyncSageMakerSummarize( + endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session ) - self.completion = SageMakerCompletion(self._session) - self.paraphrase = SageMakerParaphrase(self._session) - self.answer = SageMakerAnswer(self._session) - self.gec = SageMakerGEC(self._session) - self.summarize = SageMakerSummarize(self._session) diff --git a/ai21/clients/sagemaker/resources/sagemaker_answer.py b/ai21/clients/sagemaker/resources/sagemaker_answer.py index d4a6ceb5..586e1ce8 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_answer.py +++ b/ai21/clients/sagemaker/resources/sagemaker_answer.py @@ -1,5 +1,5 @@ from ai21.clients.common.answer_base import Answer -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource from ai21.models import AnswerResponse @@ -11,6 +11,19 @@ def create( **kwargs, ) -> AnswerResponse: body = self._create_body(context=context, question=question) - response = self._invoke(body) + response = self._post(body) - return self._json_to_response(response) + return self._json_to_response(response.json()) + + +class AsyncSageMakerAnswer(AsyncSageMakerResource, Answer): + async def create( + self, + context: str, + question: str, + **kwargs, + ) -> AnswerResponse: + body = self._create_body(context=context, question=question) + response = await self._post(body) + + return self._json_to_response(response.json()) diff --git a/ai21/clients/sagemaker/resources/sagemaker_completion.py b/ai21/clients/sagemaker/resources/sagemaker_completion.py index 97a9682a..d2082d21 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_completion.py +++ b/ai21/clients/sagemaker/resources/sagemaker_completion.py @@ -2,7 +2,7 @@ from typing import List, Dict -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource from ai21.models import Penalty, CompletionsResponse from ai21.types import NotGiven, NOT_GIVEN from ai21.utils.typing import remove_not_given @@ -61,6 +61,64 @@ def create( } ) - raw_response = self._invoke(body) + raw_response = self._post(body=body) - return CompletionsResponse.from_dict(raw_response) + return CompletionsResponse.from_dict(raw_response.json()) + + +class AsyncSageMakerCompletion(AsyncSageMakerResource): + async def create( + self, + prompt: str, + *, + max_tokens: int | NotGiven = NOT_GIVEN, + num_results: int | NotGiven = NOT_GIVEN, + min_tokens: int | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + top_k_return: int | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + frequency_penalty: Penalty | NotGiven = NOT_GIVEN, + presence_penalty: Penalty | NotGiven = NOT_GIVEN, + count_penalty: Penalty | NotGiven = NOT_GIVEN, + logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, + **kwargs, + ) -> CompletionsResponse: + """ + :param prompt: Text for model to complete + :param max_tokens: The maximum number of tokens to generate per result + :param num_results: Number of completions to sample and return. + :param min_tokens: The minimum number of tokens to generate per result. + :param temperature: A value controlling the "creativity" of the model's responses. + :param top_p: A value controlling the diversity of the model's responses. + :param top_k_return: The number of top-scoring tokens to consider for each generation step. + :param stop_sequences: Stops decoding if any of the strings is generated + :param frequency_penalty: A penalty applied to tokens that are frequently generated. + :param presence_penalty: A penalty applied to tokens that are already present in the prompt. + :param count_penalty: A penalty applied to tokens based on their frequency in the generated responses + :param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text + representations of the tokens and the floats are the biases themselves. A positive bias increases generation + probability for a given token and a negative bias decreases it. + :param kwargs: + :return: + """ + body = remove_not_given( + { + "prompt": prompt, + "maxTokens": max_tokens, + "numResults": num_results, + "minTokens": min_tokens, + "temperature": temperature, + "topP": top_p, + "topKReturn": top_k_return, + "stopSequences": stop_sequences or [], + "frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty, + "presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty, + "countPenalty": count_penalty.to_dict() if count_penalty else count_penalty, + "logitBias": logit_bias, + } + ) + + raw_response = await self._post(body=body) + + return CompletionsResponse.from_dict(raw_response.json()) diff --git a/ai21/clients/sagemaker/resources/sagemaker_gec.py b/ai21/clients/sagemaker/resources/sagemaker_gec.py index 138ac0bf..e493eff4 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_gec.py +++ b/ai21/clients/sagemaker/resources/sagemaker_gec.py @@ -1,5 +1,5 @@ from ai21.clients.common.gec_base import GEC -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource from ai21.models import GECResponse @@ -7,6 +7,15 @@ class SageMakerGEC(SageMakerResource, GEC): def create(self, text: str, **kwargs) -> GECResponse: body = self._create_body(text=text) - response = self._invoke(body) + response = self._post(body) - return self._json_to_response(response) + return self._json_to_response(response.json()) + + +class AsyncSageMakerGEC(AsyncSageMakerResource, GEC): + async def create(self, text: str, **kwargs) -> GECResponse: + body = self._create_body(text=text) + + response = await self._post(body) + + return self._json_to_response(response.json()) diff --git a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py index 40251d4c..b277f40f 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py +++ b/ai21/clients/sagemaker/resources/sagemaker_paraphrase.py @@ -1,7 +1,7 @@ from typing import Optional from ai21.clients.common.paraphrase_base import Paraphrase -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource from ai21.models import ParaphraseStyleType, ParaphraseResponse @@ -21,6 +21,27 @@ def create( start_index=start_index, end_index=end_index, ) - response = self._invoke(body=body) + response = self._post(body=body) - return self._json_to_response(response) + return self._json_to_response(response.json()) + + +class AsyncSageMakerParaphrase(AsyncSageMakerResource, Paraphrase): + async def create( + self, + text: str, + *, + style: Optional[ParaphraseStyleType] = None, + start_index: Optional[int] = 0, + end_index: Optional[int] = None, + **kwargs, + ) -> ParaphraseResponse: + body = self._create_body( + text=text, + style=style, + start_index=start_index, + end_index=end_index, + ) + response = await self._post(body=body) + + return self._json_to_response(response.json()) diff --git a/ai21/clients/sagemaker/resources/sagemaker_resource.py b/ai21/clients/sagemaker/resources/sagemaker_resource.py index 976c9a70..b3ab6b00 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_resource.py +++ b/ai21/clients/sagemaker/resources/sagemaker_resource.py @@ -2,16 +2,91 @@ import json from abc import ABC -from typing import Any, Dict +from typing import Any, Dict, Optional -from ai21.clients.sagemaker.sagemaker_session import SageMakerSession +import boto3 +import httpx + +from ai21 import AI21APIError +from ai21.clients.aws.aws_authorization import AWSAuthorization +from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException, InternalDependencyException +from ai21.http_client.async_http_client import AsyncHttpClient +from ai21.http_client.http_client import HttpClient + + +def _handle_sagemaker_error(aws_error: AI21APIError) -> None: + status_code = aws_error.status_code + if status_code == 403: + raise AccessDenied(details=aws_error.details) + + if status_code == 404: + raise NotFound(details=aws_error.details) + + if status_code == 408: + raise APITimeoutError(details=aws_error.details) + + if status_code == 424: + raise ModelErrorException(details=aws_error.details) + + if status_code == 530: + raise InternalDependencyException(details=aws_error.details) + + raise aws_error class SageMakerResource(ABC): - def __init__(self, sagemaker_session: SageMakerSession): - self._sagemaker_session = sagemaker_session + def __init__( + self, + endpoint_name: str, + region: str, + client: HttpClient, + aws_session: Optional[boto3.Session] = None, + ): + self._client = client + self._aws_session = aws_session or boto3.Session(region_name=region) + self._url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations" + self._aws_auth = AWSAuthorization(aws_session=self._aws_session) - def _invoke(self, body: Dict[str, Any]) -> Dict[str, Any]: - return self._sagemaker_session.invoke_endpoint( - input_json=json.dumps(body), + def _post( + self, + body: Dict[str, Any], + ) -> httpx.Response: + auth_headers = self._aws_auth.get_auth_headers( + service_name="sagemaker", url=self._url, method="POST", data=json.dumps(body) ) + + try: + return self._client.execute_http_request( + url=self._url, body=body, method="POST", extra_headers=auth_headers + ) + except AI21APIError as aws_error: + _handle_sagemaker_error(aws_error) + + +class AsyncSageMakerResource(ABC): + def __init__( + self, + endpoint_name: str, + region: str, + client: AsyncHttpClient, + aws_session: Optional[boto3.Session] = None, + ): + self._client = client + self._aws_session = aws_session or boto3.Session(region_name=region) + self._url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations" + self._aws_auth = AWSAuthorization(aws_session=self._aws_session) + + async def _post( + self, + body: Dict[str, Any], + ) -> httpx.Response: + auth_headers = self._aws_auth.get_auth_headers( + service_name="sagemaker", url=self._url, method="POST", data=json.dumps(body) + ) + + try: + return await self._client.execute_http_request( + url=self._url, body=body, method="POST", extra_headers=auth_headers + ) + except AI21APIError as aws_error: + _handle_sagemaker_error(aws_error) diff --git a/ai21/clients/sagemaker/resources/sagemaker_summarize.py b/ai21/clients/sagemaker/resources/sagemaker_summarize.py index b8f52e0b..a3336292 100644 --- a/ai21/clients/sagemaker/resources/sagemaker_summarize.py +++ b/ai21/clients/sagemaker/resources/sagemaker_summarize.py @@ -3,7 +3,7 @@ from typing import Optional from ai21.clients.common.summarize_base import Summarize -from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource +from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource from ai21.models import SummarizeResponse, SummaryMethod @@ -24,6 +24,28 @@ def create( summary_method=summary_method, ) - response = self._invoke(body) + response = self._post(body) - return self._json_to_response(response) + return self._json_to_response(response.json()) + + +class AsyncSageMakerSummarize(AsyncSageMakerResource, Summarize): + async def create( + self, + source: str, + source_type: str, + *, + focus: Optional[str] = None, + summary_method: Optional[SummaryMethod] = None, + **kwargs, + ) -> SummarizeResponse: + body = self._create_body( + source=source, + source_type=source_type, + focus=focus, + summary_method=summary_method, + ) + + response = await self._post(body) + + return self._json_to_response(response.json()) diff --git a/ai21/errors.py b/ai21/errors.py index 20e0606e..83091ed3 100644 --- a/ai21/errors.py +++ b/ai21/errors.py @@ -98,3 +98,8 @@ class StreamingDecodeError(AI21Error): def __init__(self, chunk: str): message = f"Failed to decode chunk: {chunk} in stream. Please check the stream format" super().__init__(message) + + +class InternalDependencyException(AI21APIError): + def __init__(self, details: Optional[str] = None): + super().__init__(530, details) diff --git a/examples/sagemaker/async_answer.py b/examples/sagemaker/async_answer.py new file mode 100644 index 00000000..ceeee832 --- /dev/null +++ b/examples/sagemaker/async_answer.py @@ -0,0 +1,21 @@ +import asyncio + +from ai21 import AsyncAI21SageMakerClient + +client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") + + +async def main(): + response = await client.answer.create( + context="Holland is a geographical region[2] and former province on the western coast" + " of the Netherlands.[2] From the 10th to the 16th century, Holland proper was a unified political region " + "within the Holy Roman Empire as a county ruled by the counts of Holland. By the 17th century, the province " + "of Holland had risen to become a maritime and economic power, dominating the other provinces of the newly " + "independent Dutch Republic.", + question="When did Holland become an economic power?", + ) + + print(response.answer) + + +asyncio.run(main()) diff --git a/examples/sagemaker/async_completion.py b/examples/sagemaker/async_completion.py new file mode 100644 index 00000000..a9d445b5 --- /dev/null +++ b/examples/sagemaker/async_completion.py @@ -0,0 +1,47 @@ +import asyncio + +from ai21 import AsyncAI21SageMakerClient + +prompt = ( + "The following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- " + "Delivery takes up to 5 days\n- There is no return option\n\nUser gender:" + " Male.\n\nConversation:\nUser: Hi, had a question\nMax: " + "Hi there, happy to help!\nUser: Is there no way to return a product?" + " I got your blue T-Shirt size small but it doesn't fit.\n" + "Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n" + "User: That's a shame. \nMax: Is there anything else i can do for you?\n\n" + "##\n\nThe following is a conversation between a user of an eCommerce store and a user operation" + " associate called Max. Max is very kind and keen to help. The following are important points about" + " the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n" + 'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" ' + "t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me" + " to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: anc@gmail.com\n" + "Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between" + " a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help." + " The following are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it" + " take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working" + " days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n" + "Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n" + "Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n" + "Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an" + " eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following" + " are important points about the business policies:\n- Delivery takes up to 5 days\n" + "- There is no return option\n\nUser gender: Female.\n\nConversation:\n" + "User: Hi, I have a question for you" +) + +client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") + + +async def main(): + response = await client.completion.create(prompt=prompt, max_tokens=2) + + print(response.completions[0].data.text) + + print(response.prompt.tokens[0]["textRange"]["start"]) + + +asyncio.run(main()) diff --git a/examples/sagemaker/async_gec.py b/examples/sagemaker/async_gec.py new file mode 100644 index 00000000..526498f3 --- /dev/null +++ b/examples/sagemaker/async_gec.py @@ -0,0 +1,14 @@ +import asyncio + +from ai21 import AsyncAI21SageMakerClient + +client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") + + +async def main(): + response = client.gec.create(text="roc and rolle") + + print(response.corrections[0].suggestion) + + +asyncio.run(main()) diff --git a/examples/sagemaker/async_paraphrase.py b/examples/sagemaker/async_paraphrase.py new file mode 100644 index 00000000..4bff2c49 --- /dev/null +++ b/examples/sagemaker/async_paraphrase.py @@ -0,0 +1,17 @@ +import asyncio + +from ai21 import AsyncAI21SageMakerClient + +client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") + + +async def main(): + response = await client.paraphrase.create( + text="What's the difference between Scottish Fold and British?", + style="formal", + ) + + print(response.suggestions[0].text) + + +asyncio.run(main()) diff --git a/examples/sagemaker/async_summarization.py b/examples/sagemaker/async_summarization.py new file mode 100644 index 00000000..44cb4bff --- /dev/null +++ b/examples/sagemaker/async_summarization.py @@ -0,0 +1,23 @@ +import asyncio + +from ai21 import AsyncAI21SageMakerClient + +client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name") + + +async def main(): + response = await client.summarize.create( + source="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2]" + " From the 10th to the 16th century, " + "Holland proper was a unified political region within the Holy Roman Empire as a" + " county ruled by the counts of Holland. By the 17th century, " + "the province of Holland had risen to become a maritime and economic power," + " dominating the other provinces of the newly independent Dutch " + "Republic.", + source_type="TEXT", + ) + + print(response.summary) + + +asyncio.run(main()) diff --git a/tests/integration_tests/clients/test_sagemaker.py b/tests/integration_tests/clients/test_sagemaker.py index 263760ef..c1d8f3d8 100644 --- a/tests/integration_tests/clients/test_sagemaker.py +++ b/tests/integration_tests/clients/test_sagemaker.py @@ -13,17 +13,27 @@ argnames=["test_file_name"], argvalues=[ ("answer.py",), + ("async_answer.py",), ("completion.py",), + ("async_completion.py",), ("gec.py",), + ("async_gec.py",), ("paraphrase.py",), + ("async_paraphrase.py",), ("summarization.py",), + ("async_summarization.py",), ], ids=[ "when_answer__should_return_ok", + "when_async_answer__should_return_ok", "when_completion__should_return_ok", + "when_async_completion__should_return_ok", "when_gec__should_return_ok", + "when_async_gec__should_return_ok", "when_paraphrase__should_return_ok", + "when_async_paraphrase__should_return_ok", "when_summarization__should_return_ok", + "when_async_summarization__should_return_ok", ], ) def test_sagemaker(test_file_name: str): diff --git a/tests/unittests/test_imports.py b/tests/unittests/test_imports.py index 3bed0630..4ea6b7bf 100644 --- a/tests/unittests/test_imports.py +++ b/tests/unittests/test_imports.py @@ -20,6 +20,7 @@ "SageMaker", "TooManyRequestsError", "AsyncAI21BedrockClient", + "AsyncAI21SageMakerClient", ]