Skip to content

Commit

Permalink
xx
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Sep 12, 2024
1 parent 59b3c81 commit 7f1e754
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
6 changes: 5 additions & 1 deletion langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(
per_req_config_modifier: Optional[PerRequestConfigModifier] = None,
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
) -> None:
"""Create an API handler for the given runnable.
Expand Down Expand Up @@ -595,6 +596,8 @@ def __init__(
If not provided, then all logs will be allowed to be streamed.
Use to also limit the events that can be streamed by the stream_events.
TODO: Introduce deprecation for this parameter to rename it
astream_events_version: version of the stream events endpoint to use.
By default "v2".
"""
if importlib.util.find_spec("sse_starlette") is None:
raise ImportError(
Expand Down Expand Up @@ -632,6 +635,7 @@ def __init__(
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
self._astream_events_version = astream_events_version

if token_feedback_config:
if len(token_feedback_config["key_configs"]) != 1:
Expand Down Expand Up @@ -1343,7 +1347,7 @@ async def _stream_events() -> AsyncIterator[dict]:
exclude_names=stream_events_request.exclude_names,
exclude_types=stream_events_request.exclude_types,
exclude_tags=stream_events_request.exclude_tags,
version="v1",
version=self._astream_events_version,
):
if (
self._names_in_stream_allow_list is None
Expand Down
22 changes: 17 additions & 5 deletions langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def _log_error_message_once(error_message: str) -> None:
logger.error(error_message)


@lru_cache(maxsize=1_000) # Will accommodate up to 1_000 different error messages
def _log_info_message_once(error_message: str) -> None:
"""Log an error message once."""
logger.info(error_message)


def _sanitize_request(request: httpx.Request) -> httpx.Request:
"""Remove sensitive headers from the request."""
accept_headers = {
Expand Down Expand Up @@ -752,7 +758,7 @@ async def astream_events(
input: Any,
config: Optional[RunnableConfig] = None,
*,
version: Literal["v1"],
version: Literal["v1", "v2", None] = None,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
Expand All @@ -775,21 +781,27 @@ async def astream_events(
input: The input to the runnable
config: The config to use for the runnable
version: The version of the astream_events to use.
Currently only "v1" is supported.
Currently, this input is IGNORED on the client.
The server will return whatever format it's configured with.
include_names: The names of the events to include
include_types: The types of the events to include
include_tags: The tags of the events to include
exclude_names: The names of the events to exclude
exclude_types: The types of the events to exclude
exclude_tags: The tags of the events to exclude
"""
if version != "v1":
raise ValueError(f"Unsupported version: {version}. Use 'v1'")

# Create a stream handler that will emit Log objects
config = ensure_config(config)
callback_manager = get_async_callback_manager_for_config(config)

if version is not None:
_log_info_message_once(
"Versioning of the astream_events API is not supported on the client "
"side currently. The server will return events in whatever format "
"it was configured with in add_routes or APIHandler. "
"To stop seeing this message, remove the `version` argument."
)

events = []

run_manager = await callback_manager.on_chain_start(
Expand Down
11 changes: 7 additions & 4 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def add_routes(
enabled_endpoints: Optional[Sequence[EndpointName]] = None,
dependencies: Optional[Sequence[Depends]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.
Expand Down Expand Up @@ -380,14 +381,16 @@ def add_routes(
- chat: UX is optimized for chat-like interactions. Please review
the README in langserve for more details about constraints (e.g.,
which message types are supported etc.)
astream_events_version: version of the stream events endpoint to use.
By default "v2".
""" # noqa: E501
if not isinstance(runnable, Runnable):
raise TypeError(
f"Expected a Runnable, got {type(runnable)}. "
f"The second argument to add_routes should be a Runnable instance."
f"add_route(app, runnable, ...) is the correct usage."
f"Please make sure that you are using a runnable which is an instance of "
f"langchain_core.runnables.Runnable."
"The second argument to add_routes should be a Runnable instance."
"add_route(app, runnable, ...) is the correct usage."
"Please make sure that you are using a runnable which is an instance of "
"langchain_core.runnables.Runnable."
)

endpoint_configuration = _EndpointConfiguration(
Expand Down

0 comments on commit 7f1e754

Please sign in to comment.