Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Sep 12, 2024
1 parent 45ddb62 commit f048281
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from langsmith import schemas as ls_schemas
from langsmith.client import Client
from langsmith.schemas import FeedbackIngestToken
from orjson import orjson
from pydantic import BaseModel, Field, __version__
from pytest import MonkeyPatch
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -231,13 +232,17 @@ async def get_async_test_client(

@asynccontextmanager
async def get_async_remote_runnable(
server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True
server: FastAPI,
*,
path: Optional[str] = None,
raise_app_exceptions: bool = True,
**kwargs: Any,
) -> RemoteRunnable:
"""Get an async client."""
url = "http://localhost:9999"
if path:
url += path
remote_runnable_client = RemoteRunnable(url=url)
remote_runnable_client = RemoteRunnable(url=url, **kwargs)

async with get_async_test_client(
server, path=path, raise_app_exceptions=raise_app_exceptions
Expand Down Expand Up @@ -2146,6 +2151,46 @@ async def check_types(inputs: VariousTypes) -> int:
)


async def test_custom_serialization() -> None:
"""Test updating the config based on the raw request object."""
from langserve.serialization import Serializer

class CustomObject:
def __init__(self, x: int) -> None:
self.x = x

class CustomSerializer(Serializer):
def dumpd(self, obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return orjson.loads(orjson.dumps(obj))

def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
return orjson.dumps(obj)

def loadd(self, obj: Any) -> Any:
"""Load the given object."""
raise NotImplementedError()

def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
return orjson.loads(s)

def foo(x: int) -> Any:
"""Add one to simulate a valid function."""
return 2

app = FastAPI()
server_runnable = RunnableLambda(foo)
add_routes(app, server_runnable, serializer=CustomSerializer())

async with get_async_remote_runnable(
app, raise_app_exceptions=True, serializer=CustomSerializer()
) as runnable:
result = await runnable.ainvoke(5)
assert result == {}


async def test_endpoint_configurations() -> None:
"""Test enabling/disabling endpoints."""
app = FastAPI()
Expand Down

0 comments on commit f048281

Please sign in to comment.