Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to specify custom serializer #764

Merged
merged 7 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
PublicTraceLink,
PublicTraceLinkCreateRequest,
)
from langserve.serialization import WellKnownLCSerializer
from langserve.serialization import Serializer, WellKnownLCSerializer
from langserve.validation import (
BatchBaseResponse,
BatchRequestShallowValidator,
Expand Down Expand Up @@ -536,6 +536,7 @@ def __init__(
stream_log_name_allow_list: Optional[Sequence[str]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
serializer: Optional[Serializer] = None,
) -> None:
"""Create an API handler for the given runnable.

Expand Down Expand Up @@ -600,6 +601,8 @@ def __init__(
TODO: Introduce deprecation for this parameter to rename it
astream_events_version: version of the stream events endpoint to use.
By default "v2".
serializer: optional serializer to use for serializing the output.
If not provided, the default serializer will be used.
"""
if importlib.util.find_spec("sse_starlette") is None:
raise ImportError(
Expand Down Expand Up @@ -638,7 +641,7 @@ def __init__(
)
self._include_callback_events = include_callback_events
self._per_req_config_modifier = per_req_config_modifier
self._serializer = WellKnownLCSerializer()
self._serializer = serializer or WellKnownLCSerializer()
self._enable_feedback_endpoint = enable_feedback_endpoint
self._enable_public_trace_link_endpoint = enable_public_trace_link_endpoint
self._names_in_stream_allow_list = stream_log_name_allow_list
Expand Down
5 changes: 4 additions & 1 deletion langserve/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(
cert: Optional[CertTypes] = None,
client_kwargs: Optional[Dict[str, Any]] = None,
use_server_callback_events: bool = True,
serializer: Optional[Serializer] = None,
) -> None:
"""Initialize the client.

Expand All @@ -300,6 +301,8 @@ def __init__(
and async httpx clients
use_server_callback_events: Whether to invoke callbacks on any
callback events returned by the server.
serializer: The serializer to use for serializing and deserializing
data. If not provided, a default serializer will be used.
"""
_client_kwargs = client_kwargs or {}
# Enforce trailing slash
Expand Down Expand Up @@ -327,7 +330,7 @@ def __init__(

# Register cleanup handler once RemoteRunnable is garbage collected
weakref.finalize(self, _close_clients, self.sync_client, self.async_client)
self._lc_serializer = WellKnownLCSerializer()
self._lc_serializer = serializer or WellKnownLCSerializer()
self._use_server_callback_events = use_server_callback_events

def _invoke(
Expand Down
42 changes: 26 additions & 16 deletions langserve/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,39 +157,49 @@ def _decode_event_data(value: Any) -> Any:


class Serializer(abc.ABC):
@abc.abstractmethod
def dumpd(self, obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return orjson.loads(self.dumps(obj))

def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
return self.loadd(orjson.loads(s))

@abc.abstractmethod
def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
"""Dump the given object to a JSON byte string."""

@abc.abstractmethod
def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
def loadd(self, s: bytes) -> Any:
"""Given a python object, load it into a well known object.

@abc.abstractmethod
def loadd(self, obj: Any) -> Any:
"""Load the given object."""
The obj represents content that was json loaded from a string, but
not yet validated or converted into a well known object.
"""


class WellKnownLCSerializer(Serializer):
def dumpd(self, obj: Any) -> Any:
"""Convert the given object to a JSON serializable object."""
return orjson.loads(orjson.dumps(obj, default=default))
"""A pre-defined serializer for well known LangChain objects.

This is the default serialized used by LangServe for serializing and
de-serializing well known LangChain objects.

If you need to extend the serialization capabilities for your own application,
feel free to create a new instance of the Serializer class and implement
the abstract methods dumps and loadd.
"""

def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
"""Dump the given object to a JSON byte string."""
return orjson.dumps(obj, default=default)

def loadd(self, obj: Any) -> Any:
"""Load the given object."""
return _decode_lc_objects(obj)
"""Given a python object, load it into a well known object.

def loads(self, s: bytes) -> Any:
"""Load the given JSON string."""
return self.loadd(orjson.loads(s))
The obj represents content that was json loaded from a string, but
not yet validated or converted into a well known object.
"""
return _decode_lc_objects(obj)


def _project_top_level(model: BaseModel) -> Dict[str, Any]:
Expand Down
5 changes: 5 additions & 0 deletions langserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TokenFeedbackConfig,
_is_hosted,
)
from langserve.serialization import Serializer

try:
from fastapi import APIRouter, Depends, FastAPI, Request, Response
Expand Down Expand Up @@ -263,6 +264,7 @@ def add_routes(
dependencies: Optional[Sequence[Depends]] = None,
playground_type: Literal["default", "chat"] = "default",
astream_events_version: Literal["v1", "v2"] = "v2",
serializer: Optional[Serializer] = None,
) -> None:
"""Register the routes on the given FastAPI app or APIRouter.

Expand Down Expand Up @@ -383,6 +385,8 @@ def add_routes(
which message types are supported etc.)
astream_events_version: version of the stream events endpoint to use.
By default "v2".
serializer: The serializer to use for serializing the output. If not provided,
the default serializer will be used.
""" # noqa: E501
if not isinstance(runnable, Runnable):
raise TypeError(
Expand Down Expand Up @@ -447,6 +451,7 @@ def add_routes(
stream_log_name_allow_list=stream_log_name_allow_list,
playground_type=playground_type,
astream_events_version=astream_events_version,
serializer=serializer,
)

namespace = path or ""
Expand Down
52 changes: 50 additions & 2 deletions tests/unit_tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from langsmith import schemas as ls_schemas
from langsmith.client import Client
from langsmith.schemas import FeedbackIngestToken
from orjson import orjson
from pydantic import BaseModel, Field, __version__
from pytest import MonkeyPatch
from pytest_mock import MockerFixture
Expand Down Expand Up @@ -244,13 +245,17 @@ async def get_async_test_client(

@asynccontextmanager
async def get_async_remote_runnable(
server: FastAPI, *, path: Optional[str] = None, raise_app_exceptions: bool = True
server: FastAPI,
*,
path: Optional[str] = None,
raise_app_exceptions: bool = True,
**kwargs: Any,
) -> RemoteRunnable:
"""Get an async client."""
url = "http://localhost:9999"
if path:
url += path
remote_runnable_client = RemoteRunnable(url=url)
remote_runnable_client = RemoteRunnable(url=url, **kwargs)

async with get_async_test_client(
server, path=path, raise_app_exceptions=raise_app_exceptions
Expand Down Expand Up @@ -2280,6 +2285,49 @@ async def check_types(inputs: VariousTypes) -> int:
)


async def test_custom_serialization() -> None:
"""Test updating the config based on the raw request object."""
from langserve.serialization import Serializer

class CustomObject:
def __init__(self, x: int) -> None:
self.x = x

def __eq__(self, other) -> bool:
return self.x == other.x

class CustomSerializer(Serializer):
def dumps(self, obj: Any) -> bytes:
"""Dump the given object as a JSON string."""
if isinstance(obj, CustomObject):
return orjson.dumps({"x": obj.x})
else:
return orjson.dumps(obj)

def loadd(self, obj: Any) -> Any:
"""Load the given object."""
if isinstance(obj, bytes):
obj = obj.decode("utf-8")
if obj.get("x"):
return CustomObject(x=obj["x"])
return obj

def foo(x: int) -> Any:
"""Add one to simulate a valid function."""
return CustomObject(x=5)

app = FastAPI()
server_runnable = RunnableLambda(foo)
add_routes(app, server_runnable, serializer=CustomSerializer())

async with get_async_remote_runnable(
app, raise_app_exceptions=True, serializer=CustomSerializer()
) as runnable:
result = await runnable.ainvoke(5)
assert isinstance(result, CustomObject)
assert result == CustomObject(x=5)


async def test_endpoint_configurations() -> None:
"""Test enabling/disabling endpoints."""
app = FastAPI()
Expand Down
Loading