diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index c87a69ee..e196bf68 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -56,6 +56,7 @@ from langsmith import schemas as ls_schemas from langsmith.client import Client from langsmith.schemas import FeedbackIngestToken +from orjson import orjson from pydantic import BaseModel, Field, __version__ from pytest import MonkeyPatch from pytest_mock import MockerFixture @@ -231,13 +232,17 @@ async def get_async_test_client( @asynccontextmanager async def get_async_remote_runnable( - server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True + server: FastAPI, + *, + path: Optional[str] = None, + raise_app_exceptions: bool = True, + **kwargs: Any, ) -> RemoteRunnable: """Get an async client.""" url = "http://localhost:9999" if path: url += path - remote_runnable_client = RemoteRunnable(url=url) + remote_runnable_client = RemoteRunnable(url=url, **kwargs) async with get_async_test_client( server, path=path, raise_app_exceptions=raise_app_exceptions @@ -2146,6 +2151,46 @@ async def check_types(inputs: VariousTypes) -> int: ) +async def test_custom_serialization() -> None: + """Test updating the config based on the raw request object.""" + from langserve.serialization import Serializer + + class CustomObject: + def __init__(self, x: int) -> None: + self.x = x + + class CustomSerializer(Serializer): + def dumpd(self, obj: Any) -> Any: + """Convert the given object to a JSON serializable object.""" + return orjson.loads(orjson.dumps(obj)) + + def dumps(self, obj: Any) -> bytes: + """Dump the given object as a JSON string.""" + return orjson.dumps(obj) + + def loadd(self, obj: Any) -> Any: + """Load the given object.""" + raise NotImplementedError() + + def loads(self, s: bytes) -> Any: + """Load the given JSON string.""" + return orjson.loads(s) + + def foo(x: int) -> Any: + """Add one to simulate a valid function.""" + return 2 + + app = FastAPI() + server_runnable = RunnableLambda(foo) + add_routes(app, server_runnable, serializer=CustomSerializer()) + + async with get_async_remote_runnable( + app, raise_app_exceptions=True, serializer=CustomSerializer() + ) as runnable: + result = await runnable.ainvoke(5) + assert result == {} + + async def test_endpoint_configurations() -> None: """Test enabling/disabling endpoints.""" app = FastAPI()