From 69a33bc516602b5a3a5311bb0116d418aaa1b942 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 15 Dec 2023 15:53:02 -0500 Subject: [PATCH] x --- examples/api_handler_simple_server/server.py | 115 +++++++++++++++++++ langserve/__init__.py | 9 +- langserve/api_handler.py | 39 +++---- langserve/server.py | 4 +- 4 files changed, 141 insertions(+), 26 deletions(-) create mode 100755 examples/api_handler_simple_server/server.py diff --git a/examples/api_handler_simple_server/server.py b/examples/api_handler_simple_server/server.py new file mode 100755 index 00000000..69a654c8 --- /dev/null +++ b/examples/api_handler_simple_server/server.py @@ -0,0 +1,115 @@ +"""An example that shows how to use the API handler directly. + +For this to work with RemoteClient, the routes must match those expected +by the client; i.e., /invoke, /batch, /stream, etc. No trailing slashes should be used. +""" +from importlib import metadata +from typing import Annotated + +from fastapi import Depends, FastAPI, Request, Response +from langchain_core.runnables import RunnableLambda +from sse_starlette import EventSourceResponse + +from langserve import APIHandler + +PYDANTIC_VERSION = metadata.version("pydantic") +_PYDANTIC_MAJOR_VERSION: int = int(PYDANTIC_VERSION.split(".")[0]) + +app = FastAPI( + title="LangChain Server", + version="1.0", + description="Spin up a simple api server using Langchain's Runnable interfaces", +) + + +## +# Example 1 +# Add an endpoint for invoke, and batch together +def add_one(x: int) -> int: + """Add one to the given number.""" + return x + 1 + + +chain = RunnableLambda(add_one) + +api_handler = APIHandler(chain, "/simple") + + +@app.post("/simple/invoke", include_in_schema=False) +async def simple_invoke(request: Request) -> Response: + """Handle a request.""" + # The API Handler validates the request parts that are defined + return await api_handler.invoke(request) + + +@app.post("/simple/batch", include_in_schema=False) +async def simple_batch(request: Request) -> Response: + """Handle a request.""" + # The API Handler validates the request parts that are defined + return await api_handler.batch(request) + + +# Here, we show how to populate the documentation for the endpoint. +# Please note that this is done separately from the actual endpoint. +# This happens due to two reasons: +# 1. FastAPI does not support using pydantic.v1 models in the docs endpoint. +# "https://github.com/tiangolo/fastapi/issues/10360" +# LangChain uses pydantic.v1 models! +# 2. Configurable Runnables have a *dynamic* schema, which means that +# the shape of the input depends on the config. +# In this case, the openapi schema is a best effort showing the documentation +# that will work for the default config (and any non-conflicting configs). +if _PYDANTIC_MAJOR_VERSION == 1: # Do not use in your own + # Add documentation + @app.post("/simple/invoke") + async def simple_invoke_docs( + request: api_handler.InvokeRequest, + ) -> api_handler.InvokeResponse: + """API endpoint used only for documentation purposes. Populate /docs endpoint""" + raise NotImplementedError( + "This endpoint is only used for documentation purposes" + ) + + @app.post("/simple/batch") + async def simple_batch_docs( + request: api_handler.BatchRequest, + ) -> api_handler.BatchResponse: + """API endpoint used only for documentation purposes. Populate /docs endpoint""" + raise NotImplementedError( + "This endpoint is only used for documentation purposes" + ) + +else: + print( + "Skipping documentation generation for pydantic v2: " + "https://github.com/tiangolo/fastapi/issues/10360" + ) + + +async def _get_api_handler() -> APIHandler: + """Prepare a RunnableLambda.""" + return APIHandler(RunnableLambda(add_one), "/v2") + + +@app.post("/v2/invoke") +async def v2_invoke( + request: Request, runnable: Annotated[APIHandler, Depends(_get_api_handler)] +) -> Response: + """Handle a request.""" + # The API Handler validates the request parts that are defined + return await runnable.invoke(request) + + +@app.post("/v2/stream") +async def v2_stream( + request: Request, runnable: Annotated[APIHandler, Depends(_get_api_handler)] +) -> EventSourceResponse: + """Handle a request.""" + # The API Handler validates the request parts that are defined + return await runnable.stream(request) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="localhost", port=8000) diff --git a/langserve/__init__.py b/langserve/__init__.py index 1498b3df..bdddd5bb 100644 --- a/langserve/__init__.py +++ b/langserve/__init__.py @@ -4,9 +4,16 @@ to be considered private and subject to change without notice. """ +from langserve.api_handler import APIHandler from langserve.client import RemoteRunnable from langserve.schema import CustomUserType from langserve.server import add_routes from langserve.version import __version__ -__all__ = ["RemoteRunnable", "add_routes", "__version__", "CustomUserType"] +__all__ = [ + "RemoteRunnable", + "APIHandler", + "add_routes", + "__version__", + "CustomUserType", +] diff --git a/langserve/api_handler.py b/langserve/api_handler.py index c9d9d6bd..b1a22ffb 100644 --- a/langserve/api_handler.py +++ b/langserve/api_handler.py @@ -381,7 +381,14 @@ def _add_callbacks( config["callbacks"].extend(callbacks) -class _APIHandler: +_MODEL_REGISTRY = {} +_SEEN_NAMES = set() + + +# PUBLIC API + + +class APIHandler: """Implementation of the various API endpoints for a runnable server. This is a private class whose API is expected to change. @@ -394,8 +401,8 @@ class _APIHandler: def __init__( self, runnable: Runnable, + path: str, # The path under which the runnable is served. *, - path: str = "", prefix: str = "", input_type: Union[Type, Literal["auto"], BaseModel] = "auto", output_type: Union[Type, Literal["auto"], BaseModel] = "auto", @@ -461,7 +468,14 @@ def __init__( "Cannot configure run_name. Please remove it from config_keys." ) + if path and not path.startswith("/"): + raise ValueError( + f"Got an invalid path: {path}. " + f"If specifying path please start it with a `/`" + ) + self._config_keys = config_keys + self._path = path self._base_url = prefix + path self._include_callback_events = include_callback_events @@ -525,22 +539,6 @@ def __init__( ) self._BatchResponse = create_batch_response_model(model_namespace, output_type_) - def _route_name(name: str) -> str: - """Return the route name with the given name.""" - return f"{path.strip('/')} {name}" if path else name - - self._route_name = _route_name - - def _route_name_with_config(name: str) -> str: - """Return the route name with the given name.""" - return ( - f"{path.strip('/')} {name} with config" - if path - else f"{name} with config" - ) - - self._route_name_with_config = _route_name_with_config - @property def InvokeRequest(self) -> Type[BaseModel]: """Return the invoke request model.""" @@ -1074,7 +1072,6 @@ async def _check_feedback_enabled(self, config_hash: str = "") -> None: """Check if feedback is enabled for the runnable. This endpoint is private since it will be deprecated in the future. - """ if not (await self.check_feedback_enabled(config_hash)): raise HTTPException( @@ -1086,7 +1083,3 @@ async def _check_feedback_enabled(self, config_hash: str = "") -> None: async def check_feedback_enabled(self, config_hash: str = "") -> bool: """Check if feedback is enabled for the runnable.""" return self._enable_feedback_endpoint or not tracing_is_enabled() - - -_MODEL_REGISTRY = {} -_SEEN_NAMES = set() diff --git a/langserve/server.py b/langserve/server.py index 8f6d09ba..dd6a88f1 100644 --- a/langserve/server.py +++ b/langserve/server.py @@ -18,7 +18,7 @@ from langchain.schema.runnable import Runnable from typing_extensions import Annotated -from langserve.api_handler import PerRequestConfigModifier, _APIHandler, _is_hosted +from langserve.api_handler import APIHandler, PerRequestConfigModifier, _is_hosted from langserve.pydantic_v1 import ( _PYDANTIC_MAJOR_VERSION, PYDANTIC_VERSION, @@ -354,7 +354,7 @@ def add_routes( # Determine the base URL for the playground endpoint prefix = app.prefix if isinstance(app, APIRouter) else "" # type: ignore - api_handler = _APIHandler( + api_handler = APIHandler( runnable, path=path, prefix=prefix,