Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve AWS bedrock integration #289

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
89 changes: 77 additions & 12 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class BedrockAPIClient(ModelClient):
Setup:
1. Install boto3: `pip install boto3`
2. Ensure you have the AWS credentials set up. There are four variables you can optionally set:
Either AWS_PROFILE_NAME or (AWS_REGION_NAME and AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY) are needed
- AWS_PROFILE_NAME: The name of the AWS profile to use.
- AWS_REGION_NAME: The name of the AWS region to use.
- AWS_ACCESS_KEY_ID: The AWS access key ID.
Expand All @@ -76,10 +77,9 @@ class BedrockAPIClient(ModelClient):
self.generator = Generator(
model_client=BedrockAPIClient(),
model_kwargs={
"modelId": "anthropic.claude-3-sonnet-20240229-v1:0",
"inferenceConfig": {
"temperature": 0.8
}
"model": "mistral.mistral-7b-instruct-v0:2",
"temperature": 0.8,
"max_tokens": 100
}, template=template
)

Expand All @@ -95,8 +95,8 @@ class BedrockAPIClient(ModelClient):

def __init__(
self,
aws_profile_name="default",
aws_region_name="us-west-2", # Use a supported default region
aws_profile_name=None,
aws_region_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
Expand All @@ -118,6 +118,12 @@ def __init__(
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)
self.inference_parameters = [
"maxTokens",
"temperature",
"topP",
"stopSequences",
]

def init_sync_client(self):
"""
Expand Down Expand Up @@ -175,21 +181,70 @@ def track_completion_usage(self, completion: Dict) -> CompletionUsage:
total_tokens=usage["totalTokens"],
)

def list_models(self):
def list_models(self, **kwargs):
# Initialize Bedrock client (not runtime)
# Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_ListFoundationModels.html

try:
response = self._client.list_foundation_models()
models = response.get("models", [])
response = self._client.list_foundation_models(**kwargs)
models = response.get("modelSummaries", [])
for model in models:
print(f"Model ID: {model['modelId']}")
print(f" Name: {model['name']}")
print(f" Description: {model['description']}")
print(f" Provider: {model['provider']}")
print(f" Name: {model['modelName']}")
print(f" Model ARN: {model['modelArn']}")
print(f" Provider: {model['providerName']}")
print(f" Input: {model['inputModalities']}")
print(f" Output: {model['outputModalities']}")
print(f" InferenceTypesSupported: {model['inferenceTypesSupported']}")
print("")
except Exception as e:
print(f"Error listing models: {e}")

def _validate_and_process_config_keys(self, api_kwargs: Dict):
"""
Validate and process the model ID in API kwargs.

:param api_kwargs: Dictionary of API keyword arguments
:raises KeyError: If 'model' key is missing
"""
if "model" in api_kwargs:
api_kwargs["modelId"] = api_kwargs.pop("model")
else:
raise KeyError("The required key 'model' is missing in model_kwargs.")

# In .converse() `maxTokens`` is the key for maximum tokens limit
if "max_tokens" in api_kwargs:
api_kwargs["maxTokens"] = api_kwargs.pop("max_tokens")

return api_kwargs

def _separate_parameters(self, api_kwargs: Dict) -> tuple:
"""
Separate inference configuration and additional model request fields.

:param api_kwargs: Dictionary of API keyword arguments
:return: Tuple of (inference_config, additional_model_request_fields)
"""
inference_config = {}
additional_model_request_fields = {}
keys_to_remove = set()
excluded_keys = {"modelId"}

# Categorize parameters
for key, value in list(api_kwargs.items()):
if key in self.inference_parameters:
inference_config[key] = value
keys_to_remove.add(key)
elif key not in excluded_keys:
additional_model_request_fields[key] = value
keys_to_remove.add(key)

# Remove categorized keys from api_kwargs
for key in keys_to_remove:
api_kwargs.pop(key, None)

return api_kwargs, inference_config, additional_model_request_fields

def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
Expand All @@ -202,9 +257,19 @@ def convert_inputs_to_api_kwargs(
"""
api_kwargs = model_kwargs.copy()
if model_type == ModelType.LLM:
# Validate and process model ID
api_kwargs = self._validate_and_process_config_keys(api_kwargs)

# Separate inference config and additional model request fields
api_kwargs, inference_config, additional_model_request_fields = (
self._separate_parameters(api_kwargs)
)

api_kwargs["messages"] = [
{"role": "user", "content": [{"text": input}]},
]
api_kwargs["inferenceConfig"] = inference_config
api_kwargs["additionalModelRequestFields"] = additional_model_request_fields
else:
raise ValueError(f"Model type {model_type} not supported")
return api_kwargs
Expand Down
22 changes: 11 additions & 11 deletions adalflow/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions adalflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ openai = "^1.12.0"
groq = "^0.9.0"
google-generativeai = "^0.7.2"
anthropic = "^0.31.1"
boto3 = "^1.35.19"

[tool.poetry.group.typing.dependencies]
mypy = "^1"
Expand Down
91 changes: 91 additions & 0 deletions adalflow/tests/test_aws_bedrock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
import unittest
from unittest.mock import patch, Mock

from adalflow.core.types import ModelType, GeneratorOutput
from adalflow.components.model_client import BedrockAPIClient


def getenv_side_effect(key):
# This dictionary can hold more keys and values as needed
env_vars = {
"AWS_ACCESS_KEY_ID": "WQIWQORE2RK63VGZJKGF",
"AWS_SECRET_ACCESS_KEY": "aW1dDWQKdR/Sx3fI39N6ycoAYTjj3vsPSuN44ebU",
"AWS_REGION_NAME": "us-east-1",
}
return env_vars.get(key, None) # Returns None if key is not found


# modified from test_openai_client.py
class TestBedrockClient(unittest.TestCase):
def setUp(self):
# Patch os.environ to ensure all environment variables are set
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": "fake_api_key",
"AWS_SECRET_ACCESS_KEY": "fake_api_key",
"AWS_REGION_NAME": "fake_api_key",
"AWS_PROFILE_NAME": "fake_profile", # Adding additional profile if needed
},
):
# Now patch os.getenv to return mocked environment variable values
with patch("os.getenv", side_effect=getenv_side_effect):
self.client = BedrockAPIClient()
self.mock_response = {
"ResponseMetadata": {
"RequestId": "43aec10a-9780-4bd5-abcc-857d12460569",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"date": "Sat, 30 Nov 2024 14:27:44 GMT",
"content-type": "application/json",
"content-length": "273",
"connection": "keep-alive",
"x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569",
},
"RetryAttempts": 0,
},
"output": {
"message": {"role": "assistant", "content": [{"text": "Hello, world!"}]}
},
"stopReason": "end_turn",
"usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30},
"metrics": {"latencyMs": 430},
}

self.api_kwargs = {
"messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-3.5-turbo",
}

@patch.object(BedrockAPIClient, "init_sync_client")
@patch("adalflow.components.model_client.bedrock_client.boto3")
def test_call(self, MockBedrock, mock_init_sync_client):
mock_sync_client = Mock()
MockBedrock.return_value = mock_sync_client
mock_init_sync_client.return_value = mock_sync_client

# Mock the client's api: converse
mock_sync_client.converse = Mock(return_value=self.mock_response)

# Set the sync client
self.client.sync_client = mock_sync_client

# Call the call method
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)

# Assertions
mock_sync_client.converse.assert_called_once_with(**self.api_kwargs)
self.assertEqual(result, self.mock_response)

# test parse_chat_completion
output = self.client.parse_chat_completion(completion=self.mock_response)
self.assertTrue(isinstance(output, GeneratorOutput))
self.assertEqual(output.raw_response, "Hello, world!")
self.assertEqual(output.usage.prompt_tokens, 20)
self.assertEqual(output.usage.completion_tokens, 10)
self.assertEqual(output.usage.total_tokens, 30)


if __name__ == "__main__":
unittest.main()
55 changes: 55 additions & 0 deletions docs/source/integrations/aws_bedrock.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
.. _integration-aws-bedrock:

AWS Bedrock API Client
=======================

.. admonition:: Author
:class: highlight

`Ajith Kumar <https://github.com/ajithvcoder>`_

Getting Credentials
-------------------

You need to have an AWS account and an access key and secret key to use AWS Bedrock services. Moreover, the account associated with the access key must have
the necessary permissions to access Bedrock services. Refer to the `AWS documentation <https://docs.aws.amazon.com/singlesignon/latest/userguide/howtogetcredentials.html>`_ for more information on obtaining credentials.

Enabling Foundation Models
--------------------------

AWS Bedrock offers several foundation models from providers like "Meta," "Amazon," "Cohere," "Anthropic," and "Microsoft." To access these models, you need to enable them first. Note that each AWS region supports a specific set of models. Not all foundation models are available in every region, and pricing varies by region.

Pricing information: `AWS Bedrock Pricing <https://aws.amazon.com/bedrock/pricing/>`_

Steps for enabling model access:

1. Select the desired region in the AWS Console (e.g., `us-east-1 (N. Virginia)`).
2. Navigate to the `Bedrock services home page <https://console.aws.amazon.com/bedrock/home>`_.
3. On the left sidebar, under "Bedrock Configuration," click "Model Access."

You will be redirected to a page where you can select the models to enable.

Note:

1. Avoid enabling high-cost models to prevent accidental high charges due to incorrect usage.
2. As of Nov 2024, a cost-effective option is the Llama-3.2 1B model, with model ID: ``meta.llama3-2-1b-instruct-v1:0`` in the ``us-east-1`` region.
3. AWS tags certain models with `inferenceTypesSupported` = `INFERENCE_PROFILE` and in UI it might appear with a tooltip as `This model can only be used through an inference profile.` In such cases you may need to use the Model ARN: ``arn:aws:bedrock:us-east-1:306093656765:inference-profile/us.meta.llama3-2-1b-instruct-v1:0`` in the model ID field when using Adalflow.
4. Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the ``.env`` file. Mention exact key names in ``.env`` file for example access key id is ``AWS_ACCESS_KEY_ID``

.. code-block:: python

import adalflow as adal
import os

# Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the .env file
adal.setup_env()
model_client = adal.BedrockAPIClient()
model_client.list_models()

Which ever profile is tagged with ``INFERENCE_PROFILE`` you might need to provide ``Model ARN`` in ``model`` filed of ``model_kwargs``

References
----------

1. You can refer to Model IDs or Model ARNs `here <https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/models>`_. Clicking on a model card provides additional information.
2. Internally, Adalflow's AWS client uses the `Converse API <https://boto3.amazonaws.com/v1/documentation/api/1.35.8/reference/services/bedrock-runtime/client/converse.html>`_ for each conversation.
Loading
Loading