Skip to content

Commit

Permalink
Add more docs, disable **kwargs (#26)
Browse files Browse the repository at this point in the history
* Disable kwargs (since we're not yet doing validation on them)
* Add more in code docs
  • Loading branch information
eyurtsev authored Oct 11, 2023
1 parent 5506675 commit d2dc4a9
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 21 deletions.
7 changes: 6 additions & 1 deletion langserve/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@


class WellKnownLCObject(BaseModel):
"""A well known LangChain object."""
"""A well known LangChain object.
A pydantic model that defines what constitutes a well known LangChain object.
All well-known objects are allowed to be serialized and de-serialized.
"""

__root__: Union[
Document,
Expand Down
6 changes: 2 additions & 4 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ async def invoke(
config = _unpack_config(invoke_request.config, config_keys)
_add_tracing_info_to_metadata(config, request)
output = await runnable.ainvoke(
_unpack_input(invoke_request.input), config=config, **invoke_request.kwargs
_unpack_input(invoke_request.input), config=config
)

return InvokeResponse(output=simple_dumpd(output))
Expand All @@ -254,7 +254,7 @@ async def batch(
config = _unpack_config(batch_request.config, config_keys)
_add_tracing_info_to_metadata(config, request)
inputs = [_unpack_input(input_) for input_ in batch_request.inputs]
output = await runnable.abatch(inputs, config=config, **batch_request.kwargs)
output = await runnable.abatch(inputs, config=config)

return BatchResponse(output=simple_dumpd(output))

Expand Down Expand Up @@ -306,7 +306,6 @@ async def _stream() -> AsyncIterator[dict]:
async for chunk in runnable.astream(
input_,
config=config,
**stream_request.kwargs,
):
yield {"data": simple_dumps(chunk), "event": "data"}
yield {"event": "end"}
Expand Down Expand Up @@ -369,7 +368,6 @@ async def _stream_log() -> AsyncIterator[dict]:
exclude_names=stream_log_request.exclude_names,
exclude_types=stream_log_request.exclude_types,
exclude_tags=stream_log_request.exclude_tags,
**stream_log_request.kwargs,
):
if stream_log_request.diff: # Run log patch
if not isinstance(chunk, RunLogPatch):
Expand Down
122 changes: 106 additions & 16 deletions langserve/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,24 @@ def create_invoke_request_model(
"""Create a pydantic model for the invoke request."""
invoke_request_type = create_model(
f"{namespace}InvokeRequest",
input=(input_type, ...),
config=(config, Field(default_factory=dict)),
kwargs=(dict, Field(default_factory=dict)),
input=(input_type, Field(..., description="The input to the runnable.")),
config=(
config,
Field(
default_factory=dict,
description=(
"Subset of RunnableConfig object in LangChain. "
"Useful for passing information like tags, metadata etc."
),
),
),
kwargs=(
dict,
Field(
default_factory=dict,
description="Keyword arguments to the runnable. Currently ignored.",
),
),
)
invoke_request_type.update_forward_refs()
return invoke_request_type
Expand All @@ -56,9 +71,24 @@ def create_stream_request_model(
"""Create a pydantic model for the stream request."""
stream_request_model = create_model(
f"{namespace}StreamRequest",
input=(input_type, ...),
config=(config, Field(default_factory=dict)),
kwargs=(dict, Field(default_factory=dict)),
input=(input_type, Field(..., description="The input to the runnable.")),
config=(
config,
Field(
default_factory=dict,
description=(
"Subset of RunnableConfig object in LangChain. "
"Useful for passing information like tags, metadata etc."
),
),
),
kwargs=(
dict,
Field(
default_factory=dict,
description="Keyword arguments to the runnable. Currently ignored.",
),
),
)
stream_request_model.update_forward_refs()
return stream_request_model
Expand All @@ -73,8 +103,24 @@ def create_batch_request_model(
batch_request_type = create_model(
f"{namespace}BatchRequest",
inputs=(List[input_type], ...),
config=(Union[config, List[config]], Field(default_factory=dict)),
kwargs=(dict, Field(default_factory=dict)),
config=(
Union[config, List[config]],
Field(
default_factory=dict,
description=(
"Subset of RunnableConfig object in LangChain. Either specify one "
"config for all inputs or a list of configs with one per input. "
"Useful for passing information like tags, metadata etc."
),
),
),
kwargs=(
dict,
Field(
default_factory=dict,
description="Keyword arguments to the runnable. Currently ignored.",
),
),
)
batch_request_type.update_forward_refs()
return batch_request_type
Expand All @@ -91,12 +137,48 @@ def create_stream_log_request_model(
input=(input_type, ...),
config=(config, Field(default_factory=dict)),
diff=(Optional[bool], False),
include_names=(Optional[Sequence[str]], None),
include_types=(Optional[Sequence[str]], None),
include_tags=(Optional[Sequence[str]], None),
exclude_names=(Optional[Sequence[str]], None),
exclude_types=(Optional[Sequence[str]], None),
exclude_tags=(Optional[Sequence[str]], None),
include_names=(
Optional[Sequence[str]],
Field(
None,
description="If specified, filter to runnables with matching names",
),
),
include_types=(
Optional[Sequence[str]],
Field(
None,
description="If specified, filter to runnables with matching types",
),
),
include_tags=(
Optional[Sequence[str]],
Field(
None,
description="If specified, filter to runnables with matching tags",
),
),
exclude_names=(
Optional[Sequence[str]],
Field(
None,
description="If specified, exclude runnables with matching names",
),
),
exclude_types=(
Optional[Sequence[str]],
Field(
None,
description="If specified, exclude runnables with matching types",
),
),
exclude_tags=(
Optional[Sequence[str]],
Field(
None,
description="If specified, exclude runnables with matching tags",
),
),
kwargs=(dict, Field(default_factory=dict)),
)
stream_log_request.update_forward_refs()
Expand All @@ -112,7 +194,7 @@ def create_invoke_response_model(
# other information can be added to the response at a later date.
invoke_response_type = create_model(
f"{namespace}InvokeResponse",
output=(output_type, ...),
output=(output_type, Field(..., description="The output of the invocation.")),
)
invoke_response_type.update_forward_refs()
return invoke_response_type
Expand All @@ -127,7 +209,15 @@ def create_batch_response_model(
# other information can be added to the response at a later date.
batch_response_type = create_model(
f"{namespace}BatchResponse",
output=(List[output_type], ...),
output=(
List[output_type],
Field(
...,
description=(
"The outputs corresponding to the inputs the batch request."
),
),
),
)
batch_response_type.update_forward_refs()
return batch_response_type

0 comments on commit d2dc4a9

Please sign in to comment.