diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 54c709c0..5f4281df 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -160,13 +160,13 @@ def validate(self, chat_request: ChatRequest): detail=error, ) - def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): + def _invoke_bedrock(self, chat_request: ChatRequest, headers: dict, stream=False): """Common logic for invoke bedrock models""" if DEBUG: logger.info("Raw request: " + chat_request.model_dump_json()) # convert OpenAI chat request to Bedrock SDK request - args = self._parse_request(chat_request) + args = self._parse_request(chat_request, headers) if DEBUG: logger.info("Bedrock request: " + json.dumps(str(args))) @@ -183,11 +183,11 @@ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): raise HTTPException(status_code=500, detail=str(e)) return response - def chat(self, chat_request: ChatRequest) -> ChatResponse: + def chat(self, chat_request: ChatRequest, headers: dict) -> ChatResponse: """Default implementation for Chat API.""" message_id = self.generate_message_id() - response = self._invoke_bedrock(chat_request) + response = self._invoke_bedrock(chat_request, headers) output_message = response["output"]["message"] input_tokens = response["usage"]["inputTokens"] @@ -206,9 +206,9 @@ def chat(self, chat_request: ChatRequest) -> ChatResponse: logger.info("Proxy response :" + chat_response.model_dump_json()) return chat_response - def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + def chat_stream(self, chat_request: ChatRequest, headers: dict) -> AsyncIterable[bytes]: """Default implementation for Chat Stream API""" - response = self._invoke_bedrock(chat_request, stream=True) + response = self._invoke_bedrock(chat_request, headers=headers, stream=True) message_id = self.generate_message_id() stream = response.get("stream") @@ -390,7 +390,7 @@ def _reframe_multi_payloard(self, messages: list) -> list: return reformatted_messages - def _parse_request(self, chat_request: ChatRequest) -> dict: + def _parse_request(self, chat_request: ChatRequest, headers: dict) -> dict: """Create default converse request body. Also perform validations to tool call etc. @@ -420,6 +420,13 @@ def _parse_request(self, chat_request: ChatRequest) -> dict: "system": system_prompts, "inferenceConfig": inference_config, } + + # Pass headers through as metadata + args["requestMetadata"] = { + "X-Header-Value-0": header.get("X-Header-Value-0", "unknown"), + "X-Header-Value-1": header.get("X-Header-Value-1", "unknown"), + } + # add tool config if chat_request.tools: args["toolConfig"] = { diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 1e48a483..75e96b05 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -1,4 +1,5 @@ from typing import Annotated +import logging from fastapi import APIRouter, Depends, Body from fastapi.responses import StreamingResponse @@ -8,6 +9,8 @@ from api.schema import ChatRequest, ChatResponse, ChatStreamResponse from api.setting import DEFAULT_MODEL +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/chat", dependencies=[Depends(api_key_auth)], @@ -17,6 +20,7 @@ @router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True) async def chat_completions( + request: Request, chat_request: Annotated[ ChatRequest, Body( @@ -32,6 +36,10 @@ async def chat_completions( ), ] ): + # Log headers for security and analytics + headers = request.headers + logger.info(f"Headers: {headers}") + if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL @@ -40,6 +48,6 @@ async def chat_completions( model.validate(chat_request) if chat_request.stream: return StreamingResponse( - content=model.chat_stream(chat_request), media_type="text/event-stream" + content=model.chat_stream(chat_request, headers), media_type="text/event-stream" ) - return model.chat(chat_request) + return model.chat(chat_request, headers)