Skip to content

Commit

Permalink
validate
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos committed Oct 17, 2023
1 parent 4d249e2 commit ab2e809
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
45 changes: 32 additions & 13 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ def _config_from_hash(config_hash: str) -> Dict[str, Any]:


def _unpack_config(
*configs: Union[BaseModel, Mapping, str], keys: Sequence[str]
*configs: Union[BaseModel, Mapping, str],
keys: Sequence[str],
model: Type[BaseModel],
) -> Dict[str, Any]:
"""Merge configs, and project the given keys from the merged dict."""
config_dicts = []
for config in configs:
if isinstance(config, str):
config_dicts.append(_config_from_hash(config))
config_dicts.append(model(**_config_from_hash(config)).dict())
elif isinstance(config, BaseModel):
config_dicts.append(config.dict())
else:
Expand Down Expand Up @@ -232,14 +234,20 @@ def add_routes(

model_namespace = path.strip("/").replace("/", "_")

config = _add_namespace_to_model(
ConfigPayload = _add_namespace_to_model(
model_namespace, runnable.config_schema(include=config_keys)
)
InvokeRequest = create_invoke_request_model(model_namespace, input_type_, config)
BatchRequest = create_batch_request_model(model_namespace, input_type_, config)
StreamRequest = create_stream_request_model(model_namespace, input_type_, config)
InvokeRequest = create_invoke_request_model(
model_namespace, input_type_, ConfigPayload
)
BatchRequest = create_batch_request_model(
model_namespace, input_type_, ConfigPayload
)
StreamRequest = create_stream_request_model(
model_namespace, input_type_, ConfigPayload
)
StreamLogRequest = create_stream_log_request_model(
model_namespace, input_type_, config
model_namespace, input_type_, ConfigPayload
)
# Generate the response models
InvokeResponse = create_invoke_response_model(model_namespace, output_type_)
Expand All @@ -255,7 +263,9 @@ async def invoke(
"""Invoke the runnable with the given input and config."""
# Request is first validated using InvokeRequest which takes into account
# config_keys as well as input_type.
config = _unpack_config(config_hash, invoke_request.config, keys=config_keys)
config = _unpack_config(
config_hash, invoke_request.config, keys=config_keys, model=ConfigPayload
)
_add_tracing_info_to_metadata(config, request)
output = await runnable.ainvoke(
_unpack_input(invoke_request.input), config=config
Expand All @@ -273,14 +283,18 @@ async def batch(
"""Invoke the runnable with the given inputs and config."""
if isinstance(batch_request.config, list):
config = [
_unpack_config(config, keys=config_keys)
_unpack_config(
config_hash, config, keys=config_keys, model=ConfigPayload
)
for config in batch_request.config
]

for c in config:
_add_tracing_info_to_metadata(c, request)
else:
config = _unpack_config(config_hash, batch_request.config, keys=config_keys)
config = _unpack_config(
config_hash, batch_request.config, keys=config_keys, model=ConfigPayload
)
_add_tracing_info_to_metadata(config, request)
inputs = [_unpack_input(input_) for input_ in batch_request.inputs]
output = await runnable.abatch(inputs, config=config)
Expand Down Expand Up @@ -329,7 +343,9 @@ async def stream(
# config_keys as well as input_type.
# After validation, the input is loaded using LangChain's load function.
input_ = _unpack_input(stream_request.input)
config = _unpack_config(config_hash, stream_request.config, keys=config_keys)
config = _unpack_config(
config_hash, stream_request.config, keys=config_keys, model=ConfigPayload
)
_add_tracing_info_to_metadata(config, request)

async def _stream() -> AsyncIterator[dict]:
Expand Down Expand Up @@ -387,7 +403,10 @@ async def stream_log(
# After validation, the input is loaded using LangChain's load function.
input_ = _unpack_input(stream_log_request.input)
config = _unpack_config(
config_hash, stream_log_request.config, keys=config_keys
config_hash,
stream_log_request.config,
keys=config_keys,
model=ConfigPayload,
)
_add_tracing_info_to_metadata(config, request)

Expand Down Expand Up @@ -448,4 +467,4 @@ async def output_schema(config_hash: str = "") -> Any:
@app.get(f"{namespace}/config_schema")
async def config_schema(config_hash: str = "") -> Any:
"""Return the config schema of the runnable."""
return runnable.config_schema(include=config_keys).schema()
return ConfigPayload.schema()
3 changes: 2 additions & 1 deletion tests/unit_tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def test_invoke_request_with_runnables() -> None:
input={"name": "bob"},
).config,
keys=[],
model=config,
)
== {}
)
Expand All @@ -177,6 +178,6 @@ def test_invoke_request_with_runnables() -> None:
"template": "goodbye {name}",
}

assert _unpack_config(request.config, keys=["configurable"]) == {
assert _unpack_config(request.config, keys=["configurable"], model=config) == {
"configurable": {"template": "goodbye {name}"},
}

0 comments on commit ab2e809

Please sign in to comment.