Skip to content

Commit

Permalink
feat: added predefined code to truncate prompt errors (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Sep 4, 2024
1 parent 72de392 commit 664101b
Show file tree
Hide file tree
Showing 16 changed files with 192 additions and 263 deletions.
27 changes: 10 additions & 17 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Awaitable, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from fastapi.responses import Response
from openai import APIConnectionError, APIStatusError, APITimeoutError

from aidial_adapter_openai.completions import chat_completion as completion
Expand Down Expand Up @@ -34,7 +35,6 @@
chat_completion as mistral_chat_completion,
)
from aidial_adapter_openai.utils.auth import get_credentials
from aidial_adapter_openai.utils.errors import dial_exception_to_json_error
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.parsers import (
Expand Down Expand Up @@ -75,16 +75,16 @@ async def handle_exceptions(call: Awaitable[T]) -> T | Response:
)
except APITimeoutError:
raise DialException(
"Request timed out",
504,
"timeout",
status_code=504,
type="timeout",
message="Request timed out",
display_message="Request timed out. Please try again later.",
)
except APIConnectionError:
raise DialException(
"Error communicating with OpenAI",
502,
"connection",
status_code=502,
type="connection",
message="Error communicating with OpenAI",
display_message="OpenAI server is not responsive. Please try again later.",
)

Expand All @@ -94,11 +94,7 @@ def get_api_version(request: Request):
api_version = API_VERSIONS_MAPPING.get(api_version, api_version)

if api_version == "":
raise DialException(
"api-version is a required query parameter",
400,
"invalid_request_error",
)
raise InvalidRequestError("api-version is a required query parameter")

return api_version

Expand Down Expand Up @@ -220,10 +216,7 @@ async def embedding(deployment_id: str, request: Request):

@app.exception_handler(DialException)
def exception_handler(request: Request, exc: DialException):
return JSONResponse(
status_code=exc.status_code,
content=dial_exception_to_json_error(exc),
)
return exc.to_fastapi_response()


@app.get("/health")
Expand Down
14 changes: 3 additions & 11 deletions aidial_adapter_openai/completions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from openai import AsyncStream
from openai.types import Completion
Expand Down Expand Up @@ -51,21 +51,13 @@ async def chat_completion(
) -> Any:

if data.get("n", 1) > 1: # type: ignore
raise DialException(
status_code=422,
message="The deployment doesn't support n > 1",
type="invalid_request_error",
)
raise RequestValidationError("The deployment doesn't support n > 1")

client = endpoint.get_client({**creds, "api_version": api_version})

messages = data.get("messages", [])
if len(messages) == 0:
raise DialException(
status_code=422,
message="The request doesn't contain any messages",
type="invalid_request_error",
)
raise RequestValidationError("The request doesn't contain any messages")

prompt = messages[-1].get("content") or ""

Expand Down
46 changes: 19 additions & 27 deletions aidial_adapter_openai/dalle3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Any, AsyncIterator, Optional

import aiohttp
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.utils.errors import json_error
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse

from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
Expand Down Expand Up @@ -42,15 +42,13 @@ async def generate_image(
]:
error["code"] = "content_filter"

return JSONResponse(
content=json_error(
message=error.get("message"),
type=error.get("type"),
param=error.get("param"),
code=error.get("code"),
),
return DIALException(
status_code=status_code,
)
message=error.get("message"),
type=error.get("type"),
param=error.get("param"),
code=error.get("code"),
).to_fastapi_response()
else:
return JSONResponse(content=data, status_code=status_code)

Expand All @@ -75,18 +73,16 @@ async def generate_stream(
yield build_chunk(id, "stop", {}, created, True, usage=IMG_USAGE)


def get_user_prompt(data: Any):
if (
"messages" not in data
or len(data["messages"]) == 0
or "content" not in data["messages"][-1]
or not data["messages"][-1]
):
raise DialException(
"Your request is invalid", 400, "invalid_request_error"
)

return data["messages"][-1]["content"]
def get_user_prompt(data: Any) -> str:
try:
prompt = data["messages"][-1]["content"]
if not isinstance(prompt, str):
raise ValueError("Content isn't a string")
return prompt
except Exception as e:
raise RequestValidationError(
"Invalid request. Expected a string at path 'messages[-1].content'."
) from e


async def move_attachments_data_to_storage(
Expand Down Expand Up @@ -117,11 +113,7 @@ async def chat_completion(
api_version: str,
) -> Response:
if data.get("n", 1) > 1:
raise DialException(
status_code=422,
message="The deployment doesn't support n > 1",
type="invalid_request_error",
)
raise RequestValidationError("The deployment doesn't support n > 1")

api_url = f"{upstream_endpoint}?api-version={api_version}"
user_prompt = get_user_prompt(data)
Expand Down
14 changes: 5 additions & 9 deletions aidial_adapter_openai/gpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from fastapi.responses import StreamingResponse
from openai import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
Expand All @@ -14,7 +14,7 @@
generate_stream,
map_stream,
)
from aidial_adapter_openai.utils.tokens import Tokenizer, discard_messages
from aidial_adapter_openai.utils.tokens import Tokenizer, truncate_prompt


async def gpt_chat_completion(
Expand All @@ -29,20 +29,16 @@ async def gpt_chat_completion(
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
if not isinstance(max_prompt_tokens, int):
raise DialException(
raise InvalidRequestError(
f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'",
400,
"invalid_request_error",
)
if max_prompt_tokens < 1:
raise DialException(
raise InvalidRequestError(
f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'",
400,
"invalid_request_error",
)
del data["max_prompt_tokens"]

data["messages"], discarded_messages = discard_messages(
data["messages"], discarded_messages = truncate_prompt(
tokenizer, data["messages"], max_prompt_tokens
)

Expand Down
20 changes: 4 additions & 16 deletions aidial_adapter_openai/gpt4_multi_modal/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import aiohttp
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError, RequestValidationError
from fastapi.responses import JSONResponse, Response, StreamingResponse

from aidial_adapter_openai.gpt4_multi_modal.download import (
Expand Down Expand Up @@ -171,19 +172,11 @@ async def chat_completion(
) -> Response:

if request.get("n", 1) > 1:
raise DialException(
status_code=422,
message="The deployment doesn't support n > 1",
type="invalid_request_error",
)
raise RequestValidationError("The deployment doesn't support n > 1")

messages: List[Any] = request["messages"]
if len(messages) == 0:
raise DialException(
status_code=422,
message="The request doesn't contain any messages",
type="invalid_request_error",
)
raise RequestValidationError("The request doesn't contain any messages")

api_url = f"{upstream_endpoint}?api-version={api_version}"

Expand All @@ -194,12 +187,7 @@ async def chat_completion(

chunk = create_stage_chunk("Usage", USAGE, is_stream)

exc = DialException(
status_code=400,
message=result,
display_message=result,
type="invalid_request_error",
)
exc = InvalidRequestError(message=result, display_message=result)

return create_response_from_chunk(chunk, exc, is_stream)

Expand Down
6 changes: 1 addition & 5 deletions aidial_adapter_openai/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,7 @@ async def get_api_key() -> str:
logger.error(
f"Default Azure credential failed with the error: {e.message}"
)
raise DialException(
"Authentication failed",
401,
"Unauthorized",
)
raise DialException("Authentication failed", 401, "Unauthorized")

return access_token.token

Expand Down
12 changes: 0 additions & 12 deletions aidial_adapter_openai/utils/errors.py

This file was deleted.

16 changes: 5 additions & 11 deletions aidial_adapter_openai/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from json import JSONDecodeError
from typing import Any, Dict, List, TypedDict

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError
from fastapi import Request
from openai import AsyncAzureOpenAI, AsyncOpenAI, Timeout
from pydantic import BaseModel
Expand Down Expand Up @@ -80,9 +80,7 @@ class EndpointParser(BaseModel):
def parse(self, endpoint: str) -> AzureOpenAIEndpoint | OpenAIEndpoint:
if result := _parse_endpoint(self.name, endpoint):
return result
raise DialException(
"Invalid upstream endpoint format", 400, "invalid_request_error"
)
raise InvalidRequestError("Invalid upstream endpoint format")


class CompletionsParser(BaseModel):
Expand All @@ -104,16 +102,12 @@ async def parse_body(request: Request) -> Dict[str, Any]:
try:
data = await request.json()
except JSONDecodeError as e:
raise DialException(
"Your request contained invalid JSON: " + str(e),
400,
"invalid_request_error",
raise InvalidRequestError(
"Your request contained invalid JSON: " + str(e)
)

if not isinstance(data, dict):
raise DialException(
str(data) + " is not of type 'object'", 400, "invalid_request_error"
)
raise InvalidRequestError(str(data) + " is not of type 'object'")

return data

Expand Down
8 changes: 3 additions & 5 deletions aidial_adapter_openai/utils/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
from typing import Any, Callable, Coroutine, TypeVar

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import InvalidRequestError


@functools.lru_cache(maxsize=64)
Expand All @@ -27,10 +27,8 @@ async def call_with_extra_body(
extra_args = actual_args - expected_args

if extra_args and "extra_body" not in expected_args:
raise DialException(
f"Extra arguments aren't supported: {extra_args}.",
400,
"invalid_request_error",
raise InvalidRequestError(
f"Extra arguments aren't supported: {extra_args}."
)

arg["extra_body"] = arg.get("extra_body") or {}
Expand Down
16 changes: 6 additions & 10 deletions aidial_adapter_openai/utils/sse_stream.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Any, AsyncIterator, Mapping

from aidial_sdk.utils.errors import json_error
from aidial_sdk.exceptions import runtime_server_error

DATA_PREFIX = "data: "
OPENAI_END_MARKER = "[DONE]"
Expand All @@ -24,18 +24,16 @@ async def parse_openai_sse_stream(
try:
payload = line.decode("utf-8-sig").lstrip() # type: ignore
except Exception:
yield json_error(
message="Can't decode chunk to a string", type="runtime_error"
)
yield runtime_server_error(
"Can't decode chunk to a string"
).json_error()
return

if payload.strip() == "":
continue

if not payload.startswith(DATA_PREFIX):
yield json_error(
message="Invalid chunk format", type="runtime_error"
)
yield runtime_server_error("Invalid chunk format").json_error()
return

payload = payload[len(DATA_PREFIX) :]
Expand All @@ -46,9 +44,7 @@ async def parse_openai_sse_stream(
try:
chunk = json.loads(payload)
except json.JSONDecodeError:
yield json_error(
message="Can't parse chunk to JSON", type="runtime_error"
)
yield runtime_server_error("Can't parse chunk to JSON").json_error()
return

yield chunk
Expand Down
Loading

0 comments on commit 664101b

Please sign in to comment.