Skip to content

Support for Bedrock Agents #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/api/app.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ async def health():
return {"status": "OK"}



@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return PlainTextResponse(str(exc), status_code=400)
132 changes: 69 additions & 63 deletions src/api/models/bedrock.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import time
from abc import ABC
from typing import AsyncIterable, Iterable, Literal
from api.models.model_manager import ModelManager

import boto3
import numpy as np
@@ -75,83 +76,88 @@ def get_inference_region_prefix():

ENCODER = tiktoken.get_encoding("cl100k_base")

# Initialize the model list.
#bedrock_model_list = list_bedrock_models()

def list_bedrock_models() -> dict:
"""Automatically getting a list of supported models.
Returns a model list combines:
- ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env)
"""
model_list = {}
try:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(
maxResults=1000,
typeEquals='SYSTEM_DEFINED'
)
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]

# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(
byOutputModality='TEXT'
)

for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
status = model['modelLifecycle'].get('status', 'ACTIVE')

# currently, use this to filter out rerank models and legacy models
if not stream_supported or status != "ACTIVE":
continue

inference_types = model.get('inferenceTypesSupported', [])
input_modalities = model['inputModalities']
# Add on-demand model list
if 'ON_DEMAND' in inference_types:
model_list[model_id] = {
'modalities': input_modalities
}

# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
if profile_id in profile_list:
model_list[profile_id] = {
'modalities': input_modalities
}

except Exception as e:
logger.error(f"Unable to list models: {str(e)}")
class BedrockModel(BaseChatModel):

if not model_list:
# In case stack not updated.
model_list[DEFAULT_MODEL] = {
'modalities': ["TEXT", "IMAGE"]
}
#bedrock_model_list = None
model_manager = None
def __init__(self):
super().__init__()
self.model_manager = ModelManager()

return model_list
def list_bedrock_models(self) -> dict:
"""Automatically getting a list of supported models.
Returns a model list combines:
- ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env)
"""
#model_list = {}
try:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(
maxResults=1000,
typeEquals='SYSTEM_DEFINED'
)
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]

# Initialize the model list.
bedrock_model_list = list_bedrock_models()
# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(
byOutputModality='TEXT'
)

for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
status = model['modelLifecycle'].get('status', 'ACTIVE')

# currently, use this to filter out rerank models and legacy models
if not stream_supported or status != "ACTIVE":
continue

inference_types = model.get('inferenceTypesSupported', [])
input_modalities = model['inputModalities']
# Add on-demand model list
if 'ON_DEMAND' in inference_types:
model[model_id] = {
'modalities': input_modalities
}
self.model_manager.add_model(model)
# model_list[model_id] = {
# 'modalities': input_modalities
# }

# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
if profile_id in profile_list:
model[profile_id] = {
'modalities': input_modalities
}
self.model_manager.add_model(model)

class BedrockModel(BaseChatModel):
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))

def list_models(self) -> list[str]:
"""Always refresh the latest model list"""
global bedrock_model_list
bedrock_model_list = list_bedrock_models()
return list(bedrock_model_list.keys())
#global bedrock_model_list
self.list_bedrock_models()
return list(self.model_manager.get_all_models().keys())

def validate(self, chat_request: ChatRequest):
"""Perform basic validation on requests"""

error = ""

###### TODO - failing here as kb and agents are not in the bedrock_model_list
# check if model is supported
if chat_request.model not in bedrock_model_list.keys():
if chat_request.model not in self.model_manager.get_all_models().keys():
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"

if error:
@@ -659,7 +665,7 @@ def _parse_content_parts(

@staticmethod
def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
model = bedrock_model_list.get(model_id)
model = ModelManager().models.get(model_id)
modalities = model.get('modalities', [])
if modality in modalities:
return True
@@ -851,4 +857,4 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
raise HTTPException(
status_code=400,
detail="Unsupported embedding model id " + model_id,
)
)
391 changes: 391 additions & 0 deletions src/api/models/bedrock_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
import base64
import json
import logging
import re
import time
from abc import ABC
from typing import AsyncIterable

import boto3
from botocore.config import Config
import numpy as np
import requests
import tiktoken
from fastapi import HTTPException
from api.models.model_manager import ModelManager

from api.models.bedrock import (
BedrockModel,
bedrock_client,
bedrock_runtime)

from api.schema import (
ChatResponse,
ChatRequest,
ChatResponseMessage,
ChatStreamResponse,
ChoiceDelta
)

from api.setting import (DEBUG, AWS_REGION, DEFAULT_KB_MODEL, KB_PREFIX, AGENT_PREFIX)

logger = logging.getLogger(__name__)
config = Config(connect_timeout=1, read_timeout=120, retries={"max_attempts": 1})

bedrock_agent = boto3.client(
service_name="bedrock-agent",
region_name=AWS_REGION,
config=config,
)

bedrock_agent_runtime = boto3.client(
service_name="bedrock-agent-runtime",
region_name=AWS_REGION,
config=config,
)


class BedrockAgents(BedrockModel):

#bedrock_model_list = None
def __init__(self):
super().__init__()
model_manager = ModelManager()

def list_models(self) -> list[str]:
"""Always refresh the latest model list"""
super().list_models()
self.get_kbs()
self.get_agents()
return list(self.model_manager.get_all_models().keys())

# get list of active knowledge bases
def get_kbs(self):

# List knowledge bases
response = bedrock_agent.list_knowledge_bases(maxResults=100)

# Print knowledge base information
for kb in response['knowledgeBaseSummaries']:
name = f"{KB_PREFIX}{kb['name']}"
val = {
"system": True, # Supports system prompts for context setting
"multimodal": True, # Capable of processing both text and images
"tool_call": True,
"stream_tool_call": True,
"kb_id": kb['knowledgeBaseId'],
"model_id": DEFAULT_KB_MODEL
}
#self.model_manager.get_all_models()[name] = val
model = {}
model[name]=val
self.model_manager.add_model(model)

def get_latest_agent_alias(self, client, agent_id):

# List all aliases for the agent
response = client.list_agent_aliases(
agentId=agent_id,
maxResults=100 # Adjust based on your needs
)

if not response.get('agentAliasSummaries'):
return None

# Sort aliases by creation time to get the latest one
aliases = response['agentAliasSummaries']
latest_alias = None
latest_creation_time = None

for alias in aliases:
# Only consider aliases that are in PREPARED state
if alias['agentAliasStatus'] == 'PREPARED':
creation_time = alias.get('creationDateTime')
if latest_creation_time is None or creation_time > latest_creation_time:
latest_creation_time = creation_time
latest_alias = alias

if latest_alias:
return latest_alias['agentAliasId']

return None

def get_agents(self):
bedrock_ag = boto3.client(
service_name="bedrock-agent",
region_name=AWS_REGION,
config=config,
)
# List Agents
response = bedrock_agent.list_agents(maxResults=100)

# Prepare agent for display
for agent in response['agentSummaries']:

if (agent['agentStatus'] != 'PREPARED'):
continue

name = f"{AGENT_PREFIX}{agent['agentName']}"
agentId = agent['agentId']

aliasId = self.get_latest_agent_alias(bedrock_ag, agentId)
if (aliasId is None):
continue

val = {
"system": False, # Supports system prompts for context setting. These are already set in Bedrock Agent configuration
"multimodal": True, # Capable of processing both text and images
"tool_call": False, # Tool Use not required for Agents
"stream_tool_call": False,
"agent_id": agentId,
"alias_id": aliasId
}
#self.model_manager.get_all_models()[name] = val
model = {}
model[name]=val
self.model_manager.add_model(model)


def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
"""Common logic for invoke bedrock models"""

# convert OpenAI chat request to Bedrock SDK request
args = self._parse_request(chat_request)
if DEBUG:
logger.info("Bedrock request: " + json.dumps(str(args)))

try:

if stream:
response = bedrock_runtime.converse_stream(**args)
else:
response = bedrock_runtime.converse(**args)


except bedrock_client.exceptions.ValidationException as e:
logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
return response

def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Default implementation for Chat API."""
#chat: {chat_request}")

message_id = self.generate_message_id()
response = self._invoke_bedrock(chat_request)

output_message = response["output"]["message"]
input_tokens = response["usage"]["inputTokens"]
output_tokens = response["usage"]["outputTokens"]
finish_reason = response["stopReason"]

chat_response = self._create_response(
model=chat_request.model,
message_id=message_id,
content=output_message["content"],
finish_reason=finish_reason,
input_tokens=input_tokens,
output_tokens=output_tokens,
)
if DEBUG:
logger.info("Proxy response :" + chat_response.model_dump_json())
return chat_response

def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:

"""Default implementation for Chat Stream API"""

response = ''
message_id = self.generate_message_id()

if (chat_request.model.startswith(KB_PREFIX)):
response = self._invoke_kb(chat_request, stream=True)
elif (chat_request.model.startswith(AGENT_PREFIX)):
response = self._invoke_agent(chat_request, stream=True)

_event_stream = response["completion"]

chunk_count = 1
message = ChatResponseMessage(
role="assistant",
content="",
)
stream_response = ChatStreamResponse(
id=message_id,
model=chat_request.model,
choices=[
ChoiceDelta(
index=0,
delta=message,
logprobs=None,
finish_reason=None,
)
],
usage=None,
)
yield self.stream_response_to_bytes(stream_response)

for event in _event_stream:
chunk_count += 1
if "chunk" in event:
_data = event["chunk"]["bytes"].decode("utf8")
message = ChatResponseMessage(content=_data)

stream_response = ChatStreamResponse(
id=message_id,
model=chat_request.model,
choices=[
ChoiceDelta(
index=0,
delta=message,
logprobs=None,
finish_reason=None,
)
],
usage=None,
)
yield self.stream_response_to_bytes(stream_response)

#message = self._make_fully_cited_answer(_data, event, False, 0)

# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
return None
else:
response = self._invoke_bedrock(chat_request, stream=True)

stream = response.get("stream")
for chunk in stream:
stream_response = self._create_response_stream(
model_id=chat_request.model, message_id=message_id, chunk=chunk
)
if not stream_response:
continue
if DEBUG:
logger.info("Proxy response :" + stream_response.model_dump_json())
if stream_response.choices:
yield self.stream_response_to_bytes(stream_response)
elif (
chat_request.stream_options
and chat_request.stream_options.include_usage
):
# An empty choices for Usage as per OpenAI doc below:
# if you set stream_options: {"include_usage": true}.
# an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire request,
# and the choices field will always be an empty array.
# All other chunks will also include a usage field, but with a null value.
yield self.stream_response_to_bytes(stream_response)

# return an [DONE] message at the end.
yield self.stream_response_to_bytes()



# This function invokes knowledgebase
def _invoke_kb(self, chat_request: ChatRequest, stream=False):
"""Common logic for invoke kb with default model"""
if DEBUG:
logger.info("BedrockAgents._invoke_kb: Raw request: " + chat_request.model_dump_json())

# convert OpenAI chat request to Bedrock SDK request
args = self._parse_request(chat_request)


if DEBUG:
logger.info("Bedrock request: " + json.dumps(str(args)))

model = self.model_manager.get_all_models()[chat_request.model]
args['modelId'] = model['model_id']


################

try:
query = args['messages'][0]['content'][0]['text']
messages = args['messages']
query = messages[len(messages)-1]['content'][0]['text']

# Step 1 - Retrieve Context
retrieval_request_body = {
"retrievalQuery": {
"text": query
},
"retrievalConfiguration": {
"vectorSearchConfiguration": {
"numberOfResults": 2
}
}
}

# Make the retrieve request
response = bedrock_agent_runtime.retrieve(knowledgeBaseId=model['kb_id'], **retrieval_request_body)

# Extract and return the results
context = ''
if "retrievalResults" in response:
for result in response["retrievalResults"]:
result = result["content"]["text"]
context = f"{context}\n{result}"


# Step 2 - Append context in the prompt
args['messages'][0]['content'][0]['text'] = f"Context: {context} \n\n {query}"

# Step 3 - Make the converse request
if stream:
response = bedrock_runtime.converse_stream(**args)
else:
response = bedrock_runtime.converse(**args)

except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))

###############
return response

# This function invokes knowledgebase
def _invoke_agent(self, chat_request: ChatRequest, stream=False):
"""Common logic for invoke agent """
if DEBUG:
logger.info("BedrockAgents._invoke_agent: Raw request: " + chat_request.model_dump_json())

# convert OpenAI chat request to Bedrock SDK request
args = self._parse_request(chat_request)


if DEBUG:
logger.info("Bedrock request: " + json.dumps(str(args)))

model = self.model_manager.get_all_models()[chat_request.model]

################

try:
query = args['messages'][0]['content'][0]['text']
messages = args['messages']
query = messages[len(messages)-1]['content'][0]['text']


# Step 1 - Retrieve Context
request_params = {
'agentId': model['agent_id'],
'agentAliasId': model['alias_id'],
'sessionId': 'unique-session-id', # Generate a unique session ID
'inputText': query
}

# Make the retrieve request
# Invoke the agent
response = bedrock_agent_runtime.invoke_agent(**request_params)
return response

except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))


35 changes: 35 additions & 0 deletions src/api/models/model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# This is a singleton class to maintain list of models
class ModelManager:
_instance = None
_models = None

def __new__(cls, *args, **kwargs):
# Ensure that only one instance of ModelManager is created
if cls._instance is None:
cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs)
cls._instance._models = {} # Initialize the list of models

return cls._instance

def get_all_models(self):
return self._models

def add_model(self, model):
"""Add a model to the list."""
if (self._models is None):
self._models = {}
self._models.update(model)


def get_model_by_name(self, model_name: str):
"""Get the list of models."""
return self._models

def clear_models(self):
"""Clear the list of models."""
self._models.clear()
self._models = {}

def __repr__(self):
return f"ModelManager(models={self._models})"

17 changes: 11 additions & 6 deletions src/api/routers/chat.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,15 @@

from fastapi import APIRouter, Depends, Body
from fastapi.responses import StreamingResponse

import logging
from api.auth import api_key_auth
from api.models.bedrock import BedrockModel
#from api.models.bedrock import BedrockModel
from api.models.bedrock_agents import BedrockAgents
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL

logger = logging.getLogger(__name__)

router = APIRouter(
prefix="/chat",
dependencies=[Depends(api_key_auth)],
@@ -32,14 +35,16 @@ async def chat_completions(
),
]
):
# this method gets called by front-end

if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL

# Exception will be raised if model not supported.
model = BedrockModel()
model = BedrockAgents()
model.validate(chat_request)
if chat_request.stream:
return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream"
)
response = StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream")
return response

return model.chat(chat_request)
2 changes: 1 addition & 1 deletion src/api/routers/embeddings.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from fastapi import APIRouter, Depends, Body

from api.auth import api_key_auth
from api.models.bedrock import get_embeddings_model
#from api.models.bedrock import get_embeddings_model
from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL

14 changes: 10 additions & 4 deletions src/api/routers/model.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
from typing import Annotated

import logging
from fastapi import APIRouter, Depends, HTTPException, Path

from api.auth import api_key_auth
from api.models.bedrock import BedrockModel
#from api.models.bedrock import BedrockModel
from api.models.bedrock_agents import BedrockAgents
from api.schema import Models, Model
logger = logging.getLogger(__name__)


router = APIRouter(
prefix="/models",
dependencies=[Depends(api_key_auth)],
# responses={404: {"description": "Not found"}},
)

chat_model = BedrockModel()

#chat_model = BedrockModel()
chat_model = BedrockAgents()

async def validate_model_id(model_id: str):
logger.info(f"validate_model_id: {model_id}")
if model_id not in chat_model.list_models():
raise HTTPException(status_code=500, detail="Unsupported Model Id")


@router.get("", response_model=Models)
async def list_models():

model_list = [
Model(id=model_id) for model_id in chat_model.list_models()
]
@@ -38,5 +43,6 @@ async def get_model(
Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"),
]
):
logger.info(f"get_model: {model_id}")
await validate_model_id(model_id)
return Model(id=model_id)
10 changes: 10 additions & 0 deletions src/api/setting.py
Original file line number Diff line number Diff line change
@@ -20,3 +20,13 @@
"DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3"
)
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"

KB_PREFIX = 'kb-'
AGENT_PREFIX = 'ag-'

DEFAULT_KB_MODEL = os.environ.get(
"DEFAULT_KB_MODEL", "anthropic.claude-3-haiku-20240307-v1:0"
)


DEFAULT_KB_MODEL_ARN = f'arn:aws:bedrock:{AWS_REGION}::foundation-model/{DEFAULT_KB_MODEL}'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you allow this to be overridden? Hardcoding arn:aws: is an issue for us in other partitions.