From 4d80d265974932222f8e8d2fd2cafaf56b4a5128 Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 04:47:00 +0000 Subject: [PATCH 01/12] refactor: remove default arguments --- adalflow/adalflow/components/model_client/bedrock_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index d25b48bc..16032bc1 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -95,8 +95,8 @@ class BedrockAPIClient(ModelClient): def __init__( self, - aws_profile_name="default", - aws_region_name="us-west-2", # Use a supported default region + aws_profile_name=None, + aws_region_name=None, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, From ead370d040a7c0f480f654506c66859861313a57 Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 05:04:54 +0000 Subject: [PATCH 02/12] fix: list_foundation_models() fetch function --- .../adalflow/components/model_client/bedrock_client.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 16032bc1..b0ebc35c 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -180,12 +180,14 @@ def list_models(self): try: response = self._client.list_foundation_models() - models = response.get("models", []) + models = response.get("modelSummaries", []) for model in models: print(f"Model ID: {model['modelId']}") - print(f" Name: {model['name']}") - print(f" Description: {model['description']}") - print(f" Provider: {model['provider']}") + print(f" Name: {model['modelName']}") + print(f" Provider: {model['providerName']}") + print(f" Input: {model['inputModalities']}") + print(f" Output: {model['outputModalities']}") + print(f" InferenceTypesSupported: {model['inferenceTypesSupported']}") print("") except Exception as e: print(f"Error listing models: {e}") From be953d9f36cc3ebcb5ba12d9349573b9ec5a0deb Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 11:49:00 +0000 Subject: [PATCH 03/12] fix: list_models() issue, modelId param, docs: add aws bedrock integration docs --- .../components/model_client/bedrock_client.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index b0ebc35c..051e6852 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -118,6 +118,12 @@ def __init__( self.chat_completion_parser = ( chat_completion_parser or get_first_message_content ) + self.inference_parameters = [ + "maxTokens", + "temperature", + "topP", + "stopSequences", + ] def init_sync_client(self): """ @@ -192,6 +198,47 @@ def list_models(self): except Exception as e: print(f"Error listing models: {e}") + def _validate_and_process_model_id(self, api_kwargs: Dict): + """ + Validate and process the model ID in API kwargs. + + :param api_kwargs: Dictionary of API keyword arguments + :raises KeyError: If 'model' key is missing + """ + if "model" in api_kwargs: + api_kwargs["modelId"] = api_kwargs.pop("model") + else: + raise KeyError("The required key 'model' is missing in model_kwargs.") + + return api_kwargs + + def _separate_parameters(self, api_kwargs: Dict) -> tuple: + """ + Separate inference configuration and additional model request fields. + + :param api_kwargs: Dictionary of API keyword arguments + :return: Tuple of (inference_config, additional_model_request_fields) + """ + inference_config = {} + additional_model_request_fields = {} + keys_to_remove = set() + excluded_keys = {"modelId"} + + # Categorize parameters + for key, value in list(api_kwargs.items()): + if key in self.inference_parameters: + inference_config[key] = value + keys_to_remove.add(key) + elif key not in excluded_keys: + additional_model_request_fields[key] = value + keys_to_remove.add(key) + + # Remove categorized keys from api_kwargs + for key in keys_to_remove: + api_kwargs.pop(key, None) + + return api_kwargs, inference_config, additional_model_request_fields + def convert_inputs_to_api_kwargs( self, input: Optional[Any] = None, @@ -204,9 +251,17 @@ def convert_inputs_to_api_kwargs( """ api_kwargs = model_kwargs.copy() if model_type == ModelType.LLM: + # Validate and process model ID + api_kwargs = self._validate_and_process_model_id(api_kwargs) + + # Separate inference config and additional model request fields + api_kwargs, inference_config, additional_model_request_fields = self._separate_parameters(api_kwargs) + api_kwargs["messages"] = [ {"role": "user", "content": [{"text": input}]}, ] + api_kwargs["inferenceConfig"] = inference_config + api_kwargs["additionalModelRequestFields"] = additional_model_request_fields else: raise ValueError(f"Model type {model_type} not supported") return api_kwargs From 03a48d3cc3cada87dfcfe32d697395a849f35d2b Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 12:12:35 +0000 Subject: [PATCH 04/12] feat: accept kwargs in list_model(), fix: max_tokens parameter --- .../components/model_client/bedrock_client.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 051e6852..84ddd374 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -181,15 +181,16 @@ def track_completion_usage(self, completion: Dict) -> CompletionUsage: total_tokens=usage["totalTokens"], ) - def list_models(self): + def list_models(self, **kwargs): # Initialize Bedrock client (not runtime) try: - response = self._client.list_foundation_models() + response = self._client.list_foundation_models(**kwargs) models = response.get("modelSummaries", []) for model in models: print(f"Model ID: {model['modelId']}") print(f" Name: {model['modelName']}") + print(f" Model ARN: {model['modelArn']}") print(f" Provider: {model['providerName']}") print(f" Input: {model['inputModalities']}") print(f" Output: {model['outputModalities']}") @@ -198,7 +199,7 @@ def list_models(self): except Exception as e: print(f"Error listing models: {e}") - def _validate_and_process_model_id(self, api_kwargs: Dict): + def _validate_and_process_config_keys(self, api_kwargs: Dict): """ Validate and process the model ID in API kwargs. @@ -210,6 +211,10 @@ def _validate_and_process_model_id(self, api_kwargs: Dict): else: raise KeyError("The required key 'model' is missing in model_kwargs.") + # In .converse() `maxTokens`` is the key for maximum tokens limit + if "max_tokens" in api_kwargs: + api_kwargs["maxTokens"] = api_kwargs.pop("max_tokens") + return api_kwargs def _separate_parameters(self, api_kwargs: Dict) -> tuple: @@ -252,7 +257,7 @@ def convert_inputs_to_api_kwargs( api_kwargs = model_kwargs.copy() if model_type == ModelType.LLM: # Validate and process model ID - api_kwargs = self._validate_and_process_model_id(api_kwargs) + api_kwargs = self._validate_and_process_config_keys(api_kwargs) # Separate inference config and additional model request fields api_kwargs, inference_config, additional_model_request_fields = self._separate_parameters(api_kwargs) From 7baa43b3a2912ad782b1ef92c34344fd794e742c Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 12:47:12 +0000 Subject: [PATCH 05/12] feat: add usage tutorial for bedrock client, docs: add documentation for aws bedrock integration --- docs/source/integrations/aws_bedrock.rst | 55 +++++++++++++++++++++++ tutorials/bedrock_client_simple_qa.py | 56 ++++++++++++++++++++++++ 2 files changed, 111 insertions(+) create mode 100644 docs/source/integrations/aws_bedrock.rst create mode 100644 tutorials/bedrock_client_simple_qa.py diff --git a/docs/source/integrations/aws_bedrock.rst b/docs/source/integrations/aws_bedrock.rst new file mode 100644 index 00000000..8c4cfe89 --- /dev/null +++ b/docs/source/integrations/aws_bedrock.rst @@ -0,0 +1,55 @@ +.. _integration-aws-bedrock: + +AWS Bedrock API Client +======================= + +.. admonition:: Author + :class: highlight + + `Ajith Kumar `_ + +Getting Credentials +------------------- + +You need to have an AWS account and an access key and secret key to use AWS Bedrock services. Moreover, the account associated with the access key must have +the necessary permissions to access Bedrock services. Refer to the `AWS documentation `_ for more information on obtaining credentials. + +Enabling Foundation Models +-------------------------- + +AWS Bedrock offers several foundation models from providers like "Meta," "Amazon," "Cohere," "Anthropic," and "Microsoft." To access these models, you need to enable them first. Note that each AWS region supports a specific set of models. Not all foundation models are available in every region, and pricing varies by region. + +Pricing information: `AWS Bedrock Pricing `_ + +Steps for enabling model access: + +1. Select the desired region in the AWS Console (e.g., `us-east-1 (N. Virginia)`). +2. Navigate to the `Bedrock services home page `_. +3. On the left sidebar, under "Bedrock Configuration," click "Model Access." + + You will be redirected to a page where you can select the models to enable. + +Note: + +1. Avoid enabling high-cost models to prevent accidental high charges due to incorrect usage. +2. As of Nov 2024, a cost-effective option is the Llama-3.2 1B model, with model ID: ``meta.llama3-2-1b-instruct-v1:0`` in the ``us-east-1`` region. +3. AWS tags certain models with `inferenceTypesSupported` = `INFERENCE_PROFILE` and in UI it might appear with a tooltip as `This model can only be used through an inference profile.` In such cases you may need to use the Model ARN: ``arn:aws:bedrock:us-east-1:306093656765:inference-profile/us.meta.llama3-2-1b-instruct-v1:0`` in the model ID field when using Adalflow. +4. Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the ``.env`` file. Mention exact key names in ``.env`` file for example access key id is ``AWS_ACCESS_KEY_ID`` + +.. code-block:: python + + import adalflow as adal + import os + + # Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the .env file + adal.setup_env() + model_client = adal.BedrockAPIClient() + model_client.list_models() + +Which ever profile is tagged with ``INFERENCE_PROFILE`` you might need to provide ``Model ARN`` in ``model`` filed of ``model_kwargs`` + +References +---------- + +1. You can refer to Model IDs or Model ARNs `here `_. Clicking on a model card provides additional information. +2. Internally, Adalflow's AWS client uses the `Converse API `_ for each conversation. diff --git a/tutorials/bedrock_client_simple_qa.py b/tutorials/bedrock_client_simple_qa.py new file mode 100644 index 00000000..271c7b03 --- /dev/null +++ b/tutorials/bedrock_client_simple_qa.py @@ -0,0 +1,56 @@ +import os + +from adalflow.components.model_client import BedrockAPIClient +from adalflow.core.types import ModelType +from adalflow.utils import setup_env + + +def list_models(): + # For list of models + model_client = BedrockAPIClient() + model_client.list_models(byProvider="meta") + + +def bedrock_chat_conversation(): + # Initialize the Bedrock client for API interactions + awsbedrock_client = BedrockAPIClient() + query = "What is the capital of France?" + + # Embed the prompt in Llama 3's instruction format. + formatted_prompt = f""" + <|begin_of_text|><|start_header_id|>user<|end_header_id|> + {query} + <|eot_id|> + <|start_header_id|>assistant<|end_header_id|> + """ + + # Set the model type to Large Language Model (LLM) + model_type = ModelType.LLM + + # Configure model parameters: + # - model: Specifies Llama-3-2 1B as the model to use + # - temperature: Controls randomness (0.5 = balanced between deterministic and creative) + # - max_tokens: Limits the response length to 100 tokens + + # Using Model ARN since its has inference_profile in us-east-1 region + # https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=meta.llama3-2-1b-instruct-v1:0 + model_id = "arn:aws:bedrock:us-east-1:306093656765:inference-profile/us.meta.llama3-2-1b-instruct-v1:0" + model_kwargs = {"model": model_id, "temperature": 0.5, "max_tokens": 100} + + # Convert the inputs into the format required by BedRock's API + api_kwargs = awsbedrock_client.convert_inputs_to_api_kwargs( + input=formatted_prompt, model_kwargs=model_kwargs, model_type=model_type + ) + print(f"api_kwargs: {api_kwargs}") + + response = awsbedrock_client.call(api_kwargs=api_kwargs, model_type=model_type) + + # Extract the text from the chat completion response + response_text = awsbedrock_client.parse_chat_completion(response) + print(f"response_text: {response_text}") + + +if __name__ == "__main__": + setup_env() + list_models() + bedrock_chat_conversation() From 1d9217bfc088e492181c7e78d3eb291ecc6da2f1 Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 12:53:48 +0000 Subject: [PATCH 06/12] fix: aws bedrock client usage --- tutorials/generator_all_providers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tutorials/generator_all_providers.py b/tutorials/generator_all_providers.py index 21d61d08..8de0abce 100644 --- a/tutorials/generator_all_providers.py +++ b/tutorials/generator_all_providers.py @@ -26,10 +26,10 @@ def use_all_providers(): ) # need to run ollama pull llama3.2:1b first to use this model - # aws_bedrock_llm = adal.Generator( - # model_client=adal.BedrockAPIClient(), - # model_kwargs={"modelId": "amazon.mistral.instruct-7b"}, - # ) + aws_bedrock_llm = adal.Generator( + model_client=adal.BedrockAPIClient(), + model_kwargs={"model": "mistral.mistral-7b-instruct-v0:2"}, + ) prompt_kwargs = {"input_str": "What is the meaning of life in one sentence?"} @@ -38,14 +38,14 @@ def use_all_providers(): anthropic_response = anthropic_llm(prompt_kwargs) google_gen_ai_response = google_gen_ai_llm(prompt_kwargs) ollama_response = ollama_llm(prompt_kwargs) - # aws_bedrock_llm_response = aws_bedrock_llm(prompt_kwargs) + aws_bedrock_llm_response = aws_bedrock_llm(prompt_kwargs) print(f"OpenAI: {openai_response}\n") print(f"Groq: {groq_response}\n") print(f"Anthropic: {anthropic_response}\n") print(f"Google GenAI: {google_gen_ai_response}\n") print(f"Ollama: {ollama_response}\n") - # print(f"AWS Bedrock: {aws_bedrock_llm_response}\n") + print(f"AWS Bedrock: {aws_bedrock_llm_response}\n") if __name__ == "__main__": From c428df1452348d43a597b88824a656b17bcf09e6 Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 13:13:09 +0000 Subject: [PATCH 07/12] docs: add comments and links --- .../adalflow/components/model_client/bedrock_client.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 84ddd374..48b8b3f7 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -50,6 +50,7 @@ class BedrockAPIClient(ModelClient): Setup: 1. Install boto3: `pip install boto3` 2. Ensure you have the AWS credentials set up. There are four variables you can optionally set: + Either AWS_PROFILE_NAME or (AWS_REGION_NAME and AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY) are needed - AWS_PROFILE_NAME: The name of the AWS profile to use. - AWS_REGION_NAME: The name of the AWS region to use. - AWS_ACCESS_KEY_ID: The AWS access key ID. @@ -76,10 +77,9 @@ class BedrockAPIClient(ModelClient): self.generator = Generator( model_client=BedrockAPIClient(), model_kwargs={ - "modelId": "anthropic.claude-3-sonnet-20240229-v1:0", - "inferenceConfig": { - "temperature": 0.8 - } + "model": "mistral.mistral-7b-instruct-v0:2", + "temperature": 0.8, + "max_tokens": 100 }, template=template ) @@ -183,6 +183,7 @@ def track_completion_usage(self, completion: Dict) -> CompletionUsage: def list_models(self, **kwargs): # Initialize Bedrock client (not runtime) + # Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_ListFoundationModels.html try: response = self._client.list_foundation_models(**kwargs) From 7e98938a4ed1999cd40f028d5e802908e97ab670 Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 13:31:57 +0000 Subject: [PATCH 08/12] fix: formatting issues --- adalflow/adalflow/components/model_client/bedrock_client.py | 4 +++- docs/source/integrations/aws_bedrock.rst | 6 +++--- tutorials/bedrock_client_simple_qa.py | 2 -- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/adalflow/adalflow/components/model_client/bedrock_client.py b/adalflow/adalflow/components/model_client/bedrock_client.py index 48b8b3f7..fa368aae 100644 --- a/adalflow/adalflow/components/model_client/bedrock_client.py +++ b/adalflow/adalflow/components/model_client/bedrock_client.py @@ -261,7 +261,9 @@ def convert_inputs_to_api_kwargs( api_kwargs = self._validate_and_process_config_keys(api_kwargs) # Separate inference config and additional model request fields - api_kwargs, inference_config, additional_model_request_fields = self._separate_parameters(api_kwargs) + api_kwargs, inference_config, additional_model_request_fields = ( + self._separate_parameters(api_kwargs) + ) api_kwargs["messages"] = [ {"role": "user", "content": [{"text": input}]}, diff --git a/docs/source/integrations/aws_bedrock.rst b/docs/source/integrations/aws_bedrock.rst index 8c4cfe89..2a691a7f 100644 --- a/docs/source/integrations/aws_bedrock.rst +++ b/docs/source/integrations/aws_bedrock.rst @@ -11,7 +11,7 @@ AWS Bedrock API Client Getting Credentials ------------------- -You need to have an AWS account and an access key and secret key to use AWS Bedrock services. Moreover, the account associated with the access key must have +You need to have an AWS account and an access key and secret key to use AWS Bedrock services. Moreover, the account associated with the access key must have the necessary permissions to access Bedrock services. Refer to the `AWS documentation `_ for more information on obtaining credentials. Enabling Foundation Models @@ -32,9 +32,9 @@ Steps for enabling model access: Note: 1. Avoid enabling high-cost models to prevent accidental high charges due to incorrect usage. -2. As of Nov 2024, a cost-effective option is the Llama-3.2 1B model, with model ID: ``meta.llama3-2-1b-instruct-v1:0`` in the ``us-east-1`` region. +2. As of Nov 2024, a cost-effective option is the Llama-3.2 1B model, with model ID: ``meta.llama3-2-1b-instruct-v1:0`` in the ``us-east-1`` region. 3. AWS tags certain models with `inferenceTypesSupported` = `INFERENCE_PROFILE` and in UI it might appear with a tooltip as `This model can only be used through an inference profile.` In such cases you may need to use the Model ARN: ``arn:aws:bedrock:us-east-1:306093656765:inference-profile/us.meta.llama3-2-1b-instruct-v1:0`` in the model ID field when using Adalflow. -4. Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the ``.env`` file. Mention exact key names in ``.env`` file for example access key id is ``AWS_ACCESS_KEY_ID`` +4. Ensure (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME) or AWS_DEFAULT_PROFILE is set in the ``.env`` file. Mention exact key names in ``.env`` file for example access key id is ``AWS_ACCESS_KEY_ID`` .. code-block:: python diff --git a/tutorials/bedrock_client_simple_qa.py b/tutorials/bedrock_client_simple_qa.py index 271c7b03..7d75c3fa 100644 --- a/tutorials/bedrock_client_simple_qa.py +++ b/tutorials/bedrock_client_simple_qa.py @@ -1,5 +1,3 @@ -import os - from adalflow.components.model_client import BedrockAPIClient from adalflow.core.types import ModelType from adalflow.utils import setup_env From 8f23132e71c729bac17b6f92ce72bc243594155f Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 14:42:56 +0000 Subject: [PATCH 09/12] test: add aws bedrock unit test --- adalflow/tests/test_aws_bedrock_client.py | 82 +++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 adalflow/tests/test_aws_bedrock_client.py diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py new file mode 100644 index 00000000..fa67c16b --- /dev/null +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -0,0 +1,82 @@ +import unittest +from unittest.mock import patch, Mock + +# use the openai for mocking standard data types +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion + +from adalflow.core.types import ModelType, GeneratorOutput +from adalflow.components.model_client import BedrockAPIClient + + +def getenv_side_effect(key): + # This dictionary can hold more keys and values as needed + env_vars = { + "AWS_ACCESS_KEY_ID": "fake_api_key", + "AWS_SECRET_ACCESS_KEY": "fake_api_key", + "AWS_REGION_NAME": "fake_api_key", + } + return env_vars.get(key, None) # Returns None if key is not found + + +# modified from test_openai_client.py +class TestBedrockClient(unittest.TestCase): + def setUp(self): + self.client = BedrockAPIClient() + self.mock_response = { + "ResponseMetadata": { + "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "date": "Sat, 30 Nov 2024 14:27:44 GMT", + "content-type": "application/json", + "content-length": "273", + "connection": "keep-alive", + "x-amzn-requestid": "43aec10a-9780-4bd5-abcc-857d12460569", + }, + "RetryAttempts": 0, + }, + "output": { + "message": {"role": "assistant", "content": [{"text": "Hello, world!"}]} + }, + "stopReason": "end_turn", + "usage": {"inputTokens": 20, "outputTokens": 10, "totalTokens": 30}, + "metrics": {"latencyMs": 430}, + } + + self.api_kwargs = { + "messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-3.5-turbo", + } + + @patch.object(BedrockAPIClient, "init_sync_client") + @patch("adalflow.components.model_client.bedrock_client.boto3") + def test_call(self, MockBedrock, mock_init_sync_client): + mock_sync_client = Mock() + MockBedrock.return_value = mock_sync_client + mock_init_sync_client.return_value = mock_sync_client + + # Mock the client's api: converse + mock_sync_client.converse = Mock(return_value=self.mock_response) + + # Set the sync client + self.client.sync_client = mock_sync_client + + # Call the call method + result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM) + + # Assertions + mock_sync_client.converse.assert_called_once_with(**self.api_kwargs) + self.assertEqual(result, self.mock_response) + + # test parse_chat_completion + output = self.client.parse_chat_completion(completion=self.mock_response) + self.assertTrue(isinstance(output, GeneratorOutput)) + self.assertEqual(output.raw_response, "Hello, world!") + self.assertEqual(output.usage.prompt_tokens, 20) + self.assertEqual(output.usage.completion_tokens, 10) + self.assertEqual(output.usage.total_tokens, 30) + + +if __name__ == "__main__": + unittest.main() From 6efabade19ba8c764af19652c6af36ad14f1cf1b Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 14:54:52 +0000 Subject: [PATCH 10/12] fix: boto3 to test dep, remove unnecesary imports in bedrock test --- adalflow/pyproject.toml | 1 + adalflow/tests/test_aws_bedrock_client.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/adalflow/pyproject.toml b/adalflow/pyproject.toml index 8b1a68f6..a732e23f 100644 --- a/adalflow/pyproject.toml +++ b/adalflow/pyproject.toml @@ -78,6 +78,7 @@ openai = "^1.12.0" groq = "^0.9.0" google-generativeai = "^0.7.2" anthropic = "^0.31.1" +boto3 = "^1.35.19" [tool.poetry.group.typing.dependencies] mypy = "^1" diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py index fa67c16b..d5d8e9e1 100644 --- a/adalflow/tests/test_aws_bedrock_client.py +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -1,10 +1,6 @@ import unittest from unittest.mock import patch, Mock -# use the openai for mocking standard data types -from openai.types import CompletionUsage -from openai.types.chat import ChatCompletion - from adalflow.core.types import ModelType, GeneratorOutput from adalflow.components.model_client import BedrockAPIClient From 6edab4cc75a51f3916682de542a5542d56ed22f0 Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 16:05:57 +0000 Subject: [PATCH 11/12] fix: poetry lock file, bedrock test client --- adalflow/poetry.lock | 22 +++++++++++----------- adalflow/tests/test_aws_bedrock_client.py | 19 +++++++++++++++---- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/adalflow/poetry.lock b/adalflow/poetry.lock index d2ee1af5..fe4f0aad 100644 --- a/adalflow/poetry.lock +++ b/adalflow/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -286,17 +286,17 @@ files = [ [[package]] name = "boto3" -version = "1.35.34" +version = "1.35.71" description = "The AWS SDK for Python" -optional = true +optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.34-py3-none-any.whl", hash = "sha256:291e7b97a34967ed93297e6171f1bebb8529e64633dd48426760e3fdef1cdea8"}, - {file = "boto3-1.35.34.tar.gz", hash = "sha256:57e6ee8504e7929bc094bb2afc879943906064179a1e88c23b4812e2c6f61532"}, + {file = "boto3-1.35.71-py3-none-any.whl", hash = "sha256:e2969a246bb3208122b3c349c49cc6604c6fc3fc2b2f65d99d3e8ccd745b0c16"}, + {file = "boto3-1.35.71.tar.gz", hash = "sha256:3ed7172b3d4fceb6218bb0ec3668c4d40c03690939c2fca4f22bb875d741a07f"}, ] [package.dependencies] -botocore = ">=1.35.34,<1.36.0" +botocore = ">=1.35.71,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -305,13 +305,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.34" +version = "1.35.71" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.34-py3-none-any.whl", hash = "sha256:ccb0fe397b11b81c9abc0c87029d17298e17bf658d8db5c0c5a551a12a207e7a"}, - {file = "botocore-1.35.34.tar.gz", hash = "sha256:789b6501a3bb4a9591c1fe10da200cc315c1fa5df5ada19c720d8ef06439b3e3"}, + {file = "botocore-1.35.71-py3-none-any.whl", hash = "sha256:fc46e7ab1df3cef66dfba1633f4da77c75e07365b36f03bd64a3793634be8fc1"}, + {file = "botocore-1.35.71.tar.gz", hash = "sha256:f9fa058e0393660c3fe53c1e044751beb64b586def0bd2212448a7c328b0cbba"}, ] [package.dependencies] @@ -3455,7 +3455,7 @@ pyasn1 = ">=0.1.3" name = "s3transfer" version = "0.10.2" description = "An Amazon S3 Transfer Manager" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "s3transfer-0.10.2-py3-none-any.whl", hash = "sha256:eca1c20de70a39daee580aef4986996620f365c4e0fda6a86100231d62f1bf69"}, @@ -4376,4 +4376,4 @@ torch = ["torch"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <4.0" -content-hash = "0a58b82476e4794adbc5768d38911e1935a9c3941cc40499001199383fc8c8ec" +content-hash = "bcf8b09220dafd0d401432f998acab2135350af664d1ca30b5ccfda364d7491b" diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py index d5d8e9e1..20cd3328 100644 --- a/adalflow/tests/test_aws_bedrock_client.py +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -1,16 +1,18 @@ +import os import unittest from unittest.mock import patch, Mock from adalflow.core.types import ModelType, GeneratorOutput from adalflow.components.model_client import BedrockAPIClient +from adalflow import setup_env def getenv_side_effect(key): # This dictionary can hold more keys and values as needed env_vars = { - "AWS_ACCESS_KEY_ID": "fake_api_key", - "AWS_SECRET_ACCESS_KEY": "fake_api_key", - "AWS_REGION_NAME": "fake_api_key", + "AWS_ACCESS_KEY_ID": "WQIWQORE2RK63VGZJKGF", + "AWS_SECRET_ACCESS_KEY": "aW1dDWQKdR/Sx3fI39N6ycoAYTjj3vsPSuN44ebU", + "AWS_REGION_NAME": "us-east-1", } return env_vars.get(key, None) # Returns None if key is not found @@ -18,7 +20,16 @@ def getenv_side_effect(key): # modified from test_openai_client.py class TestBedrockClient(unittest.TestCase): def setUp(self): - self.client = BedrockAPIClient() + # Patch os.environ to ensure all environment variables are set + with patch.dict(os.environ, { + "AWS_ACCESS_KEY_ID": "fake_api_key", + "AWS_SECRET_ACCESS_KEY": "fake_api_key", + "AWS_REGION_NAME": "fake_api_key", + "AWS_PROFILE_NAME": "fake_profile" # Adding additional profile if needed + }): + # Now patch os.getenv to return mocked environment variable values + with patch("os.getenv", side_effect=getenv_side_effect): + self.client = BedrockAPIClient() self.mock_response = { "ResponseMetadata": { "RequestId": "43aec10a-9780-4bd5-abcc-857d12460569", From 0ba9498e886fa3dad21b99fa1337de34aafa0d18 Mon Sep 17 00:00:00 2001 From: ajithvcoder Date: Sat, 30 Nov 2024 16:20:47 +0000 Subject: [PATCH 12/12] fix: formatting in bedrock test --- adalflow/tests/test_aws_bedrock_client.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/adalflow/tests/test_aws_bedrock_client.py b/adalflow/tests/test_aws_bedrock_client.py index 20cd3328..1ab857e0 100644 --- a/adalflow/tests/test_aws_bedrock_client.py +++ b/adalflow/tests/test_aws_bedrock_client.py @@ -4,7 +4,6 @@ from adalflow.core.types import ModelType, GeneratorOutput from adalflow.components.model_client import BedrockAPIClient -from adalflow import setup_env def getenv_side_effect(key): @@ -21,12 +20,15 @@ def getenv_side_effect(key): class TestBedrockClient(unittest.TestCase): def setUp(self): # Patch os.environ to ensure all environment variables are set - with patch.dict(os.environ, { - "AWS_ACCESS_KEY_ID": "fake_api_key", - "AWS_SECRET_ACCESS_KEY": "fake_api_key", - "AWS_REGION_NAME": "fake_api_key", - "AWS_PROFILE_NAME": "fake_profile" # Adding additional profile if needed - }): + with patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "fake_api_key", + "AWS_SECRET_ACCESS_KEY": "fake_api_key", + "AWS_REGION_NAME": "fake_api_key", + "AWS_PROFILE_NAME": "fake_profile", # Adding additional profile if needed + }, + ): # Now patch os.getenv to return mocked environment variable values with patch("os.getenv", side_effect=getenv_side_effect): self.client = BedrockAPIClient()