Skip to content

Commit

Permalink
Fix issue with callback events sent from server (#765)
Browse files Browse the repository at this point in the history
This properly propagates the name of the runs from the server to the client if one enables callbacks.
  • Loading branch information
eyurtsev authored Sep 14, 2024
1 parent 8b4d8df commit c747e20
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 133 deletions.
28 changes: 21 additions & 7 deletions langserve/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def log_callback(self, event: CallbackEventDict) -> None:

async def on_chat_model_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
Expand All @@ -73,7 +73,7 @@ async def on_chat_model_start(

async def on_chain_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
inputs: Dict[str, Any],
*,
run_id: UUID,
Expand Down Expand Up @@ -138,7 +138,7 @@ async def on_chain_error(

async def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
query: str,
*,
run_id: UUID,
Expand Down Expand Up @@ -202,7 +202,7 @@ async def on_retriever_error(

async def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
input_str: str,
*,
run_id: UUID,
Expand Down Expand Up @@ -306,7 +306,7 @@ async def on_agent_finish(

async def on_llm_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
prompts: List[str],
*,
run_id: UUID,
Expand Down Expand Up @@ -445,7 +445,14 @@ async def ahandle_callbacks(
if event["parent_run_id"] is None: # How do we make sure it's None!?
event["parent_run_id"] = callback_manager.run_id

event_data = {key: value for key, value in event.items() if key != "type"}
event_data = {
key: value
for key, value in event.items()
if key != "type" and key != "kwargs"
}

if "kwargs" in event:
event_data.update(event["kwargs"])

await ahandle_event(
# Unpacking like this may not work
Expand All @@ -467,7 +474,14 @@ def handle_callbacks(
if event["parent_run_id"] is None: # How do we make sure it's None!?
event["parent_run_id"] = callback_manager.run_id

event_data = {key: value for key, value in event.items() if key != "type"}
event_data = {
key: value
for key, value in event.items()
if key != "type" and key != "kwargs"
}

if "kwargs" in event:
event_data.update(event["kwargs"])

handle_event(
# Unpacking like this may not work
Expand Down
3 changes: 2 additions & 1 deletion langserve/server_sent_events.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Adapted from https://github.com/florimondmanca/httpx-sse"""
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncIterator, Iterator, List, Optional, TypedDict
from typing import Any, AsyncIterator, Iterator, List, Optional

import httpx
from typing_extensions import TypedDict


class ServerSentEvent(TypedDict):
Expand Down
88 changes: 35 additions & 53 deletions langserve/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,27 +396,29 @@ class StreamEventsParameters(BaseModel):
# status code and a message.


class OnChainStart(BaseModel):
"""On Chain Start Callback Event."""
class BaseCallback(BaseModel):
"""Base class for all callback events."""

serialized: Dict[str, Any]
inputs: Any
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None


class OnChainStart(BaseCallback):
"""On Chain Start Callback Event."""

serialized: Optional[Dict[str, Any]] = None
inputs: Any
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_start"] = "on_chain_start"


class OnChainEnd(BaseModel):
class OnChainEnd(BaseCallback):
"""On Chain End Callback Event."""

outputs: Any
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_end"] = "on_chain_end"


Expand All @@ -428,75 +430,65 @@ class Error(BaseModel):
type: Literal["error"] = "error"


class OnChainError(BaseModel):
class OnChainError(BaseCallback):
"""On Chain Error Callback Event."""

error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chain_error"] = "on_chain_error"


class OnToolStart(BaseModel):
class OnToolStart(BaseCallback):
"""On Tool Start Callback Event."""

serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
input_str: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_start"] = "on_tool_start"


class OnToolEnd(BaseModel):
class OnToolEnd(BaseCallback):
"""On Tool End Callback Event."""

output: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_end"] = "on_tool_end"


class OnToolError(BaseModel):
"""On Tool Error Callback Event."""

error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_tool_error"] = "on_tool_error"


class OnChatModelStart(BaseModel):
class OnChatModelStart(BaseCallback):
"""On Chat Model Start Callback Event."""

serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
messages: List[List[BaseMessage]]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_chat_model_start"] = "on_chat_model_start"


class OnLLMStart(BaseModel):
class OnLLMStart(BaseCallback):
"""On LLM Start Callback Event."""

serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
prompts: List[str]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_llm_start"] = "on_llm_start"


Expand All @@ -515,49 +507,39 @@ class LLMResult(BaseModel):
"""List of metadata info for model call for each input."""


class OnLLMEnd(BaseModel):
class OnLLMEnd(BaseCallback):
"""On LLM End Callback Event."""

response: LLMResult
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_llm_end"] = "on_llm_end"


class OnRetrieverStart(BaseModel):
class OnRetrieverStart(BaseCallback):
"""On Retriever Start Callback Event."""

serialized: Dict[str, Any]
serialized: Optional[Dict[str, Any]] = None
query: str
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
metadata: Optional[Dict[str, Any]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_start"] = "on_retriever_start"


class OnRetrieverError(BaseModel):
class OnRetrieverError(BaseCallback):
"""On Retriever Error Callback Event."""

error: Error
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_error"] = "on_retriever_error"


class OnRetrieverEnd(BaseModel):
class OnRetrieverEnd(BaseCallback):
"""On Retriever End Callback Event."""

documents: Sequence[Document]
run_id: UUID
parent_run_id: Optional[UUID] = None
tags: Optional[List[str]] = None
kwargs: Any = None
kwargs: Optional[Dict[str, Any]] = None
type: Literal["on_retriever_end"] = "on_retriever_end"


Expand Down
37 changes: 20 additions & 17 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ python = "^3.9"
httpx = ">=0.23.0" # May be able to decrease this version
fastapi = {version = ">=0.90.1,<1", optional = true}
sse-starlette = {version = "^1.3.0", optional = true}
langchain-core = "0.3.0dev5"
langchain-core = "^0.3"
orjson = ">=2"
pyproject-toml = "^0.0.10"
pydantic = "^2.7"
Expand Down
6 changes: 0 additions & 6 deletions tests/unit_tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,3 @@ def test_encoding_of_well_known_types(obj: Any, expected: str) -> None:
"""
lc_serializer = WellKnownLCSerializer()
assert lc_serializer.dumpd(obj) == expected


@pytest.mark.xfail(reason="0.3")
def test_fail_03() -> None:
"""This test will fail on purposes. It contains a TODO list for 0.3 release."""
assert "CHatGeneration_Deserialized correct" == "UNcomment test above"
Loading

0 comments on commit c747e20

Please sign in to comment.