diff --git a/langserve/serialization.py b/langserve/serialization.py index e7fdbb74..de6ac4fb 100644 --- a/langserve/serialization.py +++ b/langserve/serialization.py @@ -8,6 +8,7 @@ from langchain.prompts.base import StringPromptValue from langchain.prompts.chat import ChatPromptValueConcrete +from langchain.schema.agent import AgentAction, AgentActionMessageLog, AgentFinish from langchain.schema.document import Document from langchain.schema.messages import ( AIMessage, @@ -50,6 +51,9 @@ class WellKnownLCObject(BaseModel): AIMessageChunk, StringPromptValue, ChatPromptValueConcrete, + AgentAction, + AgentFinish, + AgentActionMessageLog, ] diff --git a/tests/unit_tests/test_encoders.py b/tests/unit_tests/test_encoders.py deleted file mode 100644 index 507df3ec..00000000 --- a/tests/unit_tests/test_encoders.py +++ /dev/null @@ -1,168 +0,0 @@ -import json -from typing import Any - -import pytest -from langchain.schema.messages import ( - HumanMessage, - HumanMessageChunk, - SystemMessage, -) - -try: - from pydantic.v1 import BaseModel -except ImportError: - from pydantic import BaseModel - -from langserve.serialization import simple_dumps, simple_loads - - -@pytest.mark.parametrize( - "data, expected_json", - [ - # Test with python primitives - (1, 1), - ([], []), - ({}, {}), - ({"a": 1}, {"a": 1}), - ( - {"output": [HumanMessage(content="hello")]}, - { - "output": [ - { - "content": "hello", - "additional_kwargs": {}, - "type": "human", - "is_chunk": False, - "example": False, - } - ] - }, - ), - # Test with a single message (HumanMessage) - ( - HumanMessage(content="Hello"), - { - "additional_kwargs": {}, - "content": "Hello", - "example": False, - "type": "human", - "is_chunk": False, - }, - ), - # Test with a list containing mixed elements - ( - [HumanMessage(content="Hello"), SystemMessage(content="Hi"), 42, "world"], - [ - { - "additional_kwargs": {}, - "content": "Hello", - "example": False, - "type": "human", - "is_chunk": False, - }, - { - "additional_kwargs": {}, - "content": "Hi", - "type": "system", - "is_chunk": False, - }, - 42, - "world", - ], - ), - # Uncomment when langchain 0.0.306 is released - ( - [HumanMessage(content="Hello"), HumanMessageChunk(content="Hi")], - [ - { - "additional_kwargs": {}, - "content": "Hello", - "example": False, - "type": "human", - "is_chunk": False, - }, - { - "additional_kwargs": {}, - "content": "Hi", - "example": False, - "type": "human", - "is_chunk": True, - }, - ], - ), - # Attention: This test is not correct right now - # Test with full and chunk messages - ( - [HumanMessageChunk(content="Hello"), HumanMessage(content="Hi")], - [ - { - "additional_kwargs": {}, - "content": "Hello", - "example": False, - "type": "human", - "is_chunk": True, - }, - { - "additional_kwargs": {}, - "content": "Hi", - "example": False, - "type": "human", - "is_chunk": False, - }, - ], - ), - # Test with a dictionary containing mixed elements - ( - { - "message": HumanMessage(content="Greetings"), - "numbers": [1, 2, 3], - "boom": "Hello, world!", - }, - { - "message": { - "additional_kwargs": {}, - "content": "Greetings", - "example": False, - "type": "human", - "is_chunk": False, - }, - "numbers": [1, 2, 3], - "boom": "Hello, world!", - }, - ), - ], -) -def test_serialization(data: Any, expected_json: Any) -> None: - """Test that the LangChainEncoder encodes the data as expected.""" - # Test encoding - assert json.loads(simple_dumps(data)) == expected_json - # Test decoding - assert simple_loads(json.dumps(expected_json)) == data - # Test full representation are equivalent including the pydantic model classes - assert _get_full_representation(data) == _get_full_representation( - simple_loads(json.dumps(expected_json)) - ) - - -def _get_full_representation(data: Any) -> Any: - """Get the full representation of the data, replacing pydantic models with schema. - - Pydantic tests two different models for equality based on equality - of their schema; instead we will rely on the equality of their full - schema representation. This will make sure that both models have the - same name (e.g., HumanMessage vs. HumanMessageChunk). - - Args: - data: python primitives + pydantic models - - Returns: - data represented entirely with python primitives - """ - if isinstance(data, dict): - return {key: _get_full_representation(value) for key, value in data.items()} - elif isinstance(data, list): - return [_get_full_representation(value) for value in data] - elif isinstance(data, BaseModel): - return data.schema() - else: - return data diff --git a/tests/unit_tests/test_serialization.py b/tests/unit_tests/test_serialization.py new file mode 100644 index 00000000..69e12bc0 --- /dev/null +++ b/tests/unit_tests/test_serialization.py @@ -0,0 +1,77 @@ +from typing import Any + +import pytest +from langchain.schema.messages import ( + HumanMessage, + HumanMessageChunk, + SystemMessage, +) + +try: + from pydantic.v1 import BaseModel +except ImportError: + from pydantic import BaseModel + +from langserve.serialization import simple_dumps, simple_loads + + +@pytest.mark.parametrize( + "data", + [ + # Test with python primitives + 1, + [], + {}, + {"a": 1}, + {"output": [HumanMessage(content="hello")]}, + # Test with a single message (HumanMessage) + HumanMessage(content="Hello"), + # Test with a list containing mixed elements + [HumanMessage(content="Hello"), SystemMessage(content="Hi"), 42, "world"], + # Uncomment when langchain 0.0.306 is released + [HumanMessage(content="Hello"), HumanMessageChunk(content="Hi")], + # Attention: This test is not correct right now + # Test with full and chunk messages + [HumanMessageChunk(content="Hello"), HumanMessage(content="Hi")], + # Test with a dictionary containing mixed elements + { + "message": HumanMessage(content="Greetings"), + "numbers": [1, 2, 3], + "boom": "Hello, world!", + }, + ], +) +def test_serialization(data: Any) -> None: + """There and back again! :)""" + # Test encoding + assert isinstance(simple_dumps(data), str) + # Test simple equality (does not include pydantic class names) + assert simple_loads(simple_dumps(data)) == data + # Test full representation equality (includes pydantic class names) + assert _get_full_representation( + simple_loads(simple_dumps(data)) + ) == _get_full_representation(data) + + +def _get_full_representation(data: Any) -> Any: + """Get the full representation of the data, replacing pydantic models with schema. + + Pydantic tests two different models for equality based on equality + of their schema; instead we will rely on the equality of their full + schema representation. This will make sure that both models have the + same name (e.g., HumanMessage vs. HumanMessageChunk). + + Args: + data: python primitives + pydantic models + + Returns: + data represented entirely with python primitives + """ + if isinstance(data, dict): + return {key: _get_full_representation(value) for key, value in data.items()} + elif isinstance(data, list): + return [_get_full_representation(value) for value in data] + elif isinstance(data, BaseModel): + return data.schema() + else: + return data