Skip to content

Commit

Permalink
Add output type schema to InvokeResponse and BatchResponse (#23)
Browse files Browse the repository at this point in the history
Add output type information to invoke and batch responses
  • Loading branch information
eyurtsev authored Oct 10, 2023
1 parent 54c89a2 commit 90ec2e1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 37 deletions.
63 changes: 26 additions & 37 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Any,
AsyncIterator,
Dict,
List,
Literal,
Mapping,
Sequence,
Expand All @@ -24,7 +23,9 @@
from langserve.serialization import simple_dumpd, simple_dumps
from langserve.validation import (
create_batch_request_model,
create_batch_response_model,
create_invoke_request_model,
create_invoke_response_model,
create_stream_log_request_model,
create_stream_request_model,
)
Expand All @@ -42,32 +43,6 @@ def _unpack_config(d: Union[BaseModel, Mapping], keys: Sequence[str]) -> Dict[st
return {k: _d[k] for k in keys if k in _d}


class InvokeResponse(BaseModel):
"""Response from invoking a runnable.
A container is used to allow adding additional fields in the future.
"""

output: Any
"""The output of the runnable.
An object that can be serialized to JSON using LangChain serialization.
"""


class BatchResponse(BaseModel):
"""Response from batch invoking runnables.
A container is used to allow adding additional fields in the future.
"""

output: List[Any]
"""The output of the runnable.
An object that can be serialized to JSON using LangChain serialization.
"""


def _unpack_input(validated_model: BaseModel) -> Any:
"""Unpack the decoded input from the validated model."""
if hasattr(validated_model, "__root__"):
Expand All @@ -85,19 +60,23 @@ def _unpack_input(validated_model: BaseModel) -> Any:
return model


# This is a global registry of models to avoid creating the same model
# multiple times.
# Duplicated model names break fastapi's openapi generation.
_MODEL_REGISTRY = {}


def _resolve_input_type(input_type: Union[Type, BaseModel]) -> BaseModel:
if isclass(input_type) and issubclass(input_type, BaseModel):
input_type_ = input_type
def _resolve_model(type_: Union[Type, BaseModel], default_name: str) -> BaseModel:
"""Resolve the input type to a BaseModel."""
if isclass(type_) and issubclass(type_, BaseModel):
model = type_
else:
input_type_ = create_model("Input", __root__=(input_type, ...))
model = create_model(default_name, __root__=(type_, ...))

hash_ = input_type_.schema_json()
hash_ = model.schema_json()

if hash_ not in _MODEL_REGISTRY:
_MODEL_REGISTRY[hash_] = input_type_
_MODEL_REGISTRY[hash_] = model

return _MODEL_REGISTRY[hash_]

Expand Down Expand Up @@ -144,6 +123,7 @@ def add_routes(
*,
path: str = "",
input_type: Union[Type, Literal["auto"], BaseModel] = "auto",
output_type: Union[Type, Literal["auto"], BaseModel] = "auto",
config_keys: Sequence[str] = (),
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
Expand All @@ -155,6 +135,9 @@ def add_routes(
input_type: type to use for input validation.
Default is "auto" which will use the InputType of the runnable.
User is free to provide a custom type annotation.
output_type: type to use for output validation.
Default is "auto" which will use the OutputType of the runnable.
User is free to provide a custom type annotation.
config_keys: list of config keys that will be accepted, by default
no config keys are accepted.
"""
Expand All @@ -167,10 +150,13 @@ def add_routes(
"Use `pip install sse_starlette` to install."
)

if input_type == "auto":
input_type_ = _resolve_input_type(runnable.input_schema)
else:
input_type_ = _resolve_input_type(input_type)
input_type_ = _resolve_model(
runnable.input_schema if input_type == "auto" else input_type, "Input"
)

output_type_ = _resolve_model(
runnable.output_schema if output_type == "auto" else output_type, "Output"
)

namespace = path or ""

Expand All @@ -185,6 +171,9 @@ def add_routes(
StreamLogRequest = create_stream_log_request_model(
model_namespace, input_type_, config
)
# Generate the response models
InvokeResponse = create_invoke_response_model(model_namespace, output_type_)
BatchResponse = create_batch_response_model(model_namespace, output_type_)

@app.post(
f"{namespace}/invoke",
Expand Down
26 changes: 26 additions & 0 deletions langserve/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,29 @@ def create_stream_log_request_model(
)
stream_log_request.update_forward_refs()
return stream_log_request


def create_invoke_response_model(
namespace: str,
output_type: InputValidator,
) -> Type[BaseModel]:
"""Create a pydantic model for the invoke response."""
invoke_response_type = create_model(
f"{namespace}InvokeResponse",
output=(output_type, ...),
)
invoke_response_type.update_forward_refs()
return invoke_response_type


def create_batch_response_model(
namespace: str,
output_type: InputValidator,
) -> Type[BaseModel]:
"""Create a pydantic model for the batch response."""
batch_response_type = create_model(
f"{namespace}BatchResponse",
output=(List[output_type], ...),
)
batch_response_type.update_forward_refs()
return batch_response_type

0 comments on commit 90ec2e1

Please sign in to comment.