diff --git a/src/cohere/manually_maintained/cohere_aws/client.py b/src/cohere/manually_maintained/cohere_aws/client.py index 31a9b2ec8..bacb3a97e 100644 --- a/src/cohere/manually_maintained/cohere_aws/client.py +++ b/src/cohere/manually_maintained/cohere_aws/client.py @@ -34,7 +34,9 @@ def __init__( """ self._client = boto3.client("sagemaker-runtime", region_name=aws_region) self._service_client = boto3.client("sagemaker", region_name=aws_region) - self._sess = sage.Session(boto3.session.Session()) + if os.environ.get('AWS_DEFAULT_REGION') is None: + os.environ['AWS_DEFAULT_REGION'] = aws_region + self._sess = sage.Session(sagemaker_client=self._service_client) self.mode = Mode.SAGEMAKER