Skip to content

Commit 3d8f266

Browse files
committed
feat(openai_agents): forward previous_response_id from SDK kwarg
The OpenAI Agents SDK's `Model.get_response` abstract has three keyword-only parameters: `previous_response_id`, `conversation_id`, `prompt`. The SDK threads them down through `_ServerConversationTracker` when callers use `Runner.run(..., previous_response_id=X)` or set `RunConfig` with `auto_previous_response_id=True`. `TemporalStreamingModel.get_response` was declared with `**kwargs # noqa: ARG002`, which silently swallowed all three. Callers who used the SDK's official chaining API saw their `previous_response_id` disappear and got no stateful behavior — without an error. This commit: - Replaces `**kwargs` with explicit `previous_response_id`, `conversation_id`, `prompt` params, matching the abstract. - Forwards `previous_response_id` to `responses.create` via `_non_null_or_not_given` (so `None` resolves to `NOT_GIVEN` and the field is omitted from the request body — identical behavior to today for callers that don't opt in). - Accepts `conversation_id` and `prompt` for SDK contract compliance but does not forward them yet (marked `# noqa: ARG002`); they can be wired through later if a use case appears. ## Compatibility with non-OpenAI backends Same opt-in pattern as `prompt_cache_key`. `TemporalStreamingModel` calls `responses.create`, but the underlying client can be pointed at any OpenAI-compatible server (LiteLLM proxy, Foundry, vLLM, etc.). Some of those backends don't recognize `previous_response_id`. Because we forward it only when explicitly set, callers who don't opt in see no change in the wire request — the field is filtered out by `NOT_GIVEN`. Callers who opt in are responsible for knowing whether their backend supports it. ## Test housekeeping The 27 existing tests that passed `task_id=sample_task_id` to `get_response` were relying on `**kwargs` to silently swallow it. Production reads `task_id` from a ContextVar (set by `ContextInterceptor` in real Temporal flows, set by the `_streaming_context_vars` fixture in tests), not from a function argument. The kwarg was vestigial cruft. Removed.
1 parent eb8ff68 commit 3d8f266

2 files changed

Lines changed: 103 additions & 28 deletions

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ResponseReasoningSummaryTextDeltaEvent,
5252
ResponseFunctionCallArgumentsDeltaEvent,
5353
)
54+
from openai.types.responses.response_prompt_param import ResponsePromptParam
5455

5556
# AgentEx SDK imports
5657
from agentex.lib import adk
@@ -465,12 +466,25 @@ async def get_response(
465466
output_schema: Optional[AgentOutputSchemaBase],
466467
handoffs: list[Handoff],
467468
tracing: ModelTracing, # noqa: ARG002
468-
**kwargs, # noqa: ARG002
469+
*,
470+
previous_response_id: Optional[str] = None,
471+
conversation_id: Optional[str] = None, # noqa: ARG002
472+
prompt: Optional[ResponsePromptParam] = None, # noqa: ARG002
469473
) -> ModelResponse:
470474
"""Get a non-streaming response from the model with streaming to Redis.
471475
472476
This method is used by Temporal activities and needs to return a complete
473477
response, but we stream the response to Redis while generating it.
478+
479+
``previous_response_id`` enables stateful multi-turn chaining on the
480+
Responses API: when set, the server retains the prior response's
481+
chain-of-thought and only the new input items need to be sent. Forwarded
482+
only when explicitly set — not all OpenAI-compatible backends support
483+
this parameter, so the default is omitted from the request body via
484+
``NOT_GIVEN``.
485+
486+
``conversation_id`` and ``prompt`` are accepted to satisfy the
487+
``Model.get_response`` abstract contract but not currently forwarded.
474488
"""
475489

476490
task_id = streaming_task_id.get()
@@ -595,6 +609,7 @@ async def get_response(
595609
extra_query=model_settings.extra_query,
596610
extra_body=model_settings.extra_body,
597611
prompt_cache_key=prompt_cache_key,
612+
previous_response_id=self._non_null_or_not_given(previous_response_id),
598613
# Any additional parameters from extra_args
599614
**extra_args,
600615
)

src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py

Lines changed: 87 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ async def test_temperature_setting(self, streaming_model, _streaming_context_var
4343
output_schema=None,
4444
handoffs=[],
4545
tracing=None,
46-
task_id=sample_task_id
4746
)
4847

4948
# Verify temperature was passed correctly
@@ -73,7 +72,6 @@ async def test_top_p_setting(self, streaming_model, _streaming_context_vars, sam
7372
output_schema=None,
7473
handoffs=[],
7574
tracing=None,
76-
task_id=sample_task_id
7775
)
7876

7977
create_call = streaming_model.client.responses.create.call_args
@@ -101,7 +99,6 @@ async def test_max_tokens_setting(self, streaming_model, _streaming_context_vars
10199
output_schema=None,
102100
handoffs=[],
103101
tracing=None,
104-
task_id=sample_task_id
105102
)
106103

107104
create_call = streaming_model.client.responses.create.call_args
@@ -131,7 +128,6 @@ async def test_reasoning_effort_settings(self, streaming_model, _streaming_conte
131128
output_schema=None,
132129
handoffs=[],
133130
tracing=None,
134-
task_id=sample_task_id
135131
)
136132

137133
create_call = streaming_model.client.responses.create.call_args
@@ -161,7 +157,6 @@ async def test_reasoning_summary_settings(self, streaming_model, _streaming_cont
161157
output_schema=None,
162158
handoffs=[],
163159
tracing=None,
164-
task_id=sample_task_id
165160
)
166161

167162
create_call = streaming_model.client.responses.create.call_args
@@ -199,7 +194,6 @@ async def test_tool_choice_variations(self, streaming_model, _streaming_context_
199194
output_schema=None,
200195
handoffs=[],
201196
tracing=None,
202-
task_id=sample_task_id
203197
)
204198

205199
create_call = streaming_model.client.responses.create.call_args
@@ -227,7 +221,6 @@ async def test_parallel_tool_calls(self, streaming_model, _streaming_context_var
227221
output_schema=None,
228222
handoffs=[],
229223
tracing=None,
230-
task_id=sample_task_id
231224
)
232225

233226
create_call = streaming_model.client.responses.create.call_args
@@ -255,7 +248,6 @@ async def test_truncation_strategy(self, streaming_model, _streaming_context_var
255248
output_schema=None,
256249
handoffs=[],
257250
tracing=None,
258-
task_id=sample_task_id
259251
)
260252

261253
create_call = streaming_model.client.responses.create.call_args
@@ -284,7 +276,6 @@ async def test_response_include(self, streaming_model, _streaming_context_vars,
284276
output_schema=None,
285277
handoffs=[],
286278
tracing=None,
287-
task_id=sample_task_id
288279
)
289280

290281
create_call = streaming_model.client.responses.create.call_args
@@ -314,7 +305,6 @@ async def test_verbosity(self, streaming_model, _streaming_context_vars, sample_
314305
output_schema=None,
315306
handoffs=[],
316307
tracing=None,
317-
task_id=sample_task_id
318308
)
319309

320310
create_call = streaming_model.client.responses.create.call_args
@@ -347,7 +337,6 @@ async def test_metadata_and_store(self, streaming_model, _streaming_context_vars
347337
output_schema=None,
348338
handoffs=[],
349339
tracing=None,
350-
task_id=sample_task_id
351340
)
352341

353342
create_call = streaming_model.client.responses.create.call_args
@@ -383,7 +372,6 @@ async def test_extra_headers_and_body(self, streaming_model, _streaming_context_
383372
output_schema=None,
384373
handoffs=[],
385374
tracing=None,
386-
task_id=sample_task_id
387375
)
388376

389377
create_call = streaming_model.client.responses.create.call_args
@@ -412,7 +400,6 @@ async def test_top_logprobs(self, streaming_model, _streaming_context_vars, samp
412400
output_schema=None,
413401
handoffs=[],
414402
tracing=None,
415-
task_id=sample_task_id
416403
)
417404

418405
create_call = streaming_model.client.responses.create.call_args
@@ -445,7 +432,6 @@ async def test_function_tool(self, streaming_model, _streaming_context_vars, sam
445432
output_schema=None,
446433
handoffs=[],
447434
tracing=None,
448-
task_id=sample_task_id
449435
)
450436

451437
create_call = streaming_model.client.responses.create.call_args
@@ -475,7 +461,6 @@ async def test_web_search_tool(self, streaming_model, _streaming_context_vars, s
475461
output_schema=None,
476462
handoffs=[],
477463
tracing=None,
478-
task_id=sample_task_id
479464
)
480465

481466
create_call = streaming_model.client.responses.create.call_args
@@ -502,7 +487,6 @@ async def test_file_search_tool(self, streaming_model, _streaming_context_vars,
502487
output_schema=None,
503488
handoffs=[],
504489
tracing=None,
505-
task_id=sample_task_id
506490
)
507491

508492
create_call = streaming_model.client.responses.create.call_args
@@ -531,7 +515,6 @@ async def test_computer_tool(self, streaming_model, _streaming_context_vars, sam
531515
output_schema=None,
532516
handoffs=[],
533517
tracing=None,
534-
task_id=sample_task_id
535518
)
536519

537520
create_call = streaming_model.client.responses.create.call_args
@@ -563,7 +546,6 @@ async def test_multiple_computer_tools_error(self, streaming_model, _streaming_c
563546
output_schema=None,
564547
handoffs=[],
565548
tracing=None,
566-
task_id=sample_task_id
567549
)
568550

569551
@pytest.mark.asyncio
@@ -585,7 +567,6 @@ async def test_hosted_mcp_tool(self, streaming_model, _streaming_context_vars, s
585567
output_schema=None,
586568
handoffs=[],
587569
tracing=None,
588-
task_id=sample_task_id
589570
)
590571

591572
create_call = streaming_model.client.responses.create.call_args
@@ -613,7 +594,6 @@ async def test_image_generation_tool(self, streaming_model, _streaming_context_v
613594
output_schema=None,
614595
handoffs=[],
615596
tracing=None,
616-
task_id=sample_task_id
617597
)
618598

619599
create_call = streaming_model.client.responses.create.call_args
@@ -640,7 +620,6 @@ async def test_code_interpreter_tool(self, streaming_model, _streaming_context_v
640620
output_schema=None,
641621
handoffs=[],
642622
tracing=None,
643-
task_id=sample_task_id
644623
)
645624

646625
create_call = streaming_model.client.responses.create.call_args
@@ -667,7 +646,6 @@ async def test_local_shell_tool(self, streaming_model, _streaming_context_vars,
667646
output_schema=None,
668647
handoffs=[],
669648
tracing=None,
670-
task_id=sample_task_id
671649
)
672650

673651
create_call = streaming_model.client.responses.create.call_args
@@ -695,7 +673,6 @@ async def test_handoffs(self, streaming_model, _streaming_context_vars, sample_t
695673
output_schema=None,
696674
handoffs=[sample_handoff],
697675
tracing=None,
698-
task_id=sample_task_id
699676
)
700677

701678
create_call = streaming_model.client.responses.create.call_args
@@ -725,7 +702,6 @@ async def test_mixed_tools(self, streaming_model, _streaming_context_vars, sampl
725702
output_schema=None,
726703
handoffs=[sample_handoff],
727704
tracing=None,
728-
task_id=sample_task_id
729705
)
730706

731707
create_call = streaming_model.client.responses.create.call_args
@@ -770,7 +746,6 @@ async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming
770746
output_schema=None,
771747
handoffs=[],
772748
tracing=None,
773-
task_id=sample_task_id
774749
)
775750

776751
# Verify streaming context was created
@@ -845,7 +820,6 @@ async def test_redis_context_creation(self, streaming_model, mock_adk_streaming,
845820
output_schema=None,
846821
handoffs=[],
847822
tracing=None,
848-
task_id=sample_task_id
849823
)
850824

851825
# Should create at least one context for reasoning
@@ -1113,4 +1087,90 @@ async def test_prompt_cache_key_forwarded_when_opted_in(
11131087
kwargs = model.client.responses.create.call_args.kwargs
11141088
assert kwargs["prompt_cache_key"] == "my-key"
11151089
# Must be popped from extra_args so the SDK doesn't see it twice.
1116-
assert list(kwargs).count("prompt_cache_key") == 1
1090+
assert list(kwargs).count("prompt_cache_key") == 1
1091+
1092+
@pytest.mark.asyncio
1093+
async def test_previous_response_id_not_sent_by_default(
1094+
self,
1095+
streaming_model_with_mock_tracer,
1096+
_streaming_context_vars, # noqa: ARG002
1097+
):
1098+
"""Without an opt-in, previous_response_id resolves to NOT_GIVEN.
1099+
1100+
Critical for non-Responses-API-native backends (e.g. Claude-via-LiteLLM)
1101+
where unknown fields on the request body could be rejected. NOT_GIVEN
1102+
is filtered before serialization, so the field is omitted entirely.
1103+
"""
1104+
model = streaming_model_with_mock_tracer
1105+
completed = self._make_response_completed_event()
1106+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1107+
1108+
await model.get_response(
1109+
system_instructions=None,
1110+
input="hi",
1111+
model_settings=ModelSettings(),
1112+
tools=[],
1113+
output_schema=None,
1114+
handoffs=[],
1115+
tracing=None,
1116+
)
1117+
1118+
kwargs = model.client.responses.create.call_args.kwargs
1119+
assert kwargs["previous_response_id"] is NOT_GIVEN
1120+
1121+
@pytest.mark.asyncio
1122+
async def test_previous_response_id_forwarded_via_sdk_kwarg(
1123+
self,
1124+
streaming_model_with_mock_tracer,
1125+
_streaming_context_vars, # noqa: ARG002
1126+
):
1127+
"""The SDK threads previous_response_id as a keyword arg per Model.get_response
1128+
abstract contract. Verify it reaches responses.create instead of being silently
1129+
swallowed (which was the prior behavior under **kwargs)."""
1130+
model = streaming_model_with_mock_tracer
1131+
completed = self._make_response_completed_event()
1132+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1133+
1134+
await model.get_response(
1135+
system_instructions=None,
1136+
input="hi",
1137+
model_settings=ModelSettings(),
1138+
tools=[],
1139+
output_schema=None,
1140+
handoffs=[],
1141+
tracing=None,
1142+
previous_response_id="resp_prior_turn",
1143+
)
1144+
1145+
kwargs = model.client.responses.create.call_args.kwargs
1146+
assert kwargs["previous_response_id"] == "resp_prior_turn"
1147+
1148+
@pytest.mark.asyncio
1149+
async def test_conversation_id_and_prompt_accepted_but_not_forwarded(
1150+
self,
1151+
streaming_model_with_mock_tracer,
1152+
_streaming_context_vars, # noqa: ARG002
1153+
):
1154+
"""conversation_id and prompt are accepted to satisfy the SDK abstract
1155+
contract but not currently forwarded to responses.create."""
1156+
model = streaming_model_with_mock_tracer
1157+
completed = self._make_response_completed_event()
1158+
model.client.responses.create = AsyncMock(return_value=self._async_iter([completed]))
1159+
1160+
# Should not raise — both kwargs are accepted by the signature.
1161+
await model.get_response(
1162+
system_instructions=None,
1163+
input="hi",
1164+
model_settings=ModelSettings(),
1165+
tools=[],
1166+
output_schema=None,
1167+
handoffs=[],
1168+
tracing=None,
1169+
conversation_id="conv_test",
1170+
prompt=None,
1171+
)
1172+
1173+
kwargs = model.client.responses.create.call_args.kwargs
1174+
# Neither should appear in the outgoing request kwargs.
1175+
assert "conversation_id" not in kwargs
1176+
assert "prompt" not in kwargs

0 commit comments

Comments
 (0)