Skip to content

Commit

Permalink
refactor: export get_aws_region, add async sm to readme, add async ex…
Browse files Browse the repository at this point in the history
…amples
  • Loading branch information
miri-bar committed Jul 1, 2024
1 parent 5d78367 commit 1ac7804
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 13 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,24 @@ response = client.summarize.create(
print(response.summary)
```

#### Async

```python
import asyncio
from ai21 import AsyncAI21SageMakerClient

client = AsyncAI21SageMakerClient(endpoint_name="j2-endpoint-name")

async def main():
response = await client.summarize.create(
source="Text to summarize",
source_type="TEXT",
)
print(response.summary)

asyncio.run(main())
```

### With Boto3 Session

```python
Expand Down
16 changes: 3 additions & 13 deletions ai21/clients/bedrock/ai21_bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@
import boto3

from ai21.ai21_env_config import _AI21EnvConfig, AI21EnvConfig
from ai21.clients.aws.utils import get_aws_region
from ai21.clients.bedrock.resources.chat.bedrock_chat import BedrockChat, AsyncBedrockChat
from ai21.clients.bedrock.resources.bedrock_completion import BedrockCompletion, AsyncBedrockCompletion
from ai21.http_client.http_client import HttpClient
from ai21.http_client.async_http_client import AsyncHttpClient


def _get_aws_region(
env_config: _AI21EnvConfig,
session: Optional[boto3.Session] = None,
region: Optional[str] = None,
) -> str:
if session is not None:
return session.region_name

return region or env_config.aws_region


class AI21BedrockClient:
def __init__(
self,
Expand All @@ -40,7 +30,7 @@ def __init__(
DeprecationWarning,
)

region = _get_aws_region(env_config=env_config, session=session, region=region)
region = get_aws_region(env_config=env_config, session=session, region=region)

self._http_client = http_client or HttpClient(
timeout_sec=timeout_sec,
Expand Down Expand Up @@ -73,7 +63,7 @@ def __init__(
DeprecationWarning,
)

region = _get_aws_region(env_config=env_config, session=session, region=region)
region = get_aws_region(env_config=env_config, session=session, region=region)

self._http_client = http_client or AsyncHttpClient(
timeout_sec=timeout_sec,
Expand Down
47 changes: 47 additions & 0 deletions examples/sagemaker/async_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio

from ai21 import AsyncAI21SageMakerClient

prompt = (
"The following is a conversation between a user of an eCommerce store and a user operation"
" associate called Max. Max is very kind and keen to help."
" The following are important points about the business policies:\n- "
"Delivery takes up to 5 days\n- There is no return option\n\nUser gender:"
" Male.\n\nConversation:\nUser: Hi, had a question\nMax: "
"Hi there, happy to help!\nUser: Is there no way to return a product?"
" I got your blue T-Shirt size small but it doesn't fit.\n"
"Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n"
"User: That's a shame. \nMax: Is there anything else i can do for you?\n\n"
"##\n\nThe following is a conversation between a user of an eCommerce store and a user operation"
" associate called Max. Max is very kind and keen to help. The following are important points about"
" the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n"
'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" '
"t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me"
" to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: [email protected]\n"
"Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between"
" a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help."
" The following are important points about the business policies:\n- Delivery takes up to 5 days\n"
"- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it"
" take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working"
" days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n"
"Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n"
"Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n"
"Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an"
" eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following"
" are important points about the business policies:\n- Delivery takes up to 5 days\n"
"- There is no return option\n\nUser gender: Female.\n\nConversation:\n"
"User: Hi, I have a question for you"
)

client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name")


async def main():
response = await client.completion.create(prompt=prompt, max_tokens=2)

print(response.completions[0].data.text)

print(response.prompt.tokens[0]["textRange"]["start"])


asyncio.run(main())
14 changes: 14 additions & 0 deletions examples/sagemaker/async_gec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio

from ai21 import AsyncAI21SageMakerClient

client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name")


async def main():
response = client.gec.create(text="roc and rolle")

print(response.corrections[0].suggestion)


asyncio.run(main())
17 changes: 17 additions & 0 deletions examples/sagemaker/async_paraphrase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import asyncio

from ai21 import AsyncAI21SageMakerClient

client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name")


async def main():
response = await client.paraphrase.create(
text="What's the difference between Scottish Fold and British?",
style="formal",
)

print(response.suggestions[0].text)


asyncio.run(main())
23 changes: 23 additions & 0 deletions examples/sagemaker/async_summarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import asyncio

from ai21 import AsyncAI21SageMakerClient

client = AsyncAI21SageMakerClient(endpoint_name="sm_endpoint_name")


async def main():
response = await client.summarize.create(
source="Holland is a geographical region[2] and former province on the western coast of the Netherlands.[2]"
" From the 10th to the 16th century, "
"Holland proper was a unified political region within the Holy Roman Empire as a"
" county ruled by the counts of Holland. By the 17th century, "
"the province of Holland had risen to become a maritime and economic power,"
" dominating the other provinces of the newly independent Dutch "
"Republic.",
source_type="TEXT",
)

print(response.summary)


asyncio.run(main())

0 comments on commit 1ac7804

Please sign in to comment.