Skip to content

Commit

Permalink
fix(utils): get_bedrock_client region and profile
Browse files Browse the repository at this point in the history
Update utils.bedrock.get_bedrock_client() function so that an
explicitly passed `region` parameter takes precedence over defaults
from the environment variables - in keeping with CLI/boto3 norms.
Also fix the function so it actually uses the `AWS_PROFILE`, and
creates the initial `boto3.Session` with the same region and profile
as the final client.

Add standard license header, docstring comments, and lint.
  • Loading branch information
athewsey committed Jul 26, 2023
1 parent 47cc886 commit e626481
Showing 1 changed file with 53 additions and 29 deletions.
82 changes: 53 additions & 29 deletions utils/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,81 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
"""Helper utilities for working with Amazon Bedrock from Python notebooks"""
# Python Built-Ins:
import json
import boto3
import os
from typing import Any, Dict, List, Optional
from pydantic import root_validator
from time import sleep
from enum import Enum
from typing import Dict, Optional

# External Dependencies:
import boto3
from pydantic import root_validator
from enum import Enum
from botocore.config import Config

def get_bedrock_client(assumed_role=None, region='us-east-1', url_override = None):
boto3_kwargs = {}
session = boto3.Session()

target_region = os.environ.get('AWS_DEFAULT_REGION',region)
def get_bedrock_client(
assumed_role: Optional[str] = None,
region: Optional[str] = None,
url_override: Optional[str] = None,
):
"""Create a boto3 client for Amazon Bedrock, with optional configuration overrides
Parameters
----------
assumed_role :
Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not
specified, the current active credentials will be used.
region :
Optional name of the AWS Region in which the service should be called (e.g. "us-east-1").
If not specified, AWS_REGION or AWS_DEFAULT_REGION environment variable will be used.
url_override :
Optional override for the Bedrock service API Endpoint. If setting this, it should usually
include the protocol i.e. "https://..."
"""
if region is None:
target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION"))
else:
target_region = region

print(f"Create new client\n Using region: {target_region}")
if 'AWS_PROFILE' in os.environ:
print(f" Using profile: {os.environ['AWS_PROFILE']}")
boto3_kwargs = {"region_name": target_region}

profile_name = os.environ.get("AWS_PROFILE")
if profile_name:
print(f" Using profile: {profile_name}")
boto3_kwargs["profile_name"] = profile_name

retry_config = Config(
region_name = target_region,
retries = {
'max_attempts': 10,
'mode': 'standard'
}
region_name=target_region,
retries={
"max_attempts": 10,
"mode": "standard",
},
)

boto3_kwargs = {}
session = boto3.Session(**boto3_kwargs)

if assumed_role:
print(f" Using role: {assumed_role}", end='')
sts = session.client("sts")
response = sts.assume_role(
RoleArn=str(assumed_role), #
RoleArn=str(assumed_role),
RoleSessionName="langchain-llm-1"
)
print(" ... successful!")
boto3_kwargs['aws_access_key_id']=response['Credentials']['AccessKeyId']
boto3_kwargs['aws_secret_access_key']=response['Credentials']['SecretAccessKey']
boto3_kwargs['aws_session_token']=response['Credentials']['SessionToken']
boto3_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"]
boto3_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"]
boto3_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"]

if url_override:
boto3_kwargs['endpoint_url']=url_override
boto3_kwargs["endpoint_url"] = url_override

bedrock_client = session.client(
service_name='bedrock',
service_name="bedrock",
config=retry_config,
region_name= target_region,
**boto3_kwargs
)
)

print("boto3 Bedrock client successfully created!")
print(bedrock_client._endpoint)
return bedrock_client
Expand Down Expand Up @@ -82,7 +108,7 @@ def validate_environment(cls, values: Dict) -> Dict:
bedrock_client = get_bedrock_client(assumed_role=None) #boto3.client("bedrock")
values["client"] = bedrock_client
return values

def generate_image(self, prompt: str, init_image: Optional[str] = None, **kwargs):
"""
Invoke Bedrock model to generate embeddings.
Expand Down Expand Up @@ -181,5 +207,3 @@ def _invoke_model(self, model_id: BedrockModel, body_string: str):
sleep(self.__RETRY_BACKOFF_SEC)
continue
return response


0 comments on commit e626481

Please sign in to comment.