Skip to content

Commit

Permalink
Add input/output/config endpoints to server and test out configurabil…
Browse files Browse the repository at this point in the history
…ity of runnables (#9)

Code is in a working state. Potentially some minor kinks. Dependent on
features on master, so should not be landed prior to master being
released
  • Loading branch information
eyurtsev authored Oct 4, 2023
1 parent 48baf90 commit f973c6a
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 108 deletions.
86 changes: 73 additions & 13 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,19 @@
Union,
)

from langchain.load.serializable import Serializable
from langchain.schema.runnable import Runnable
from typing_extensions import Annotated

try:
from pydantic.v1 import BaseModel, create_model
except ImportError:
from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model

from langserve.serialization import simple_dumpd, simple_dumps
from langserve.validation import (
create_batch_request_model,
create_invoke_request_model,
create_runnable_config_model,
create_stream_log_request_model,
create_stream_request_model,
)
Expand All @@ -35,9 +35,10 @@
APIRouter = FastAPI = Any


def _project_dict(d: Mapping, keys: Sequence[str]) -> Dict[str, Any]:
def _unpack_config(d: Union[BaseModel, Mapping], keys: Sequence[str]) -> Dict[str, Any]:
"""Project the given keys from the given dict."""
return {k: d[k] for k in keys if k in d}
_d = d.dict() if isinstance(d, BaseModel) else d
return {k: _d[k] for k in keys if k in _d}


class InvokeResponse(BaseModel):
Expand Down Expand Up @@ -69,9 +70,18 @@ class BatchResponse(BaseModel):
def _unpack_input(validated_model: BaseModel) -> Any:
"""Unpack the decoded input from the validated model."""
if hasattr(validated_model, "__root__"):
return validated_model.__root__
model = validated_model.__root__
else:
return validated_model
model = validated_model

if isinstance(model, BaseModel) and not isinstance(model, Serializable):
# If the model is a pydantic model, but not a Serializable, then
# it was created by the server as part of validation and isn't expected
# to be accepted by the runnables as input as a pydantic model,
# instead we need to convert it into a corresponding python dict.
return model.dict()

return model


_MODEL_REGISTRY = {}
Expand All @@ -91,6 +101,39 @@ def _resolve_input_type(input_type: Union[Type, BaseModel]) -> BaseModel:
return _MODEL_REGISTRY[hash_]


def _add_namespace_to_model(namespace: str, model: Type[BaseModel]) -> Type[BaseModel]:
"""Create a unique name for the given model.
Args:
namespace: The namespace to use for the model.
model: The model to create a unique name for.
Returns:
A new model with name prepended with the given namespace.
"""

class Config:
arbitrary_types_allowed = True

model_with_unique_name = create_model(
f"{namespace}{model.__name__}",
config=Config,
**{
name: (
field.annotation,
Field(
field.default,
title=name,
description=field.field_info.description,
),
)
for name, field in model.__fields__.items()
},
)
model_with_unique_name.update_forward_refs()
return model_with_unique_name


# PUBLIC API


Expand Down Expand Up @@ -132,10 +175,12 @@ def add_routes(

model_namespace = path.strip("/").replace("/", "_")

config = create_runnable_config_model(model_namespace, config_keys)
config = _add_namespace_to_model(
model_namespace, runnable.config_schema(include=config_keys)
)

InvokeRequest = create_invoke_request_model(model_namespace, input_type_, config)
BatchRequest = create_batch_request_model(model_namespace, input_type_, config)
# Stream request is the same as invoke request, but with a different response type
StreamRequest = create_stream_request_model(model_namespace, input_type_, config)
StreamLogRequest = create_stream_log_request_model(
model_namespace, input_type_, config
Expand All @@ -152,7 +197,7 @@ async def invoke(
# Request is first validated using InvokeRequest which takes into account
# config_keys as well as input_type.
# After validation, the input is loaded using LangChain's load function.
config = _project_dict(request.config, config_keys)
config = _unpack_config(request.config, config_keys)
output = await runnable.ainvoke(
_unpack_input(request.input), config=config, **request.kwargs
)
Expand All @@ -164,9 +209,9 @@ async def invoke(
async def batch(request: Annotated[BatchRequest, BatchRequest]) -> BatchResponse:
"""Invoke the runnable with the given inputs and config."""
if isinstance(request.config, list):
config = [_project_dict(config, config_keys) for config in request.config]
config = [_unpack_config(config, config_keys) for config in request.config]
else:
config = _project_dict(request.config, config_keys)
config = _unpack_config(request.config, config_keys)
inputs = [_unpack_input(input_) for input_ in request.inputs]
output = await runnable.abatch(inputs, config=config, **request.kwargs)

Expand All @@ -181,7 +226,7 @@ async def stream(
# config_keys as well as input_type.
# After validation, the input is loaded using LangChain's load function.
input_ = _unpack_input(request.input)
config = _project_dict(request.config, config_keys)
config = _unpack_config(request.config, config_keys)

async def _stream() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
Expand All @@ -204,7 +249,7 @@ async def stream_log(
# config_keys as well as input_type.
# After validation, the input is loaded using LangChain's load function.
input_ = _unpack_input(request.input)
config = _project_dict(request.config, config_keys)
config = _unpack_config(request.config, config_keys)

async def _stream_log() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
Expand All @@ -227,3 +272,18 @@ async def _stream_log() -> AsyncIterator[dict]:
yield {"event": "end"}

return EventSourceResponse(_stream_log())

@app.get(f"{namespace}/input_schema")
async def input_schema() -> Any:
"""Return the input schema of the runnable."""
return runnable.input_schema.schema()

@app.get(f"{namespace}/output_schema")
async def output_schema() -> Any:
"""Return the input schema of the runnable."""
return runnable.output_schema.schema()

@app.get(f"{namespace}/config_schema")
async def config_schema() -> Any:
"""Return the input schema of the runnable."""
return runnable.config_schema(include=config_keys).schema()
25 changes: 2 additions & 23 deletions langserve/validation.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,18 @@
from typing import List, Optional, Sequence, Type, Union

from langchain.schema.runnable import RunnableConfig
from typing import List, Optional, Sequence, Union

try:
from pydantic.v1 import BaseModel, Field, create_model
except ImportError:
from pydantic import BaseModel, Field, create_model

from typing_extensions import TypedDict
from typing_extensions import Type, TypedDict

InputValidator = Union[Type[BaseModel], type]
# The following langchain objects are considered to be safe to load.

# PUBLIC API


def create_runnable_config_model(
ns: str, config_keys: Sequence[str]
) -> type(TypedDict):
"""Create a projection of the runnable config type.
Args:
ns: The namespace of the runnable config type.
config_keys: The keys to include in the projection.
"""
subset_dict = {}
for key in config_keys:
if key in RunnableConfig.__annotations__:
subset_dict[key] = RunnableConfig.__annotations__[key]
else:
raise AssertionError(f"Key {key} not in RunnableConfig.")

return TypedDict(f"{ns}RunnableConfig", subset_dict, total=False)


def create_invoke_request_model(
namespace: str,
input_type: InputValidator,
Expand Down
69 changes: 13 additions & 56 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fastapi = {version = ">=0.90.1", optional = true}
sse-starlette = {version = "^1.3.0", optional = true}
httpx-sse = {version = ">=0.3.1", optional = true}
pydantic = "^1"
langchain = ">=0.0.306"
langchain = ">=0.0.307"

[tool.poetry.group.dev.dependencies]
jupyterlab = "^3.6.1"
Expand Down
Loading

0 comments on commit f973c6a

Please sign in to comment.