From 43683b36713a7a1345243f4043d2c5ee057b322f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 12 Sep 2024 14:33:21 -0400 Subject: [PATCH] add ability to control version of astream events API (#760) add ability to control version of astream events API server side --- langserve/api_handler.py | 6 +++- langserve/client.py | 22 ++++++++++--- langserve/server.py | 11 ++++--- tests/unit_tests/test_server_client.py | 45 +++++++++++--------------- 4 files changed, 48 insertions(+), 36 deletions(-) diff --git a/langserve/api_handler.py b/langserve/api_handler.py index ed8f974d..437df869 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -534,6 +534,7 @@ def __init__( per_req_config_modifier: Optional[PerRequestConfigModifier] = None, stream_log_name_allow_list: Optional[Sequence[str]] = None, playground_type: Literal["default", "chat"] = "default", + astream_events_version: Literal["v1", "v2"] = "v2", ) -> None: """Create an API handler for the given runnable. @@ -595,6 +596,8 @@ def __init__( If not provided, then all logs will be allowed to be streamed. Use to also limit the events that can be streamed by the stream_events. TODO: Introduce deprecation for this parameter to rename it + astream_events_version: version of the stream events endpoint to use. + By default "v2". """ if importlib.util.find_spec("sse_starlette") is None: raise ImportError( @@ -632,6 +635,7 @@ def __init__( self._enable_feedback_endpoint = enable_feedback_endpoint self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint self._names_in_stream_allow_list = stream_log_name_allow_list + self._astream_events_version = astream_events_version if token_feedback_config: if len(token_feedback_config["key_configs"]) != 1: @@ -1343,7 +1347,7 @@ async def _stream_events() -> AsyncIterator[dict]: exclude_names=stream_events_request.exclude_names, exclude_types=stream_events_request.exclude_types, exclude_tags=stream_events_request.exclude_tags, - version="v1", + version=self._astream_events_version, ): if ( self._names_in_stream_allow_list is None diff --git a/langserve/client.py b/langserve/client.py index 6261290a..55968937 100644 --- a/langserve/client.py +++ b/langserve/client.py @@ -120,6 +120,12 @@ def _log_error_message_once(error_message: str) -> None: logger.error(error_message) +@lru_cache(maxsize=1_000) # Will accommodate up to 1_000 different error messages +def _log_info_message_once(error_message: str) -> None: + """Log an error message once.""" + logger.info(error_message) + + def _sanitize_request(request: httpx.Request) -> httpx.Request: """Remove sensitive headers from the request.""" accept_headers = { @@ -752,7 +758,7 @@ async def astream_events( input: Any, config: Optional[RunnableConfig] = None, *, - version: Literal["v1"], + version: Literal["v1", "v2", None] = None, include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, @@ -775,7 +781,8 @@ async def astream_events( input: The input to the runnable config: The config to use for the runnable version: The version of the astream_events to use. - Currently only "v1" is supported. + Currently, this input is IGNORED on the client. + The server will return whatever format it's configured with. include_names: The names of the events to include include_types: The types of the events to include include_tags: The tags of the events to include @@ -783,13 +790,18 @@ async def astream_events( exclude_types: The types of the events to exclude exclude_tags: The tags of the events to exclude """ - if version != "v1": - raise ValueError(f"Unsupported version: {version}. Use 'v1'") - # Create a stream handler that will emit Log objects config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) + if version is not None: + _log_info_message_once( + "Versioning of the astream_events API is not supported on the client " + "side currently. The server will return events in whatever format " + "it was configured with in add_routes or APIHandler. " + "To stop seeing this message, remove the `version` argument." + ) + events = [] run_manager = await callback_manager.on_chain_start( diff --git a/langserve/server.py b/langserve/server.py index 75762618..1ba35b11 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -262,6 +262,7 @@ def add_routes( enabled_endpoints: Optional[Sequence[EndpointName]] = None, dependencies: Optional[Sequence[Depends]] = None, playground_type: Literal["default", "chat"] = "default", + astream_events_version: Literal["v1", "v2"] = "v2", ) -> None: """Register the routes on the given FastAPI app or APIRouter. @@ -380,14 +381,16 @@ def add_routes( - chat: UX is optimized for chat-like interactions. Please review the README in langserve for more details about constraints (e.g., which message types are supported etc.) + astream_events_version: version of the stream events endpoint to use. + By default "v2". """ # noqa: E501 if not isinstance(runnable, Runnable): raise TypeError( f"Expected a Runnable, got {type(runnable)}. " - f"The second argument to add_routes should be a Runnable instance." - f"add_route(app, runnable, ...) is the correct usage." - f"Please make sure that you are using a runnable which is an instance of " - f"langchain_core.runnables.Runnable." + "The second argument to add_routes should be a Runnable instance." + "add_route(app, runnable, ...) is the correct usage." + "Please make sure that you are using a runnable which is an instance of " + "langchain_core.runnables.Runnable." ) endpoint_configuration = _EndpointConfiguration( diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index 3cadda82..c87a69ee 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -2306,7 +2306,7 @@ def mul_two(y: int) -> int: # test client side error with pytest.raises(httpx.HTTPStatusError) as cb: # Invalid input type (expected string but got int) - async for _ in runnable.astream_events("foo", version="v1"): + async for _ in runnable.astream_events("foo", version="v2"): pass # Verify that this is a 422 error @@ -2315,7 +2315,7 @@ def mul_two(y: int) -> int: with pytest.raises(httpx.HTTPStatusError) as cb: # Invalid input type (expected string but got int) # include names should not be a list of lists - async for _ in runnable.astream_events(1, include_names=[[]], version="v1"): + async for _ in runnable.astream_events(1, include_names=[[]], version="v2"): pass # Verify that this is a 422 error @@ -2324,7 +2324,7 @@ def mul_two(y: int) -> int: # Test good requests events = [] - async for event in runnable.astream_events(1, version="v1"): + async for event in runnable.astream_events(1, version="v2"): events.append(event) # validate events @@ -2337,6 +2337,7 @@ def mul_two(y: int) -> int: assert not k.startswith("__") assert "metadata" in event del event["metadata"] + event["parent_ids"] = [] assert events == [ { @@ -2416,6 +2417,7 @@ def _clean_up_events(events: List[Dict[str, Any]]) -> None: assert not k.startswith("__") assert "metadata" in event del event["metadata"] + event["parent_ids"] = [] async def test_astream_events_with_serialization( @@ -2488,7 +2490,7 @@ def back_to_serializable(inputs) -> str: app, raise_app_exceptions=False, path="/doc_types" ) as runnable: # Test good requests - events = [event async for event in runnable.astream_events("foo", version="v1")] + events = [event async for event in runnable.astream_events("foo", version="v2")] _clean_up_events(events) assert events == [ @@ -2578,7 +2580,7 @@ def back_to_serializable(inputs) -> str: app, raise_app_exceptions=False, path="/get_pets" ) as runnable: # Test good requests - events = [event async for event in runnable.astream_events("foo", version="v1")] + events = [event async for event in runnable.astream_events("foo", version="v2")] _clean_up_events(events) assert events == [ { @@ -2613,7 +2615,7 @@ def back_to_serializable(inputs) -> str: ) as runnable: # Test good requests with pytest.raises(httpx.HTTPStatusError) as cb: - async for event in runnable.astream_events("foo", version="v1"): + async for event in runnable.astream_events("foo", version="v2"): pass assert cb.value.response.status_code == 500 @@ -2641,7 +2643,7 @@ async def test_astream_events_with_prompt_model_parser_chain( events = [ event async for event in runnable.astream_events( - {"question": "hello"}, version="v1" + {"question": "hello"}, version="v2" ) ] _clean_up_events(events) @@ -2850,25 +2852,16 @@ async def test_astream_events_with_prompt_model_parser_chain( ] }, "output": { - "generations": [ - [ - { - "generation_info": None, - "message": { - "additional_kwargs": {}, - "content": "Hello World!", - "name": None, - "response_metadata": {}, - "type": "AIMessageChunk", - }, - "text": "Hello World!", - "type": "ChatGenerationChunk", - } - ] - ], - "llm_output": None, - "run": None, - "type": "LLMResult", + "additional_kwargs": {}, + "content": "Hello World!", + "example": False, + "invalid_tool_calls": [], + "name": None, + "response_metadata": {}, + "tool_call_chunks": [], + "tool_calls": [], + "type": "AIMessageChunk", + "usage_metadata": None, }, }, "event": "on_chat_model_end",