Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump langchain, update api #20

Merged
merged 5 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@
import asyncio
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Iterator, List, Optional, Sequence, Union
from typing import (
Any,
AsyncIterator,
Iterator,
List,
Optional,
Sequence,
Union,
)
from urllib.parse import urljoin

import httpx
from langchain.callbacks.tracers.log_stream import RunLogPatch
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.load.dump import dumpd
from langchain.schema.runnable import Runnable
from langchain.schema.runnable.config import (
Expand Down Expand Up @@ -359,15 +367,17 @@ async def astream_log(
input: Input,
config: Optional[RunnableConfig] = None,
*,
diff: bool = False,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[RunLogPatch]:
) -> Union[AsyncIterator[RunLogPatch], AsyncIterator[RunLog]]:
"""Stream all output from a runnable, as reported to the callback system.

This includes all inner runs of LLMs, Retrievers, Tools, etc.

Output is streamed as Log objects, which include a list of
Expand All @@ -392,6 +402,7 @@ async def astream_log(
"input": simple_dumpd(input),
"config": _without_callbacks(config),
"kwargs": kwargs,
"diff": diff,
"include_names": include_names,
"include_types": include_types,
"include_tags": include_tags,
Expand All @@ -413,11 +424,18 @@ async def astream_log(
async for sse in event_source.aiter_sse():
if sse.event == "data":
data = simple_loads(sse.data)
chunk = RunLogPatch(*data["ops"])
if diff:
chunk = RunLogPatch(*data["ops"])
else:
chunk = RunLog(*data["ops"], state=data["state"])

yield chunk

if final_output:
final_output += chunk
if diff:
if final_output:
final_output += chunk
else:
final_output = chunk
else:
final_output = chunk
elif sse.event == "end":
Expand Down
32 changes: 24 additions & 8 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Union,
)

from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.load.serializable import Serializable
from langchain.schema.runnable import Runnable
from typing_extensions import Annotated
Expand Down Expand Up @@ -38,12 +39,7 @@
def _unpack_config(d: Union[BaseModel, Mapping], keys: Sequence[str]) -> Dict[str, Any]:
"""Project the given keys from the given dict."""
_d = d.dict() if isinstance(d, BaseModel) else d
new_keys = list(keys)

if "configurable" not in new_keys:
new_keys.append("configurable")

return {k: _d[k] for k in new_keys if k in _d}
return {k: _d[k] for k in keys if k in _d}


class InvokeResponse(BaseModel):
Expand Down Expand Up @@ -256,9 +252,10 @@ async def stream_log(

async def _stream_log() -> AsyncIterator[dict]:
"""Stream the output of the runnable."""
async for run_log_patch in runnable.astream_log(
async for chunk in runnable.astream_log(
input_,
config=config,
diff=request.diff,
include_names=request.include_names,
include_types=request.include_types,
include_tags=request.include_tags,
Expand All @@ -267,9 +264,28 @@ async def _stream_log() -> AsyncIterator[dict]:
exclude_tags=request.exclude_tags,
**request.kwargs,
):
if request.diff: # Run log patch
if not isinstance(chunk, RunLogPatch):
raise AssertionError(
f"Expected a RunLog instance got {type(chunk)}"
)
data = {
"ops": chunk.ops,
}
else:
# Then it's a run log
if not isinstance(chunk, RunLog):
raise AssertionError(
f"Expected a RunLog instance got {type(chunk)}"
)
data = {
"state": chunk.state,
"ops": chunk.ops,
}

# Temporary adapter
yield {
"data": simple_dumps({"ops": run_log_patch.ops}),
"data": simple_dumps(data),
"event": "data",
}
yield {"event": "end"}
Expand Down
1 change: 1 addition & 0 deletions langserve/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def create_stream_log_request_model(
f"{namespace}StreamLogRequest",
input=(input_type, ...),
config=(config, Field(default_factory=dict)),
diff=(Optional[bool], False),
include_names=(Optional[Sequence[str]], None),
include_types=(Optional[Sequence[str]], None),
include_tags=(Optional[Sequence[str]], None),
Expand Down
Loading
Loading