Skip to content

Commit

Permalink
chore: add async to sm resources
Browse files Browse the repository at this point in the history
  • Loading branch information
miri-bar committed Jun 27, 2024
1 parent 6107f45 commit 26fe85b
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 20 deletions.
10 changes: 10 additions & 0 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def _import_async_bedrock_client():
return AsyncAI21BedrockClient


def _import_async_sagemaker_client():
from ai21.clients.sagemaker.ai21_sagemaker_client import AsyncAI21SageMakerClient

return AsyncAI21SageMakerClient


def __getattr__(name: str) -> Any:
try:
if name == "AI21BedrockClient":
Expand All @@ -58,6 +64,9 @@ def __getattr__(name: str) -> Any:

if name == "AsyncAI21BedrockClient":
return _import_async_bedrock_client()

if name == "AsyncAI21SageMakerClient":
return _import_async_sagemaker_client()
except ImportError as e:
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e

Expand All @@ -79,4 +88,5 @@ def __getattr__(name: str) -> Any:
"AI21AzureClient",
"AsyncAI21AzureClient",
"AsyncAI21BedrockClient",
"AsyncAI21SageMakerClient",
]
75 changes: 65 additions & 10 deletions ai21/clients/sagemaker/ai21_sagemaker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig

from ai21.clients.aws.utils import get_aws_region
from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer
from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion
from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC
from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase
from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize
from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer, AsyncSageMakerAnswer
from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion, AsyncSageMakerCompletion
from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC, AsyncSageMakerGEC
from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase, AsyncSageMakerParaphrase
from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize, AsyncSageMakerSummarize
from ai21.http_client.async_http_client import AsyncHttpClient
from ai21.http_client.http_client import HttpClient


Expand Down Expand Up @@ -38,8 +39,62 @@ def __init__(
timeout_sec=timeout_sec,
num_retries=num_retries,
)
self.completion = SageMakerCompletion(endpoint_name=endpoint_name, region=region, client=self._http_client)
self.paraphrase = SageMakerParaphrase(endpoint_name=endpoint_name, region=region, client=self._http_client)
self.answer = SageMakerAnswer(endpoint_name=endpoint_name, region=region, client=self._http_client)
self.gec = SageMakerGEC(endpoint_name=endpoint_name, region=region, client=self._http_client)
self.summarize = SageMakerSummarize(endpoint_name=endpoint_name, region=region, client=self._http_client)

self.completion = SageMakerCompletion(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.paraphrase = SageMakerParaphrase(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.answer = SageMakerAnswer(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.gec = SageMakerGEC(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.summarize = SageMakerSummarize(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)


class AsyncAI21SageMakerClient:
"""
:param endpoint_name: The name of the endpoint to use for the client.
:param region: The AWS region of the endpoint.
:param session: An optional boto3 session to use for the client.
"""

def __init__(
self,
endpoint_name: str,
region: Optional[str] = None,
session: Optional["boto3.Session"] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
headers: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[AsyncHttpClient] = None,
**kwargs,
):
region = get_aws_region(env_config=env_config, session=session, region=region)
self._http_client = http_client or AsyncHttpClient(
headers=headers,
timeout_sec=timeout_sec,
num_retries=num_retries,
)

self.completion = AsyncSageMakerCompletion(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.paraphrase = AsyncSageMakerParaphrase(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.answer = AsyncSageMakerAnswer(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.gec = AsyncSageMakerGEC(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.summarize = AsyncSageMakerSummarize(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
15 changes: 14 additions & 1 deletion ai21/clients/sagemaker/resources/sagemaker_answer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ai21.clients.common.answer_base import Answer
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import AnswerResponse


Expand All @@ -14,3 +14,16 @@ def create(
response = self._post(body)

return self._json_to_response(response.json())


class AsyncSageMakerAnswer(AsyncSageMakerResource, Answer):
async def create(
self,
context: str,
question: str,
**kwargs,
) -> AnswerResponse:
body = self._create_body(context=context, question=question)
response = await self._post(body)

return self._json_to_response(response.json())
60 changes: 59 additions & 1 deletion ai21/clients/sagemaker/resources/sagemaker_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import List, Dict

from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import Penalty, CompletionsResponse
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given
Expand Down Expand Up @@ -64,3 +64,61 @@ def create(
raw_response = self._post(body=body)

return CompletionsResponse.from_dict(raw_response.json())


class AsyncSageMakerCompletion(AsyncSageMakerResource):
async def create(
self,
prompt: str,
*,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
top_k_return: int | NotGiven = NOT_GIVEN,
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
"""
:param prompt: Text for model to complete
:param max_tokens: The maximum number of tokens to generate per result
:param num_results: Number of completions to sample and return.
:param min_tokens: The minimum number of tokens to generate per result.
:param temperature: A value controlling the "creativity" of the model's responses.
:param top_p: A value controlling the diversity of the model's responses.
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
:param stop_sequences: Stops decoding if any of the strings is generated
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text
representations of the tokens and the floats are the biases themselves. A positive bias increases generation
probability for a given token and a negative bias decreases it.
:param kwargs:
:return:
"""
body = remove_not_given(
{
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
"logitBias": logit_bias,
}
)

raw_response = await self._post(body=body)

return CompletionsResponse.from_dict(raw_response.json())
11 changes: 10 additions & 1 deletion ai21/clients/sagemaker/resources/sagemaker_gec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ai21.clients.common.gec_base import GEC
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import GECResponse


Expand All @@ -10,3 +10,12 @@ def create(self, text: str, **kwargs) -> GECResponse:
response = self._post(body)

return self._json_to_response(response.json())


class AsyncSageMakerGEC(AsyncSageMakerResource, GEC):
async def create(self, text: str, **kwargs) -> GECResponse:
body = self._create_body(text=text)

response = await self._post(body)

return self._json_to_response(response.json())
23 changes: 22 additions & 1 deletion ai21/clients/sagemaker/resources/sagemaker_paraphrase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from ai21.clients.common.paraphrase_base import Paraphrase
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import ParaphraseStyleType, ParaphraseResponse


Expand All @@ -24,3 +24,24 @@ def create(
response = self._post(body=body)

return self._json_to_response(response.json())


class AsyncSageMakerParaphrase(AsyncSageMakerResource, Paraphrase):
async def create(
self,
text: str,
*,
style: Optional[ParaphraseStyleType] = None,
start_index: Optional[int] = 0,
end_index: Optional[int] = None,
**kwargs,
) -> ParaphraseResponse:
body = self._create_body(
text=text,
style=style,
start_index=start_index,
end_index=end_index,
)
response = await self._post(body=body)

return self._json_to_response(response.json())
76 changes: 71 additions & 5 deletions ai21/clients/sagemaker/resources/sagemaker_resource.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,92 @@
from __future__ import annotations

import json
from abc import ABC
from typing import Any, Dict
from typing import Any, Dict, Optional

import boto3
import httpx

from ai21.clients.aws_http_client.aws_http_client import AWSHttpClient
from ai21 import AI21APIError
from ai21.clients.aws.aws_authorization import AWSAuthorization
from ai21.errors import AccessDenied, NotFound, APITimeoutError, ModelErrorException, InternalDependencyException
from ai21.http_client.async_http_client import AsyncHttpClient
from ai21.http_client.http_client import HttpClient


def _handle_sagemaker_error(aws_error: AI21APIError) -> None:
status_code = aws_error.status_code
if status_code == 403:
raise AccessDenied(details=aws_error.details)

if status_code == 404:
raise NotFound(details=aws_error.details)

if status_code == 408:
raise APITimeoutError(details=aws_error.details)

if status_code == 424:
raise ModelErrorException(details=aws_error.details)

if status_code == 530:
raise InternalDependencyException(details=aws_error.details)

raise aws_error


class SageMakerResource(ABC):
def __init__(
self,
endpoint_name: str,
region: str,
client: AWSHttpClient,
client: HttpClient,
aws_session: Optional[boto3.Session] = None,
):
self._client = client

self._aws_session = aws_session or boto3.Session(region_name=region)
self._url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations"
self._aws_auth = AWSAuthorization(aws_session=self._aws_session)

def _post(
self,
body: Dict[str, Any],
) -> httpx.Response:
return self._client.execute_http_request(url=self._url, body=body, method="POST", service_name="sagemaker")
auth_headers = self._aws_auth.get_auth_headers(
service_name="sagemaker", url=self._url, method="POST", data=json.dumps(body)
)

try:
return self._client.execute_http_request(
url=self._url, body=body, method="POST", extra_headers=auth_headers
)
except AI21APIError as aws_error:
_handle_sagemaker_error(aws_error)


class AsyncSageMakerResource(ABC):
def __init__(
self,
endpoint_name: str,
region: str,
client: AsyncHttpClient,
aws_session: Optional[boto3.Session] = None,
):
self._client = client
self._aws_session = aws_session or boto3.Session(region_name=region)
self._url = f"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{endpoint_name}/invocations"
self._aws_auth = AWSAuthorization(aws_session=self._aws_session)

async def _post(
self,
body: Dict[str, Any],
) -> httpx.Response:
auth_headers = self._aws_auth.get_auth_headers(
service_name="sagemaker", url=self._url, method="POST", data=json.dumps(body)
)

try:
return await self._client.execute_http_request(
url=self._url, body=body, method="POST", extra_headers=auth_headers
)
except AI21APIError as aws_error:
_handle_sagemaker_error(aws_error)
24 changes: 23 additions & 1 deletion ai21/clients/sagemaker/resources/sagemaker_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

from ai21.clients.common.summarize_base import Summarize
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import SummarizeResponse, SummaryMethod


Expand All @@ -27,3 +27,25 @@ def create(
response = self._post(body)

return self._json_to_response(response.json())


class AsyncSageMakerSummarize(AsyncSageMakerResource, Summarize):
async def create(
self,
source: str,
source_type: str,
*,
focus: Optional[str] = None,
summary_method: Optional[SummaryMethod] = None,
**kwargs,
) -> SummarizeResponse:
body = self._create_body(
source=source,
source_type=source_type,
focus=focus,
summary_method=summary_method,
)

response = await self._post(body)

return self._json_to_response(response.json())
Loading

0 comments on commit 26fe85b

Please sign in to comment.