Skip to content

Commit

Permalink
add ability to control version of astream events API (#760)
Browse files Browse the repository at this point in the history
add ability to control version of astream events API server side
  • Loading branch information
eyurtsev authored Sep 12, 2024
1 parent 59b3c81 commit 43683b3
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 36 deletions.
6 changes: 5 additions & 1 deletion langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
) -> None:
"""Create an API handler for the given runnable.
Expand Down Expand Up @@ -595,6 +596,8 @@ def __init__(
If not provided, then all logs will be allowed to be streamed.
Use to also limit the events that can be streamed by the stream_events.
TODO: Introduce deprecation for this parameter to rename it
astream_events_version: version of the stream events endpoint to use.
By default "v2".
"""
if importlib.util.find_spec("sse_starlette") is None:
raise ImportError(
Expand Down Expand Up @@ -632,6 +635,7 @@ def __init__(
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
self._astream_events_version = astream_events_version

if token_feedback_config:
if len(token_feedback_config["key_configs"]) != 1:
Expand Down Expand Up @@ -1343,7 +1347,7 @@ async def _stream_events() -> AsyncIterator[dict]:
exclude_names=stream_events_request.exclude_names,
exclude_types=stream_events_request.exclude_types,
exclude_tags=stream_events_request.exclude_tags,
version="v1",
version=self._astream_events_version,
):
if (
self._names_in_stream_allow_list is None
Expand Down
22 changes: 17 additions & 5 deletions langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def _log_error_message_once(error_message: str) -> None:
logger.error(error_message)


@lru_cache(maxsize=1_000) # Will accommodate up to 1_000 different error messages
def _log_info_message_once(error_message: str) -> None:
"""Log an error message once."""
logger.info(error_message)


def _sanitize_request(request: httpx.Request) -> httpx.Request:
"""Remove sensitive headers from the request."""
accept_headers = {
Expand Down Expand Up @@ -752,7 +758,7 @@ async def astream_events(
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1"],
version: Literal["v1", "v2", None] = None,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
Expand All @@ -775,21 +781,27 @@ async def astream_events(
input: The input to the runnable
config: The config to use for the runnable
version: The version of the astream_events to use.
Currently only "v1" is supported.
Currently, this input is IGNORED on the client.
The server will return whatever format it's configured with.
include_names: The names of the events to include
include_types: The types of the events to include
include_tags: The tags of the events to include
exclude_names: The names of the events to exclude
exclude_types: The types of the events to exclude
exclude_tags: The tags of the events to exclude
"""
if version != "v1":
raise ValueError(f"Unsupported version: {version}. Use 'v1'")

# Create a stream handler that will emit Log objects
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)

if version is not None:
_log_info_message_once(
"Versioning of the astream_events API is not supported on the client "
"side currently. The server will return events in whatever format "
"it was configured with in add_routes or APIHandler. "
"To stop seeing this message, remove the `version` argument."
)

events = []

run_manager = await callback_manager.on_chain_start(
Expand Down
11 changes: 7 additions & 4 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def add_routes(
enabled_endpoints: Optional[Sequence[EndpointName]] = None,
dependencies: Optional[Sequence[Depends]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
Expand Down Expand Up @@ -380,14 +381,16 @@ def add_routes(
- chat: UX is optimized for chat-like interactions. Please review
the README in langserve for more details about constraints (e.g.,
which message types are supported etc.)
astream_events_version: version of the stream events endpoint to use.
By default "v2".
""" # noqa: E501
if not isinstance(runnable, Runnable):
raise TypeError(
f"Expected a Runnable, got {type(runnable)}. "
f"The second argument to add_routes should be a Runnable instance."
f"add_route(app, runnable, ...) is the correct usage."
f"Please make sure that you are using a runnable which is an instance of "
f"langchain_core.runnables.Runnable."
"The second argument to add_routes should be a Runnable instance."
"add_route(app, runnable, ...) is the correct usage."
"Please make sure that you are using a runnable which is an instance of "
"langchain_core.runnables.Runnable."
)

endpoint_configuration = _EndpointConfiguration(
Expand Down
45 changes: 19 additions & 26 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,7 +2306,7 @@ def mul_two(y: int) -> int:
# test client side error
with pytest.raises(httpx.HTTPStatusError) as cb:
# Invalid input type (expected string but got int)
async for _ in runnable.astream_events("foo", version="v1"):
async for _ in runnable.astream_events("foo", version="v2"):
pass

# Verify that this is a 422 error
Expand All @@ -2315,7 +2315,7 @@ def mul_two(y: int) -> int:
with pytest.raises(httpx.HTTPStatusError) as cb:
# Invalid input type (expected string but got int)
# include names should not be a list of lists
async for _ in runnable.astream_events(1, include_names=[[]], version="v1"):
async for _ in runnable.astream_events(1, include_names=[[]], version="v2"):
pass

# Verify that this is a 422 error
Expand All @@ -2324,7 +2324,7 @@ def mul_two(y: int) -> int:
# Test good requests
events = []

async for event in runnable.astream_events(1, version="v1"):
async for event in runnable.astream_events(1, version="v2"):
events.append(event)

# validate events
Expand All @@ -2337,6 +2337,7 @@ def mul_two(y: int) -> int:
assert not k.startswith("__")
assert "metadata" in event
del event["metadata"]
event["parent_ids"] = []

assert events == [
{
Expand Down Expand Up @@ -2416,6 +2417,7 @@ def _clean_up_events(events: List[Dict[str, Any]]) -> None:
assert not k.startswith("__")
assert "metadata" in event
del event["metadata"]
event["parent_ids"] = []


async def test_astream_events_with_serialization(
Expand Down Expand Up @@ -2488,7 +2490,7 @@ def back_to_serializable(inputs) -> str:
app, raise_app_exceptions=False, path="/doc_types"
) as runnable:
# Test good requests
events = [event async for event in runnable.astream_events("foo", version="v1")]
events = [event async for event in runnable.astream_events("foo", version="v2")]
_clean_up_events(events)

assert events == [
Expand Down Expand Up @@ -2578,7 +2580,7 @@ def back_to_serializable(inputs) -> str:
app, raise_app_exceptions=False, path="/get_pets"
) as runnable:
# Test good requests
events = [event async for event in runnable.astream_events("foo", version="v1")]
events = [event async for event in runnable.astream_events("foo", version="v2")]
_clean_up_events(events)
assert events == [
{
Expand Down Expand Up @@ -2613,7 +2615,7 @@ def back_to_serializable(inputs) -> str:
) as runnable:
# Test good requests
with pytest.raises(httpx.HTTPStatusError) as cb:
async for event in runnable.astream_events("foo", version="v1"):
async for event in runnable.astream_events("foo", version="v2"):
pass
assert cb.value.response.status_code == 500

Expand Down Expand Up @@ -2641,7 +2643,7 @@ async def test_astream_events_with_prompt_model_parser_chain(
events = [
event
async for event in runnable.astream_events(
{"question": "hello"}, version="v1"
{"question": "hello"}, version="v2"
)
]
_clean_up_events(events)
Expand Down Expand Up @@ -2850,25 +2852,16 @@ async def test_astream_events_with_prompt_model_parser_chain(
]
},
"output": {
"generations": [
[
{
"generation_info": None,
"message": {
"additional_kwargs": {},
"content": "Hello World!",
"name": None,
"response_metadata": {},
"type": "AIMessageChunk",
},
"text": "Hello World!",
"type": "ChatGenerationChunk",
}
]
],
"llm_output": None,
"run": None,
"type": "LLMResult",
"additional_kwargs": {},
"content": "Hello World!",
"example": False,
"invalid_tool_calls": [],
"name": None,
"response_metadata": {},
"tool_call_chunks": [],
"tool_calls": [],
"type": "AIMessageChunk",
"usage_metadata": None,
},
},
"event": "on_chat_model_end",
Expand Down

0 comments on commit 43683b3

Please sign in to comment.