diff --git a/langserve/callbacks.py b/langserve/callbacks.py index 767e6ad0..18dc9420 100644 --- a/langserve/callbacks.py +++ b/langserve/callbacks.py @@ -48,7 +48,7 @@ def log_callback(self, event: CallbackEventDict) -> None: async def on_chat_model_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], messages: List[List[BaseMessage]], *, run_id: UUID, @@ -73,7 +73,7 @@ async def on_chat_model_start( async def on_chain_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], inputs: Dict[str, Any], *, run_id: UUID, @@ -138,7 +138,7 @@ async def on_chain_error( async def on_retriever_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], query: str, *, run_id: UUID, @@ -202,7 +202,7 @@ async def on_retriever_error( async def on_tool_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], input_str: str, *, run_id: UUID, @@ -306,7 +306,7 @@ async def on_agent_finish( async def on_llm_start( self, - serialized: Dict[str, Any], + serialized: Optional[Dict[str, Any]], prompts: List[str], *, run_id: UUID, @@ -445,7 +445,14 @@ async def ahandle_callbacks( if event["parent_run_id"] is None: # How do we make sure it's None!? event["parent_run_id"] = callback_manager.run_id - event_data = {key: value for key, value in event.items() if key != "type"} + event_data = { + key: value + for key, value in event.items() + if key != "type" and key != "kwargs" + } + + if "kwargs" in event: + event_data.update(event["kwargs"]) await ahandle_event( # Unpacking like this may not work @@ -467,7 +474,14 @@ def handle_callbacks( if event["parent_run_id"] is None: # How do we make sure it's None!? event["parent_run_id"] = callback_manager.run_id - event_data = {key: value for key, value in event.items() if key != "type"} + event_data = { + key: value + for key, value in event.items() + if key != "type" and key != "kwargs" + } + + if "kwargs" in event: + event_data.update(event["kwargs"]) handle_event( # Unpacking like this may not work diff --git a/langserve/server_sent_events.py b/langserve/server_sent_events.py index f85375fa..2c8a06ed 100644 --- a/langserve/server_sent_events.py +++ b/langserve/server_sent_events.py @@ -1,8 +1,9 @@ """Adapted from https://github.com/florimondmanca/httpx-sse""" from contextlib import asynccontextmanager, contextmanager -from typing import Any, AsyncIterator, Iterator, List, Optional, TypedDict +from typing import Any, AsyncIterator, Iterator, List, Optional import httpx +from typing_extensions import TypedDict class ServerSentEvent(TypedDict): diff --git a/langserve/validation.py b/langserve/validation.py index 2ec56cc4..1256cdfb 100644 --- a/langserve/validation.py +++ b/langserve/validation.py @@ -396,27 +396,29 @@ class StreamEventsParameters(BaseModel): # status code and a message. -class OnChainStart(BaseModel): - """On Chain Start Callback Event.""" +class BaseCallback(BaseModel): + """Base class for all callback events.""" - serialized: Dict[str, Any] - inputs: Any run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + + +class OnChainStart(BaseCallback): + """On Chain Start Callback Event.""" + + serialized: Optional[Dict[str, Any]] = None + inputs: Any + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chain_start"] = "on_chain_start" -class OnChainEnd(BaseModel): +class OnChainEnd(BaseCallback): """On Chain End Callback Event.""" outputs: Any - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chain_end"] = "on_chain_end" @@ -428,38 +430,35 @@ class Error(BaseModel): type: Literal["error"] = "error" -class OnChainError(BaseModel): +class OnChainError(BaseCallback): """On Chain Error Callback Event.""" error: Error - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chain_error"] = "on_chain_error" -class OnToolStart(BaseModel): +class OnToolStart(BaseCallback): """On Tool Start Callback Event.""" - serialized: Dict[str, Any] + serialized: Optional[Dict[str, Any]] = None input_str: str run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_tool_start"] = "on_tool_start" -class OnToolEnd(BaseModel): +class OnToolEnd(BaseCallback): """On Tool End Callback Event.""" output: str run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_tool_end"] = "on_tool_end" @@ -467,36 +466,29 @@ class OnToolError(BaseModel): """On Tool Error Callback Event.""" error: Error - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_tool_error"] = "on_tool_error" -class OnChatModelStart(BaseModel): +class OnChatModelStart(BaseCallback): """On Chat Model Start Callback Event.""" - serialized: Dict[str, Any] + serialized: Optional[Dict[str, Any]] = None messages: List[List[BaseMessage]] - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_chat_model_start"] = "on_chat_model_start" -class OnLLMStart(BaseModel): +class OnLLMStart(BaseCallback): """On LLM Start Callback Event.""" - serialized: Dict[str, Any] + serialized: Optional[Dict[str, Any]] = None prompts: List[str] run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_llm_start"] = "on_llm_start" @@ -515,49 +507,39 @@ class LLMResult(BaseModel): """List of metadata info for model call for each input.""" -class OnLLMEnd(BaseModel): +class OnLLMEnd(BaseCallback): """On LLM End Callback Event.""" response: LLMResult - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_llm_end"] = "on_llm_end" -class OnRetrieverStart(BaseModel): +class OnRetrieverStart(BaseCallback): """On Retriever Start Callback Event.""" - serialized: Dict[str, Any] + serialized: Optional[Dict[str, Any]] = None query: str - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - metadata: Optional[Dict[str, Any]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_retriever_start"] = "on_retriever_start" -class OnRetrieverError(BaseModel): +class OnRetrieverError(BaseCallback): """On Retriever Error Callback Event.""" error: Error run_id: UUID parent_run_id: Optional[UUID] = None tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_retriever_error"] = "on_retriever_error" -class OnRetrieverEnd(BaseModel): +class OnRetrieverEnd(BaseCallback): """On Retriever End Callback Event.""" documents: Sequence[Document] - run_id: UUID - parent_run_id: Optional[UUID] = None - tags: Optional[List[str]] = None - kwargs: Any = None + kwargs: Optional[Dict[str, Any]] = None type: Literal["on_retriever_end"] = "on_retriever_end" diff --git a/poetry.lock b/poetry.lock index 75250b0e..4599fa04 100644 --- a/poetry.lock +++ b/poetry.lock @@ -826,13 +826,13 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "fastapi" -version = "0.114.1" +version = "0.114.2" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" files = [ - {file = "fastapi-0.114.1-py3-none-any.whl", hash = "sha256:5d4746f6e4b7dff0b4f6b6c6d5445645285f662fe75886e99af7ee2d6b58bb3e"}, - {file = "fastapi-0.114.1.tar.gz", hash = "sha256:1d7bbbeabbaae0acb0c22f0ab0b040f642d3093ca3645f8c876b6f91391861d8"}, + {file = "fastapi-0.114.2-py3-none-any.whl", hash = "sha256:44474a22913057b1acb973ab90f4b671ba5200482e7622816d79105dcece1ac5"}, + {file = "fastapi-0.114.2.tar.gz", hash = "sha256:0adb148b62edb09e8c6eeefa3ea934e8f276dabc038c5a82989ea6346050c3da"}, ] [package.dependencies] @@ -1564,33 +1564,36 @@ test = ["hatch", "ipykernel", "openapi-core (>=0.18.0,<0.19.0)", "openapi-spec-v [[package]] name = "langchain-core" -version = "0.3.0.dev5" +version = "0.3.0" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_core-0.3.0.dev5-py3-none-any.whl", hash = "sha256:8de38ea9dd15f9705d20ee44af7160216052467d47893d5473e0e878861685be"}, - {file = "langchain_core-0.3.0.dev5.tar.gz", hash = "sha256:d0986b3e2810522d90eb09d048ee4a92b079a334b313f4b7429624b4a062adff"}, + {file = "langchain_core-0.3.0-py3-none-any.whl", hash = "sha256:bee6dae2366d037ef0c5b87401fed14b5497cad26f97724e8c9ca7bc9239e847"}, + {file = "langchain_core-0.3.0.tar.gz", hash = "sha256:1249149ea3ba24c9c761011483c14091573a5eb1a773aa0db9c8ad155dd4a69d"}, ] [package.dependencies] jsonpatch = ">=1.33,<2.0" langsmith = ">=0.1.117,<0.2.0" packaging = ">=23.2,<25" -pydantic = ">=2.7.4,<3.0.0" +pydantic = [ + {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, + {version = ">=2.5.2,<3.0.0", markers = "python_full_version < \"3.12.4\""}, +] PyYAML = ">=5.3" tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" typing-extensions = ">=4.7" [[package]] name = "langsmith" -version = "0.1.117" +version = "0.1.120" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.117-py3-none-any.whl", hash = "sha256:e936ee9bcf8293b0496df7ba462a3702179fbe51f9dc28744b0fbec0dbf206ae"}, - {file = "langsmith-0.1.117.tar.gz", hash = "sha256:a1b532f49968b9339bcaff9118d141846d52ed3d803f342902e7448edf1d662b"}, + {file = "langsmith-0.1.120-py3-none-any.whl", hash = "sha256:54d2785e301646c0988e0a69ebe4d976488c87b41928b358cb153b6ddd8db62b"}, + {file = "langsmith-0.1.120.tar.gz", hash = "sha256:25499ca187b41bd89d784b272b97a8d76f60e0e21bdf20336e8a2aa6a9b23ac9"}, ] [package.dependencies] @@ -3264,13 +3267,13 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake [[package]] name = "urllib3" -version = "2.2.2" +version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, - {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, + {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, + {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] @@ -3857,13 +3860,13 @@ test = ["mypy", "pre-commit", "pytest", "pytest-asyncio", "websockets (>=10.0)"] [[package]] name = "zipp" -version = "3.20.1" +version = "3.20.2" description = "Backport of pathlib-compatible object wrapper for zip files" optional = false python-versions = ">=3.8" files = [ - {file = "zipp-3.20.1-py3-none-any.whl", hash = "sha256:9960cd8967c8f85a56f920d5d507274e74f9ff813a0ab8889a5b5be2daf44064"}, - {file = "zipp-3.20.1.tar.gz", hash = "sha256:c22b14cc4763c5a5b04134207736c107db42e9d3ef2d9779d465f5f1bcba572b"}, + {file = "zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350"}, + {file = "zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29"}, ] [package.extras] @@ -3882,4 +3885,4 @@ server = ["fastapi", "sse-starlette"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "bee76de01384e265c807d785959226cdf9c788b4d5dfc3772dcce5d70a8f6121" +content-hash = "dd831e6c53bd5f1c57629df0394b5b7ffa58d040fdafaf9a23bd9d17d502d2f9" diff --git a/pyproject.toml b/pyproject.toml index 651464bf..c27e203b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ python = "^3.9" httpx = ">=0.23.0" # May be able to decrease this version fastapi = {version = ">=0.90.1,<1", optional = true} sse-starlette = {version = "^1.3.0", optional = true} -langchain-core = "0.3.0dev5" +langchain-core = "^0.3" orjson = ">=2" pyproject-toml = "^0.0.10" pydantic = "^2.7" diff --git a/tests/unit_tests/test_serialization.py b/tests/unit_tests/test_serialization.py index 9d3f7a4e..33fcc712 100644 --- a/tests/unit_tests/test_serialization.py +++ b/tests/unit_tests/test_serialization.py @@ -189,9 +189,3 @@ def test_encoding_of_well_known_types(obj: Any, expected: str) -> None: """ lc_serializer = WellKnownLCSerializer() assert lc_serializer.dumpd(obj) == expected - - -@pytest.mark.xfail(reason="0.3") -def test_fail_03() -> None: - """This test will fail on purposes. It contains a TODO list for 0.3 release.""" - assert "CHatGeneration_Deserialized correct" == "UNcomment test above" diff --git a/tests/unit_tests/test_server_client.py b/tests/unit_tests/test_server_client.py index c87a69ee..2c8902b5 100644 --- a/tests/unit_tests/test_server_client.py +++ b/tests/unit_tests/test_server_client.py @@ -122,6 +122,19 @@ def _replace_run_id_in_stream_resp(streamed_resp: str) -> str: return streamed_resp.replace(uuid, "") +def _null_run_id_and_metadata_recursively(decoded_response: Any) -> None: + """Recursively traverse the object and delete any keys called run_id""" + if isinstance(decoded_response, dict): + for key, value in decoded_response.items(): + if key in {"run_id", "__langserve_version"}: + decoded_response[key] = None + else: + _null_run_id_and_metadata_recursively(value) + elif isinstance(decoded_response, list): + for item in decoded_response: + _null_run_id_and_metadata_recursively(item) + + @pytest.fixture(scope="module") def event_loop(): """Create an instance of the default event loop for each test case.""" @@ -523,6 +536,15 @@ def test_invoke(sync_remote_runnable: RemoteRunnable) -> None: assert remote_runnable_run.child_runs[0].name == "add_one_or_passthrough" +def test_batch_tracer_with_single_input(sync_remote_runnable: RemoteRunnable) -> None: + """Test passing a single tracer to batch.""" + tracer = FakeTracer() + assert sync_remote_runnable.batch([1], config={"callbacks": [tracer]}) == [2] + assert len(tracer.runs) == 1 + assert len(tracer.runs[0].child_runs) == 1 + assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough" + + def test_batch(sync_remote_runnable: RemoteRunnable) -> None: """Test sync batch.""" assert sync_remote_runnable.batch([]) == [] @@ -536,17 +558,9 @@ def test_batch(sync_remote_runnable: RemoteRunnable) -> None: tracer = FakeTracer() assert sync_remote_runnable.batch([1, 2], config={"callbacks": [tracer]}) == [2, 3] assert len(tracer.runs) == 2 - # Light test to verify that we're picking up information about the server side - # function being invoked via a callback. - assert tracer.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) - assert tracer.runs[1].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[1].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) + assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer.runs[1].child_runs[0].name == "add_one_or_passthrough" # Verify that each tracer gets its own run tracer1 = FakeTracer() @@ -558,17 +572,8 @@ def test_batch(sync_remote_runnable: RemoteRunnable) -> None: assert len(tracer2.runs) == 1 # Light test to verify that we're picking up information about the server side # function being invoked via a callback. - assert tracer1.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer1.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) - - assert tracer2.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer2.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) + assert tracer1.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer2.runs[0].child_runs[0].name == "add_one_or_passthrough" async def test_ainvoke(async_remote_runnable: RemoteRunnable) -> None: @@ -601,10 +606,7 @@ async def test_ainvoke(async_remote_runnable: RemoteRunnable) -> None: elif sys.version_info < (3, 11): assert len(tracer.runs) == 1, "Failed for python < 3.11" remote_runnable = tracer.runs[0] - assert ( - remote_runnable.child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) + assert remote_runnable.name == "RemoteRunnable" else: raise AssertionError(f"Unsupported python version {sys.version_info}") @@ -624,17 +626,9 @@ async def test_abatch(async_remote_runnable: RemoteRunnable) -> None: [1, 2], config={"callbacks": [tracer]} ) == [2, 3] assert len(tracer.runs) == 2 - # Light test to verify that we're picking up information about the server side - # function being invoked via a callback. - assert tracer.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[0].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) - assert tracer.runs[1].child_runs[0].name == "RunnableLambda" - assert ( - tracer.runs[1].child_runs[0].extra["kwargs"]["name"] == "add_one_or_passthrough" - ) + assert tracer.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer.runs[1].child_runs[0].name == "add_one_or_passthrough" # Verify that each tracer gets its own run tracer1 = FakeTracer() @@ -644,19 +638,9 @@ async def test_abatch(async_remote_runnable: RemoteRunnable) -> None: ) == [2, 3] assert len(tracer1.runs) == 1 assert len(tracer2.runs) == 1 - # Light test to verify that we're picking up information about the server side - # function being invoked via a callback. - assert tracer1.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer1.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) - assert tracer2.runs[0].child_runs[0].name == "RunnableLambda" - assert ( - tracer2.runs[0].child_runs[0].extra["kwargs"]["name"] - == "add_one_or_passthrough" - ) + assert tracer1.runs[0].child_runs[0].name == "add_one_or_passthrough" + assert tracer2.runs[0].child_runs[0].name == "add_one_or_passthrough" async def test_astream(async_remote_runnable: RemoteRunnable) -> None: @@ -1128,6 +1112,156 @@ async def add_one(x: int) -> int: assert config_seen["metadata"]["__langserve_endpoint"] == "invoke" +async def test_include_callback_events(mocker: MockerFixture) -> None: + """This test should not use a RemoteRunnable. + + Check if callback events are being sent back from the server. + + Do so using the raw client. + """ + + async def add_one(x: int) -> int: + """Add one to simulate a valid function""" + return x + 1 + + server_runnable = RunnableLambda(func=add_one) + + app = FastAPI() + add_routes(app, server_runnable, input_type=int, include_callback_events=True) + async with AsyncClient( + base_url="http://localhost:9999", transport=httpx.ASGITransport(app=app) + ) as async_client: + response = await async_client.post("/invoke", json={"input": 1}) + # Config should be ignored but default debug information + # will still be added + assert response.status_code == 200 + decoded_response = response.json() + # Remove any run_id from the response recursively + _null_run_id_and_metadata_recursively(decoded_response) + assert decoded_response == { + "callback_events": [ + { + "inputs": 1, + "kwargs": {"name": "add_one", "run_type": None}, + "metadata": { + "__langserve_endpoint": "invoke", + "__langserve_version": None, + "__useragent": "python-httpx/0.27.2", + }, + "parent_run_id": None, + "serialized": None, + "run_id": None, + "tags": [], + "type": "on_chain_start", + }, + { + "kwargs": {}, + "outputs": 2, + "metadata": None, + "parent_run_id": None, + "tags": [], + "run_id": None, + "type": "on_chain_end", + }, + ], + "metadata": { + "feedback_tokens": [], + "run_id": None, + }, + "output": 2, + } + + +async def test_include_callback_events_batch() -> None: + """This test should not use a RemoteRunnable. + + Check if callback events are being sent back from the server. + + Do so using the raw client. + """ + + async def add_one(x: int) -> int: + """Add one to simulate a valid function""" + return x + 1 + + server_runnable = RunnableLambda(func=add_one) + + app = FastAPI() + add_routes(app, server_runnable, input_type=int, include_callback_events=True) + async with AsyncClient( + base_url="http://localhost:9999", transport=httpx.ASGITransport(app=app) + ) as async_client: + response = await async_client.post("/batch", json={"inputs": [1, 2]}) + # Config should be ignored but default debug information + # will still be added + assert response.status_code == 200 + decoded_response = response.json() + # Remove any run_id from the response recursively + _null_run_id_and_metadata_recursively(decoded_response) + del decoded_response["metadata"]["run_ids"] + assert decoded_response == { + "callback_events": [ + [ + { + "inputs": 1, + "kwargs": {"name": "add_one", "run_type": None}, + "metadata": { + "__langserve_endpoint": "batch", + "__langserve_version": None, + "__useragent": "python-httpx/0.27.2", + }, + "parent_run_id": None, + "run_id": None, + "serialized": None, + "tags": [], + "type": "on_chain_start", + }, + { + "kwargs": {}, + "outputs": 2, + "parent_run_id": None, + "metadata": None, + "run_id": None, + "tags": [], + "type": "on_chain_end", + }, + ], + [ + { + "inputs": 2, + "kwargs": {"name": "add_one", "run_type": None}, + "metadata": { + "__langserve_endpoint": "batch", + "__langserve_version": None, + "__useragent": "python-httpx/0.27.2", + }, + "parent_run_id": None, + "run_id": None, + "serialized": None, + "tags": [], + "type": "on_chain_start", + }, + { + "kwargs": {}, + "outputs": 3, + "parent_run_id": None, + "metadata": None, + "run_id": None, + "tags": [], + "type": "on_chain_end", + }, + ], + ], + "metadata": { + "responses": [ + {"feedback_tokens": [], "run_id": None}, + {"feedback_tokens": [], "run_id": None}, + ], + }, + "output": [2, 3], + } + + async def test_input_validation(mocker: MockerFixture) -> None: """Test client side and server side exceptions.""" diff --git a/tests/unit_tests/utils/tracer.py b/tests/unit_tests/utils/tracer.py index a634b9b9..bdb39f44 100644 --- a/tests/unit_tests/utils/tracer.py +++ b/tests/unit_tests/utils/tracer.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Any, Dict, List, Optional from uuid import UUID from langchain_core.tracers import BaseTracer @@ -39,6 +39,34 @@ def _copy_run(self, run: Run) -> Run: } ) + def _create_chain_run( + self, + serialized: Dict[str, Any], + inputs: Dict[str, Any], + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + run_type: Optional[str] = None, + name: Optional[str] = None, + **kwargs: Any, + ) -> Run: + if name is None: + # can't raise an exception from here, but can get a breakpoint + # import pdb; pdb.set_trace() + pass + return super()._create_chain_run( + serialized, + inputs, + run_id, + tags, + parent_run_id, + metadata, + run_type, + name, + **kwargs, + ) + def _persist_run(self, run: Run) -> None: """Persist a run.""" self.runs.append(self._copy_run(run))