diff --git a/python/flink_agents/plan/actions/chat_model_action.py b/python/flink_agents/plan/actions/chat_model_action.py index bb74a5f0..865bb250 100644 --- a/python/flink_agents/plan/actions/chat_model_action.py +++ b/python/flink_agents/plan/actions/chat_model_action.py @@ -40,21 +40,33 @@ def chat( model: str, messages: List[ChatMessage], ctx: RunnerContext, -) -> None: +): """Chat with llm. If there is no tool call generated, we return the chat response event directly, otherwise, we generate tool request event according to the tool calls in chat model response, and save the request and response messages in tool call context. + + This function uses async execution for the chat model call to avoid blocking. """ + # 1. Get resource BEFORE async execution (resource access must happen before async) chat_model = cast( "BaseChatModelSetup", ctx.get_resource(model, ResourceType.CHAT_MODEL) ) - # TODO: support async execution of chat. - response = chat_model.chat(messages) + # 2. Read memory BEFORE async execution (memory access must happen before async) short_term_memory = ctx.short_term_memory + tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT) + if not tool_call_context: + tool_call_context = {} + + # 3. Execute chat asynchronously (no memory access inside async function) + # The lambda captures chat_model and messages, but cannot access memory + response = yield from ctx.execute_async( + lambda: chat_model.chat(messages) + ) + # 4. Process response and update memory AFTER async execution # generate tool request event according tool calls in response if len(response.tool_calls) > 0: # TODO: Because memory doesn't support remove currently, so we use @@ -66,9 +78,6 @@ def chat( # to store and remove the specific tool context directly. # save tool call context - tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT) - if not tool_call_context: - tool_call_context = {} if initial_request_id not in tool_call_context: tool_call_context[initial_request_id] = copy.deepcopy(messages) # append response to tool call context @@ -95,7 +104,6 @@ def chat( # if there is no tool call generated, return chat response directly else: # clear tool call context related to specific request id - tool_call_context = short_term_memory.get(_TOOL_CALL_CONTEXT) if tool_call_context and initial_request_id in tool_call_context: tool_call_context.pop(initial_request_id) short_term_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) @@ -107,15 +115,17 @@ def chat( ) -def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> None: +def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext): """Built-in action for processing a chat request or tool response. Internally, this action will use short term memory to save the tool call context, which is a dict mapping request id to chat messages. + + This function uses async execution for chat operations. """ short_term_memory = ctx.short_term_memory if isinstance(event, ChatRequestEvent): - chat( + yield from chat( initial_request_id=event.id, model=event.model, messages=event.messages, @@ -150,7 +160,7 @@ def process_chat_request_or_tool_response(event: Event, ctx: RunnerContext) -> N ) short_term_memory.set(_TOOL_CALL_CONTEXT, tool_call_context) - chat( + yield from chat( initial_request_id=initial_request_id, model=model, messages=tool_call_context[initial_request_id], diff --git a/python/flink_agents/plan/actions/tool_call_action.py b/python/flink_agents/plan/actions/tool_call_action.py index 4d56e36a..b8a6684f 100644 --- a/python/flink_agents/plan/actions/tool_call_action.py +++ b/python/flink_agents/plan/actions/tool_call_action.py @@ -22,22 +22,48 @@ from flink_agents.plan.function import PythonFunction -def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext) -> None: - """Built-in action for processing tool call requests.""" - responses = {} - external_ids = {} +def process_tool_request(event: ToolRequestEvent, ctx: RunnerContext): + """Built-in action for processing tool call requests. + + This function uses async execution for tool calls to avoid blocking. + """ + # 1. Get all resources BEFORE async execution (resource access must happen before async) + tools = {} + tool_call_info = [] for tool_call in event.tool_calls: id = tool_call["id"] name = tool_call["function"]["name"] kwargs = tool_call["function"]["arguments"] - tool = ctx.get_resource(name, ResourceType.TOOL) external_id = tool_call.get("original_id") - if not tool: - response = f"Tool `{name}` does not exist." + + # Get tool resource before async execution + tool = ctx.get_resource(name, ResourceType.TOOL) + tools[id] = tool + tool_call_info.append({ + "id": id, + "name": name, + "kwargs": kwargs, + "external_id": external_id, + "tool": tool, + }) + + # 2. Execute tools asynchronously (no memory access inside async function) + responses = {} + external_ids = {} + for info in tool_call_info: + if not info["tool"]: + # Tool doesn't exist - handle synchronously + responses[info["id"]] = f"Tool `{info['name']}` does not exist." else: - response = tool.call(**kwargs) - responses[id] = response - external_ids[id] = external_id + # Execute tool call asynchronously + # The lambda captures tool and kwargs, but cannot access memory + response = yield from ctx.execute_async( + lambda t=info["tool"], k=info["kwargs"]: t.call(**k) + ) + responses[info["id"]] = response + external_ids[info["id"]] = info["external_id"] + + # 3. Send event AFTER async execution ctx.send_event( ToolResponseEvent( request_id=event.id, responses=responses, external_ids=external_ids diff --git a/python/flink_agents/runtime/tests/test_built_in_actions.py b/python/flink_agents/runtime/tests/test_built_in_actions.py index 560a625c..0fa2e576 100644 --- a/python/flink_agents/runtime/tests/test_built_in_actions.py +++ b/python/flink_agents/runtime/tests/test_built_in_actions.py @@ -205,3 +205,160 @@ def test_built_in_actions() -> None: # noqa: D103 "3" } ] + + +def test_built_in_actions_async_execution() -> None: + """Test that built-in actions use async execution correctly. + + This test verifies that chat_model_action and tool_call_action work + correctly with async execution, ensuring backward compatibility. + """ + import time + + class SlowMockChatModelConnection(BaseChatModelConnection): + """Mock ChatModel that simulates slow network calls.""" + + def chat( + self, + messages: Sequence[ChatMessage], + tools: List | None = None, + **kwargs: Any, + ) -> ChatMessage: + """Simulate slow network call.""" + time.sleep(0.1) # Simulate network delay + if "sum" in messages[-1].content: + input = messages[-1].content + function = {"name": "add", "arguments": {"a": 1, "b": 2}} + tool_call = { + "id": uuid.uuid4(), + "type": ToolType.FUNCTION, + "function": function, + } + return ChatMessage( + role=MessageRole.ASSISTANT, content=input, tool_calls=[tool_call] + ) + else: + content = "\n".join([message.content for message in messages]) + return ChatMessage(role=MessageRole.ASSISTANT, content=content) + + class SlowMockChatModel(BaseChatModelSetup): + """Mock ChatModel with slow connection.""" + + @property + def model_kwargs(self) -> Dict[str, Any]: + return {} + + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatMessage: + server = self.get_resource(self.connection, ResourceType.CHAT_MODEL_CONNECTION) + if self.prompt is not None: + if isinstance(self.prompt, str): + prompt = self.get_resource(self.prompt, ResourceType.PROMPT) + else: + prompt = self.prompt + if "sum" in messages[-1].content: + input_variable = {} + for msg in messages: + str_extra_args = {k: str(v) for k, v in msg.extra_args.items()} + input_variable.update(str_extra_args) + messages = prompt.format_messages(**input_variable) + tools = None + if self.tools is not None: + tools = [ + self.get_resource(tool_name, ResourceType.TOOL) + for tool_name in self.tools + ] + return server.chat(messages, tools=tools, **kwargs) + + class SlowTool: + """Mock tool that simulates slow execution.""" + + def __init__(self, name: str): + self.name = name + + def call(self, **kwargs: Any) -> Any: + """Simulate slow tool execution.""" + time.sleep(0.05) # Simulate slow tool execution + if self.name == "add": + return kwargs.get("a", 0) + kwargs.get("b", 0) + return None + + class AsyncTestAgent(Agent): + """Agent for testing async execution.""" + + @prompt + @staticmethod + def prompt() -> Prompt: + return Prompt.from_text( + text="Please call the appropriate tool to do the following task: {task}", + ) + + @chat_model_connection + @staticmethod + def slow_connection() -> ResourceDescriptor: + return ResourceDescriptor(clazz=SlowMockChatModelConnection) + + @chat_model_setup + @staticmethod + def slow_chat_model() -> ResourceDescriptor: + return ResourceDescriptor( + clazz=SlowMockChatModel, + connection="slow_connection", + prompt="prompt", + tools=["add"], + ) + + @tool + @staticmethod + def add(a: int, b: int) -> int: + """Calculate the sum of a and b.""" + time.sleep(0.05) # Simulate slow tool execution + return a + b + + @action(InputEvent) + @staticmethod + def process_input(event: InputEvent, ctx: RunnerContext) -> None: + input = event.input + ctx.send_event( + ChatRequestEvent( + model="slow_chat_model", + messages=[ + ChatMessage( + role=MessageRole.USER, content=input, extra_args={"task": input} + ) + ], + ) + ) + + @action(ChatResponseEvent) + @staticmethod + def process_chat_response(event: ChatResponseEvent, ctx: RunnerContext) -> None: + input = event.response + ctx.send_event(OutputEvent(output=input.content)) + + # Test that async execution works correctly + env = AgentsExecutionEnvironment.get_execution_environment() + input_list = [] + agent = AsyncTestAgent() + + output_list = env.from_list(input_list).apply(agent).to_list() + + input_list.append({"key": "0001", "value": "calculate the sum of 1 and 2."}) + + # Measure execution time to verify async doesn't block + start_time = time.time() + env.execute() + execution_time = time.time() - start_time + + # Verify results are correct + assert output_list == [ + { + "0001": "calculate the sum of 1 and 2.\n" + "Please call the appropriate tool to do the following task: " + "calculate the sum of 1 and 2.\n" + "3" + } + ] + + # Verify execution completed (async should allow this to complete) + # Note: Exact timing depends on implementation, but it should complete + assert execution_time < 5.0 # Should complete within reasonable time