Skip to content

Commit a83ed87

Browse files
declan-scaleclaude
andcommitted
refactor: address PR review comments on created_at plumbing
- Drop redundant in_temporal_workflow() guard in MessagesModule.create / create_batch; workflow_now_if_in_workflow() already returns None outside a workflow. - Extract the inline _take_created_at closure in providers/openai.py into a module-level _make_created_at_dispenser helper, eliminating the duplicated block between run_agent_auto_send and run_agent_streamed_auto_send. - Drop @pytest.mark.asyncio from tests/lib/adk/test_messages_module.py to match the project's asyncio_mode="auto" convention used in the sibling test_messages_service.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8982b27 commit a83ed87

3 files changed

Lines changed: 21 additions & 23 deletions

File tree

src/agentex/lib/adk/_modules/messages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def create(
8585
"""
8686
# Default created_at to workflow.now() so two awaited adk.messages.create
8787
# calls from the same workflow are guaranteed monotonic at the server.
88-
if created_at is None and in_temporal_workflow():
88+
if created_at is None:
8989
created_at = workflow_now_if_in_workflow()
9090
params = CreateMessageParams(
9191
trace_id=trace_id,
@@ -185,7 +185,7 @@ async def create_batch(
185185
Returns:
186186
List[TaskMessageEntity]: The created messages.
187187
"""
188-
if created_at is None and in_temporal_workflow():
188+
if created_at is None:
189189
created_at = workflow_now_if_in_workflow()
190190
params = CreateMessagesBatchParams(
191191
task_id=task_id,

src/agentex/lib/core/services/adk/providers/openai.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Literal
55
from datetime import datetime
66
from contextlib import AsyncExitStack, asynccontextmanager
7+
from collections.abc import Callable
78

89
from mcp import StdioServerParameters
910
from agents import Agent, Runner, RunResult, RunResultStreaming
@@ -74,6 +75,22 @@ async def mcp_server_context(
7475
yield servers
7576

7677

78+
def _make_created_at_dispenser(initial: datetime | None) -> Callable[[], datetime | None]:
79+
# Returns a closure that yields the workflow-supplied created_at exactly
80+
# once (on the first call), then None forever after. Used to stamp the
81+
# first agent message of a turn with workflow.now() while letting
82+
# subsequent messages fall back to server wall-clock — see the call sites
83+
# in run_agent_auto_send / run_agent_streamed_auto_send for context.
84+
pending: list[datetime | None] = [initial]
85+
86+
def take() -> datetime | None:
87+
value = pending[0]
88+
pending[0] = None
89+
return value
90+
91+
return take
92+
93+
7794
class OpenAIService:
7895
"""Service for OpenAI agent operations using the agents library."""
7996

@@ -377,15 +394,7 @@ async def run_agent_auto_send(
377394
) as span:
378395
heartbeat_if_in_workflow("run agent auto send")
379396

380-
# See run_agent_streamed_auto_send for the rationale: only the
381-
# first message opened in this turn carries the workflow-supplied
382-
# created_at; the rest fall back to server wall clock.
383-
_pending_created_at: list[datetime | None] = [created_at]
384-
385-
def _take_created_at() -> datetime | None:
386-
value = _pending_created_at[0]
387-
_pending_created_at[0] = None
388-
return value
397+
_take_created_at = _make_created_at_dispenser(created_at)
389398

390399
async with mcp_server_context(mcp_server_params, mcp_timeout_seconds) as servers:
391400
tools = (
@@ -758,12 +767,7 @@ async def run_agent_streamed_auto_send(
758767
# user-echo at the server. Subsequent messages in the same turn are
759768
# separated by network/processing latency and rely on the server's
760769
# wall clock.
761-
_pending_created_at: list[datetime | None] = [created_at]
762-
763-
def _take_created_at() -> datetime | None:
764-
value = _pending_created_at[0]
765-
_pending_created_at[0] = None
766-
return value
770+
_take_created_at = _make_created_at_dispenser(created_at)
767771

768772
async with mcp_server_context(mcp_server_params, mcp_timeout_seconds) as servers:
769773
tools = (

tests/lib/adk/test_messages_module.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from datetime import datetime, timezone
1212
from unittest.mock import AsyncMock, patch
1313

14-
import pytest
15-
1614
import agentex.lib.adk._modules.messages as _messages_mod
1715
from agentex.types.task_message import TaskMessage
1816
from agentex.types.text_content import TextContent
@@ -38,7 +36,6 @@ def _make_module() -> tuple[AsyncMock, MessagesModule]:
3836

3937

4038
class TestMessagesModuleCreate:
41-
@pytest.mark.asyncio
4239
async def test_outside_workflow_does_not_inject_created_at(self) -> None:
4340
mock_service, module = _make_module()
4441
mock_service.create_message.return_value = _make_task_message()
@@ -52,7 +49,6 @@ async def test_outside_workflow_does_not_inject_created_at(self) -> None:
5249
kwargs = mock_service.create_message.call_args.kwargs
5350
assert kwargs["created_at"] is None
5451

55-
@pytest.mark.asyncio
5652
async def test_inside_workflow_auto_injects_workflow_now(self) -> None:
5753
mock_service, module = _make_module()
5854
mock_service.create_message.return_value = _make_task_message()
@@ -80,7 +76,6 @@ async def fake_execute_activity(**call_kwargs):
8076
params = captured["request"]
8177
assert params.created_at == _FIXED_NOW
8278

83-
@pytest.mark.asyncio
8479
async def test_caller_supplied_created_at_is_respected(self) -> None:
8580
mock_service, module = _make_module()
8681
mock_service.create_message.return_value = _make_task_message()
@@ -99,7 +94,6 @@ async def test_caller_supplied_created_at_is_respected(self) -> None:
9994

10095

10196
class TestMessagesModuleCreateBatch:
102-
@pytest.mark.asyncio
10397
async def test_inside_workflow_auto_injects_workflow_now(self) -> None:
10498
mock_service, module = _make_module()
10599
mock_service.create_messages_batch.return_value = [_make_task_message()]

0 commit comments

Comments
 (0)