Skip to content

Commit

Permalink
Add sagemaker async support (#155)
Browse files Browse the repository at this point in the history
* refactor: migrate from boto3 to custom http client

* refactor: create an aws http client, and switch bedrock client to use it

* test: rename test, add async tests

* feat: add async client, remove bedrock_session class

* docs: Azure README (#139)

* feat: Add async tokenizer, add detokenize method (#144)

* feat: add detokenize method, add async tokenizer

* chore: update pyproject and poetry.lock

* fix: fix tokenizer name in examples and readme, add example

* test: adjust unittest + add unittest for async

* chore: cache -> lru_cache to support python 3.8

* test: fix test_imports test

* chore: add env_config to bedrock client to avoid breaking changes

* refactor: sagemaker client, boto->aws http client

* refactor: export aws auth logic to new class

* refactor: remove aws_http_client, use http_client instead, add aws auth test

* test: fix tests

* refactor: remove aws_http_client

* chore(deps-dev): bump authlib from 1.3.0 to 1.3.1 (#131)

Bumps [authlib](https://github.com/lepture/authlib) from 1.3.0 to 1.3.1.
- [Release notes](https://github.com/lepture/authlib/releases)
- [Changelog](https://github.com/lepture/authlib/blob/master/docs/changelog.rst)
- [Commits](lepture/authlib@v1.3.0...v1.3.1)

---
updated-dependencies:
- dependency-name: authlib
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Asaf Joseph Gardin <[email protected]>

* chore(deps): bump pypa/gh-action-pypi-publish from 1.8.14 to 1.9.0 (#138)

Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.8.14 to 1.9.0.
- [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases)
- [Commits](pypa/gh-action-pypi-publish@81e9d93...ec4db0b)

---
updated-dependencies:
- dependency-name: pypa/gh-action-pypi-publish
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Asaf Joseph Gardin <[email protected]>

* chore: rebase code

* refactor: chat + chat completions - migrate to new client

* refactor: cr comments

* chore: add async to new bedrock components

* refactor: rename aws folder

* chore: refactor to use http client

* chore: fix typo on file name bedrock_chat_completions

* fix: fix errors

* chore: fix typo

* chore: add async to sm resources

* test: fix imports test

* fix: Added deprecation warning

* fix: Added deprecation warning to async

* chore: add log for ignoring stream

* chore: fix lint

* refactor: export get_aws_region, add async sm to readme, add async examples

* test: add async test files to test_sagemaker.py

* refactor: remove get_aws_region func

---------

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: Asaf Joseph Gardin <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 4, 2024
1 parent f117c3f commit cf1a576
Show file tree
Hide file tree
Showing 18 changed files with 465 additions and 54 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,24 @@ response = client.summarize.create(
print(response.summary)
```

#### Async

```python
import asyncio
from ai21 import AsyncAI21SageMakerClient

client = AsyncAI21SageMakerClient(endpoint_name="j2-endpoint-name")

async def main():
response = await client.summarize.create(
source="Text to summarize",
source_type="TEXT",
)
print(response.summary)

asyncio.run(main())
```

### With Boto3 Session

```python
Expand Down
10 changes: 10 additions & 0 deletions ai21/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def _import_async_bedrock_client():
return AsyncAI21BedrockClient


def _import_async_sagemaker_client():
from ai21.clients.sagemaker.ai21_sagemaker_client import AsyncAI21SageMakerClient

return AsyncAI21SageMakerClient


def __getattr__(name: str) -> Any:
try:
if name == "AI21BedrockClient":
Expand All @@ -58,6 +64,9 @@ def __getattr__(name: str) -> Any:

if name == "AsyncAI21BedrockClient":
return _import_async_bedrock_client()

if name == "AsyncAI21SageMakerClient":
return _import_async_sagemaker_client()
except ImportError as e:
raise ImportError(f'Please install "ai21[AWS]" in order to use {name}') from e

Expand All @@ -79,4 +88,5 @@ def __getattr__(name: str) -> Any:
"AI21AzureClient",
"AsyncAI21AzureClient",
"AsyncAI21BedrockClient",
"AsyncAI21SageMakerClient",
]
19 changes: 3 additions & 16 deletions ai21/clients/bedrock/ai21_bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,13 @@

import boto3

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.ai21_env_config import AI21EnvConfig
from ai21.clients.bedrock.resources.chat.bedrock_chat import BedrockChat, AsyncBedrockChat
from ai21.clients.bedrock.resources.bedrock_completion import BedrockCompletion, AsyncBedrockCompletion
from ai21.http_client.http_client import HttpClient
from ai21.http_client.async_http_client import AsyncHttpClient


def _get_aws_region(
env_config: _AI21EnvConfig,
session: Optional[boto3.Session] = None,
region: Optional[str] = None,
) -> str:
if session is not None:
return session.region_name

return region or env_config.aws_region


class AI21BedrockClient:
def __init__(
self,
Expand All @@ -31,7 +20,6 @@ def __init__(
num_retries: Optional[int] = None,
http_client: Optional[HttpClient] = None,
session: Optional[boto3.Session] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
if model_id is not None:
warnings.warn(
Expand All @@ -40,7 +28,7 @@ def __init__(
DeprecationWarning,
)

region = _get_aws_region(env_config=env_config, session=session, region=region)
region = region or AI21EnvConfig.aws_region

self._http_client = http_client or HttpClient(
timeout_sec=timeout_sec,
Expand All @@ -64,7 +52,6 @@ def __init__(
num_retries: Optional[int] = None,
http_client: Optional[AsyncHttpClient] = None,
session: Optional[boto3.Session] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
):
if model_id is not None:
warnings.warn(
Expand All @@ -73,7 +60,7 @@ def __init__(
DeprecationWarning,
)

region = _get_aws_region(env_config=env_config, session=session, region=region)
region = region or AI21EnvConfig.aws_region

self._http_client = http_client or AsyncHttpClient(
timeout_sec=timeout_sec,
Expand Down
92 changes: 76 additions & 16 deletions ai21/clients/sagemaker/ai21_sagemaker_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from typing import Optional
from typing import Optional, Dict, Any

import boto3

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer
from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion
from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC
from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase
from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize
from ai21.clients.sagemaker.sagemaker_session import SageMakerSession

from ai21.clients.sagemaker.resources.sagemaker_answer import SageMakerAnswer, AsyncSageMakerAnswer
from ai21.clients.sagemaker.resources.sagemaker_completion import SageMakerCompletion, AsyncSageMakerCompletion
from ai21.clients.sagemaker.resources.sagemaker_gec import SageMakerGEC, AsyncSageMakerGEC
from ai21.clients.sagemaker.resources.sagemaker_paraphrase import SageMakerParaphrase, AsyncSageMakerParaphrase
from ai21.clients.sagemaker.resources.sagemaker_summarize import SageMakerSummarize, AsyncSageMakerSummarize
from ai21.http_client.async_http_client import AsyncHttpClient
from ai21.http_client.http_client import HttpClient


class AI21SageMakerClient:
Expand All @@ -18,22 +20,80 @@ class AI21SageMakerClient:
:param session: An optional boto3 session to use for the client.
"""

def __init__(
self,
endpoint_name: str,
region: Optional[str] = None,
session: Optional["boto3.Session"] = None,
headers: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[HttpClient] = None,
**kwargs,
):
region = region or AI21EnvConfig.aws_region
self._http_client = http_client or HttpClient(
headers=headers,
timeout_sec=timeout_sec,
num_retries=num_retries,
)

self.completion = SageMakerCompletion(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.paraphrase = SageMakerParaphrase(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.answer = SageMakerAnswer(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.gec = SageMakerGEC(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.summarize = SageMakerSummarize(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)


class AsyncAI21SageMakerClient:
"""
:param endpoint_name: The name of the endpoint to use for the client.
:param region: The AWS region of the endpoint.
:param session: An optional boto3 session to use for the client.
"""

def __init__(
self,
endpoint_name: str,
region: Optional[str] = None,
session: Optional["boto3.Session"] = None,
env_config: _AI21EnvConfig = AI21EnvConfig,
headers: Optional[Dict[str, Any]] = None,
timeout_sec: Optional[float] = None,
num_retries: Optional[int] = None,
http_client: Optional[AsyncHttpClient] = None,
**kwargs,
):
region = region or AI21EnvConfig.aws_region

self._http_client = http_client or AsyncHttpClient(
headers=headers,
timeout_sec=timeout_sec,
num_retries=num_retries,
)

self._env_config = env_config
_session = ()
self._session = SageMakerSession(
session=session, region=region or self._env_config.aws_region, endpoint_name=endpoint_name
self.completion = AsyncSageMakerCompletion(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.paraphrase = AsyncSageMakerParaphrase(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.answer = AsyncSageMakerAnswer(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.gec = AsyncSageMakerGEC(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.summarize = AsyncSageMakerSummarize(
endpoint_name=endpoint_name, region=region, client=self._http_client, aws_session=session
)
self.completion = SageMakerCompletion(self._session)
self.paraphrase = SageMakerParaphrase(self._session)
self.answer = SageMakerAnswer(self._session)
self.gec = SageMakerGEC(self._session)
self.summarize = SageMakerSummarize(self._session)
19 changes: 16 additions & 3 deletions ai21/clients/sagemaker/resources/sagemaker_answer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ai21.clients.common.answer_base import Answer
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import AnswerResponse


Expand All @@ -11,6 +11,19 @@ def create(
**kwargs,
) -> AnswerResponse:
body = self._create_body(context=context, question=question)
response = self._invoke(body)
response = self._post(body)

return self._json_to_response(response)
return self._json_to_response(response.json())


class AsyncSageMakerAnswer(AsyncSageMakerResource, Answer):
async def create(
self,
context: str,
question: str,
**kwargs,
) -> AnswerResponse:
body = self._create_body(context=context, question=question)
response = await self._post(body)

return self._json_to_response(response.json())
64 changes: 61 additions & 3 deletions ai21/clients/sagemaker/resources/sagemaker_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import List, Dict

from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import Penalty, CompletionsResponse
from ai21.types import NotGiven, NOT_GIVEN
from ai21.utils.typing import remove_not_given
Expand Down Expand Up @@ -61,6 +61,64 @@ def create(
}
)

raw_response = self._invoke(body)
raw_response = self._post(body=body)

return CompletionsResponse.from_dict(raw_response)
return CompletionsResponse.from_dict(raw_response.json())


class AsyncSageMakerCompletion(AsyncSageMakerResource):
async def create(
self,
prompt: str,
*,
max_tokens: int | NotGiven = NOT_GIVEN,
num_results: int | NotGiven = NOT_GIVEN,
min_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
top_k_return: int | NotGiven = NOT_GIVEN,
stop_sequences: List[str] | NotGiven = NOT_GIVEN,
frequency_penalty: Penalty | NotGiven = NOT_GIVEN,
presence_penalty: Penalty | NotGiven = NOT_GIVEN,
count_penalty: Penalty | NotGiven = NOT_GIVEN,
logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN,
**kwargs,
) -> CompletionsResponse:
"""
:param prompt: Text for model to complete
:param max_tokens: The maximum number of tokens to generate per result
:param num_results: Number of completions to sample and return.
:param min_tokens: The minimum number of tokens to generate per result.
:param temperature: A value controlling the "creativity" of the model's responses.
:param top_p: A value controlling the diversity of the model's responses.
:param top_k_return: The number of top-scoring tokens to consider for each generation step.
:param stop_sequences: Stops decoding if any of the strings is generated
:param frequency_penalty: A penalty applied to tokens that are frequently generated.
:param presence_penalty: A penalty applied to tokens that are already present in the prompt.
:param count_penalty: A penalty applied to tokens based on their frequency in the generated responses
:param logit_bias: A dictionary which contains mapping from strings to floats, where the strings are text
representations of the tokens and the floats are the biases themselves. A positive bias increases generation
probability for a given token and a negative bias decreases it.
:param kwargs:
:return:
"""
body = remove_not_given(
{
"prompt": prompt,
"maxTokens": max_tokens,
"numResults": num_results,
"minTokens": min_tokens,
"temperature": temperature,
"topP": top_p,
"topKReturn": top_k_return,
"stopSequences": stop_sequences or [],
"frequencyPenalty": frequency_penalty.to_dict() if frequency_penalty else frequency_penalty,
"presencePenalty": presence_penalty.to_dict() if presence_penalty else presence_penalty,
"countPenalty": count_penalty.to_dict() if count_penalty else count_penalty,
"logitBias": logit_bias,
}
)

raw_response = await self._post(body=body)

return CompletionsResponse.from_dict(raw_response.json())
15 changes: 12 additions & 3 deletions ai21/clients/sagemaker/resources/sagemaker_gec.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from ai21.clients.common.gec_base import GEC
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import GECResponse


class SageMakerGEC(SageMakerResource, GEC):
def create(self, text: str, **kwargs) -> GECResponse:
body = self._create_body(text=text)

response = self._invoke(body)
response = self._post(body)

return self._json_to_response(response)
return self._json_to_response(response.json())


class AsyncSageMakerGEC(AsyncSageMakerResource, GEC):
async def create(self, text: str, **kwargs) -> GECResponse:
body = self._create_body(text=text)

response = await self._post(body)

return self._json_to_response(response.json())
27 changes: 24 additions & 3 deletions ai21/clients/sagemaker/resources/sagemaker_paraphrase.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from ai21.clients.common.paraphrase_base import Paraphrase
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource
from ai21.clients.sagemaker.resources.sagemaker_resource import SageMakerResource, AsyncSageMakerResource
from ai21.models import ParaphraseStyleType, ParaphraseResponse


Expand All @@ -21,6 +21,27 @@ def create(
start_index=start_index,
end_index=end_index,
)
response = self._invoke(body=body)
response = self._post(body=body)

return self._json_to_response(response)
return self._json_to_response(response.json())


class AsyncSageMakerParaphrase(AsyncSageMakerResource, Paraphrase):
async def create(
self,
text: str,
*,
style: Optional[ParaphraseStyleType] = None,
start_index: Optional[int] = 0,
end_index: Optional[int] = None,
**kwargs,
) -> ParaphraseResponse:
body = self._create_body(
text=text,
style=style,
start_index=start_index,
end_index=end_index,
)
response = await self._post(body=body)

return self._json_to_response(response.json())
Loading

0 comments on commit cf1a576

Please sign in to comment.