Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions python/flink_agents/plan/actions/chat_model_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,33 @@ def chat(
model: str,
messages: List[ChatMessage],
ctx: RunnerContext,
) -> None:
):
"""Chat with llm.

If there is no tool call generated, we return the chat response event directly,
otherwise, we generate tool request event according to the tool calls in chat model
response, and save the request and response messages in tool call context.

This function uses async execution for the chat model call to avoid blocking.
"""
# 1. Get resource BEFORE async execution (resource access must happen before async)
chat_model = cast(
"BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL)
)

# TODO: support async execution of chat.
response = chat_model.chat(messages)
# 2. Read memory BEFORE async execution (memory access must happen before async)
short_term_memory = ctx.short_term_memory
tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT)
if not tool_call_context:
tool_call_context = {}

# 3. Execute chat asynchronously (no memory access inside async function)
# The lambda captures chat_model and messages, but cannot access memory
response = yield from ctx.execute_async(
lambda: chat_model.chat(messages)
)

# 4. Process response and update memory AFTER async execution
# generate tool request event according tool calls in response
if len(response.tool_calls) > 0:
# TODO: Because memory doesn't support remove currently, so we use
Expand All @@ -66,9 +78,6 @@ def chat(
# to store and remove the specific tool context directly.

# save tool call context
tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT)
if not tool_call_context:
tool_call_context = {}
if initial_request_id not in tool_call_context:
tool_call_context[initial_request_id] = copy.deepcopy(messages)
# append response to tool call context
Expand All @@ -95,7 +104,6 @@ def chat(
# if there is no tool call generated, return chat response directly
else:
# clear tool call context related to specific request id
tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT)
if tool_call_context and initial_request_id in tool_call_context:
tool_call_context.pop(initial_request_id)
short_term_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)
Expand All @@ -107,15 +115,17 @@ def chat(
)


def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None:
def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext):
"""Built-in action for processing a chat request or tool response.

Internally, this action will use short term memory to save the tool call context,
which is a dict mapping request id to chat messages.

This function uses async execution for chat operations.
"""
short_term_memory = ctx.short_term_memory
if isinstance(event, ChatRequestEvent):
chat(
yield from chat(
initial_request_id=event.id,
model=event.model,
messages=event.messages,
Expand Down Expand Up @@ -150,7 +160,7 @@ def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> N
)
short_term_memory.set(_TOOL_CALL_CONTEXT, tool_call_context)

chat(
yield from chat(
initial_request_id=initial_request_id,
model=model,
messages=tool_call_context[initial_request_id],
Expand Down
46 changes: 36 additions & 10 deletions python/flink_agents/plan/actions/tool_call_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,48 @@
from flink_agents.plan.function import PythonFunction


def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None:
"""Built-in action for processing tool call requests."""
responses = {}
external_ids = {}
def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext):
"""Built-in action for processing tool call requests.

This function uses async execution for tool calls to avoid blocking.
"""
# 1. Get all resources BEFORE async execution (resource access must happen before async)
tools = {}
tool_call_info = []
for tool_call in event.tool_calls:
id = tool_call["id"]
name = tool_call["function"]["name"]
kwargs = tool_call["function"]["arguments"]
tool = ctx.get_resource(name, ResourceType.TOOL)
external_id = tool_call.get("original_id")
if not tool:
response = f"Tool `{name}` does not exist."

# Get tool resource before async execution
tool = ctx.get_resource(name, ResourceType.TOOL)
tools[id] = tool
tool_call_info.append({
"id": id,
"name": name,
"kwargs": kwargs,
"external_id": external_id,
"tool": tool,
})

# 2. Execute tools asynchronously (no memory access inside async function)
responses = {}
external_ids = {}
for info in tool_call_info:
if not info["tool"]:
# Tool doesn't exist - handle synchronously
responses[info["id"]] = f"Tool `{info['name']}` does not exist."
else:
response = tool.call(**kwargs)
responses[id] = response
external_ids[id] = external_id
# Execute tool call asynchronously
# The lambda captures tool and kwargs, but cannot access memory
response = yield from ctx.execute_async(
lambda t=info["tool"], k=info["kwargs"]: t.call(**k)
)
responses[info["id"]] = response
external_ids[info["id"]] = info["external_id"]

# 3. Send event AFTER async execution
ctx.send_event(
ToolResponseEvent(
request_id=event.id, responses=responses, external_ids=external_ids
Expand Down
157 changes: 157 additions & 0 deletions python/flink_agents/runtime/tests/test_built_in_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,160 @@ def test_built_in_actions() -> None: # noqa: D103
"3"
}
]


def test_built_in_actions_async_execution() -> None:
"""Test that built-in actions use async execution correctly.

This test verifies that chat_model_action and tool_call_action work
correctly with async execution, ensuring backward compatibility.
"""
import time

class SlowMockChatModelConnection(BaseChatModelConnection):
"""Mock ChatModel that simulates slow network calls."""

def chat(
self,
messages: Sequence[ChatMessage],
tools: List | None = None,
**kwargs: Any,
) -> ChatMessage:
"""Simulate slow network call."""
time.sleep(0.1) # Simulate network delay
if "sum" in messages[-1].content:
input = messages[-1].content
function = {"name": "add", "arguments": {"a": 1, "b": 2}}
tool_call = {
"id": uuid.uuid4(),
"type": ToolType.FUNCTION,
"function": function,
}
return ChatMessage(
role=MessageRole.ASSISTANT, content=input, tool_calls=[tool_call]
)
else:
content = "\n".join([message.content for message in messages])
return ChatMessage(role=MessageRole.ASSISTANT, content=content)

class SlowMockChatModel(BaseChatModelSetup):
"""Mock ChatModel with slow connection."""

@property
def model_kwargs(self) -> Dict[str, Any]:
return {}

def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage:
server = self.get_resource(self.connection, ResourceType.CHAT_MODEL_CONNECTION)
if self.prompt is not None:
if isinstance(self.prompt, str):
prompt = self.get_resource(self.prompt, ResourceType.PROMPT)
else:
prompt = self.prompt
if "sum" in messages[-1].content:
input_variable = {}
for msg in messages:
str_extra_args = {k: str(v) for k, v in msg.extra_args.items()}
input_variable.update(str_extra_args)
messages = prompt.format_messages(**input_variable)
tools = None
if self.tools is not None:
tools = [
self.get_resource(tool_name, ResourceType.TOOL)
for tool_name in self.tools
]
return server.chat(messages, tools=tools, **kwargs)

class SlowTool:
"""Mock tool that simulates slow execution."""

def __init__(self, name: str):
self.name = name

def call(self, **kwargs: Any) -> Any:
"""Simulate slow tool execution."""
time.sleep(0.05) # Simulate slow tool execution
if self.name == "add":
return kwargs.get("a", 0) + kwargs.get("b", 0)
return None

class AsyncTestAgent(Agent):
"""Agent for testing async execution."""

@prompt
@staticmethod
def prompt() -> Prompt:
return Prompt.from_text(
text="Please call the appropriate tool to do the following task: {task}",
)

@chat_model_connection
@staticmethod
def slow_connection() -> ResourceDescriptor:
return ResourceDescriptor(clazz=SlowMockChatModelConnection)

@chat_model_setup
@staticmethod
def slow_chat_model() -> ResourceDescriptor:
return ResourceDescriptor(
clazz=SlowMockChatModel,
connection="slow_connection",
prompt="prompt",
tools=["add"],
)

@tool
@staticmethod
def add(a: int, b: int) -> int:
"""Calculate the sum of a and b."""
time.sleep(0.05) # Simulate slow tool execution
return a + b

@action(InputEvent)
@staticmethod
def process_input(event: InputEvent, ctx: RunnerContext) -> None:
input = event.input
ctx.send_event(
ChatRequestEvent(
model="slow_chat_model",
messages=[
ChatMessage(
role=MessageRole.USER, content=input, extra_args={"task": input}
)
],
)
)

@action(ChatResponseEvent)
@staticmethod
def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None:
input = event.response
ctx.send_event(OutputEvent(output=input.content))

# Test that async execution works correctly
env = AgentsExecutionEnvironment.get_execution_environment()
input_list = []
agent = AsyncTestAgent()

output_list = env.from_list(input_list).apply(agent).to_list()

input_list.append({"key": "0001", "value": "calculate the sum of 1 and 2."})

# Measure execution time to verify async doesn't block
start_time = time.time()
env.execute()
execution_time = time.time() - start_time

# Verify results are correct
assert output_list == [
{
"0001": "calculate the sum of 1 and 2.\n"
"Please call the appropriate tool to do the following task: "
"calculate the sum of 1 and 2.\n"
"3"
}
]

# Verify execution completed (async should allow this to complete)
# Note: Exact timing depends on implementation, but it should complete
assert execution_time < 5.0 # Should complete within reasonable time
Loading