From 6586af3071221b4e8bd02d57c980b3ac86f35006 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 25 Feb 2025 00:02:09 -0800 Subject: [PATCH 1/8] Use .iter() API to fully replace existing streaming implementation --- docs/agents.md | 175 +++++++++++++++++++ pydantic_ai_slim/pydantic_ai/_agent_graph.py | 38 +++- pydantic_ai_slim/pydantic_ai/agent.py | 5 +- pydantic_ai_slim/pydantic_ai/messages.py | 17 +- pydantic_ai_slim/pydantic_ai/result.py | 122 ++++++++++++- 5 files changed, 351 insertions(+), 6 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 818e2db88..177577719 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -220,6 +220,181 @@ Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pyd --- +### Streaming + +Here is an example of streaming in combination with `async for`: + +```python {title="streaming.py"} +import asyncio +from dataclasses import dataclass +from datetime import date + +from pydantic_ai import ( + Agent, + capture_run_messages, +) +from pydantic_ai.agent import is_handle_response_node, is_model_request_node +from pydantic_ai.messages import ( + PartStartEvent, + PartDeltaEvent, + FunctionToolCallEvent, + FunctionToolResultEvent, + FinalResultEvent, TextPartDelta, ToolCallPartDelta, +) +from pydantic_ai.tools import RunContext +from pydantic_graph import End + + +@dataclass +class WeatherService: + async def get_forecast(self, location: str, forecast_date: date) -> str: + # In real code: call weather API, DB queries, etc. + return f"The forecast in {location} on {forecast_date} is 24°C and sunny." + + async def get_historic_weather(self, location: str, forecast_date: date) -> str: + # In real code: call a historical weather API or DB + return f"The weather in {location} on {forecast_date} was 18°C and partly cloudy." + + +weather_agent = Agent[WeatherService, str]( + "openai:gpt-4o", + deps_type=WeatherService, + result_type=str, # We'll produce a final answer as plain text + system_prompt="Providing a weather forecast at the locations the user provides.", +) + + +@weather_agent.tool +async def weather_forecast( + ctx: RunContext[WeatherService], + location: str, + forecast_date: date, +) -> str: + if forecast_date >= date.today(): + return await ctx.deps.get_forecast(location, forecast_date) + else: + return await ctx.deps.get_historic_weather(location, forecast_date) + + +async def main(): + # The user asks for tomorrow's weather in Paris + user_prompt = "What will the weather be like in Paris tomorrow?" + + # We'll capture raw messages for debugging + with capture_run_messages() as messages: + # Provide a WeatherService instance as the agent's dependencies + deps = WeatherService() + + # Begin a node-by-node, streaming iteration + with weather_agent.iter(user_prompt, deps=deps) as run: + node = run.next_node # The first node to run + while not isinstance(node, End): + if is_model_request_node(node): + # A model request node => We can stream tokens from the model's request + print("=== ModelRequestNode: streaming partial request tokens ===") + async with node.stream(run.ctx) as request_stream: + async for event in request_stream: + if isinstance(event, PartStartEvent): + print(f"[Request] Starting part {event.index}: {event.part!r}") + elif isinstance(event, PartDeltaEvent): + if isinstance(event.delta, TextPartDelta): + print(f"[Request] Part {event.index} text delta: {event.delta.content_delta!r}") + elif isinstance(event.delta, ToolCallPartDelta): + print(f"[Request] Part {event.index} args_delta={event.delta.args_delta}") + elif isinstance(event, FinalResultEvent): + print(f"[Result] The model produced a final result (tool_name={event.tool_name})") + + elif is_handle_response_node(node): + # A handle-response node => The model returned some data, potentially calls a tool + print("=== HandleResponseNode: streaming partial response & tool usage ===") + async with node.stream(run.ctx) as handle_stream: + async for event in handle_stream: + if isinstance(event, FunctionToolCallEvent): + print(f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})") + elif isinstance(event, FunctionToolResultEvent): + print(f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}") + + node = await run.next(node) + + # Once an End node is reached, the agent run is complete + assert run.result is not None + print("\n=== Final Agent Output ===") + print("Forecast:", run.result.data) + + # Show the raw messages exchanged + print("\n=== Raw Messages Captured ===") + for m in messages: + print(" -", m) + + +if __name__ == "__main__": + asyncio.run(main()) + +""" +=== ModelRequestNode: streaming partial request tokens === +[Request] Starting part 0: ToolCallPart(tool_name='weather_forecast', args='', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', part_kind='tool-call') +[Request] Part 0 args_delta={" +[Request] Part 0 args_delta=location +[Request] Part 0 args_delta=":" +[Request] Part 0 args_delta=Paris +[Request] Part 0 args_delta="," +[Request] Part 0 args_delta=forecast +[Request] Part 0 args_delta=_date +[Request] Part 0 args_delta=":" +[Request] Part 0 args_delta=202 +[Request] Part 0 args_delta=3 +[Request] Part 0 args_delta=- +[Request] Part 0 args_delta=11 +[Request] Part 0 args_delta=- +[Request] Part 0 args_delta=02 +[Request] Part 0 args_delta="} +=== HandleResponseNode: streaming partial response & tool usage === +[Tools] The LLM calls tool='weather_forecast' with args={"location":"Paris","forecast_date":"2023-11-02"} (tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R') +[Tools] Tool call 'call_Q0QqiZfIhHyNViiLG7jT0G9R' returned => The weather in Paris on 2023-11-02 was 18°C and partly cloudy. +=== ModelRequestNode: streaming partial request tokens === +[Request] Starting part 0: TextPart(content='', part_kind='text') +[Result] The model produced a final result (tool_name=None) +[Request] Part 0 text delta: 'The' +[Request] Part 0 text delta: ' weather' +[Request] Part 0 text delta: ' forecast' +[Request] Part 0 text delta: ' for' +[Request] Part 0 text delta: ' Paris' +[Request] Part 0 text delta: ' tomorrow' +[Request] Part 0 text delta: ',' +[Request] Part 0 text delta: ' November' +[Request] Part 0 text delta: ' ' +[Request] Part 0 text delta: '2' +[Request] Part 0 text delta: ',' +[Request] Part 0 text delta: ' ' +[Request] Part 0 text delta: '202' +[Request] Part 0 text delta: '3' +[Request] Part 0 text delta: ',' +[Request] Part 0 text delta: ' is' +[Request] Part 0 text delta: ' expected' +[Request] Part 0 text delta: ' to' +[Request] Part 0 text delta: ' be' +[Request] Part 0 text delta: ' ' +[Request] Part 0 text delta: '18' +[Request] Part 0 text delta: '°C' +[Request] Part 0 text delta: ' and' +[Request] Part 0 text delta: ' partly' +[Request] Part 0 text delta: ' cloudy' +[Request] Part 0 text delta: '.' +=== HandleResponseNode: streaming partial response & tool usage === + +=== Final Agent Output === +Forecast: The weather forecast for Paris tomorrow, November 2, 2023, is expected to be 18°C and partly cloudy. + +=== Raw Messages Captured === + - ModelRequest(parts=[SystemPromptPart(content='Providing a weather forecast at the locations the user provides.', dynamic_ref=None, part_kind='system-prompt'), UserPromptPart(content='What will the weather be like in Paris tomorrow?', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 4, 867863, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request') + - ModelResponse(parts=[ToolCallPart(tool_name='weather_forecast', args='{"location":"Paris","forecast_date":"2023-11-02"}', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', part_kind='tool-call')], model_name='gpt-4o', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 8, tzinfo=datetime.timezone.utc), kind='response') + - ModelRequest(parts=[ToolReturnPart(tool_name='weather_forecast', content='The weather in Paris on 2023-11-02 was 18°C and partly cloudy.', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 9, 150432, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request') + - ModelResponse(parts=[TextPart(content='The weather forecast for Paris tomorrow, November 2, 2023, is expected to be 18°C and partly cloudy.', part_kind='text')], model_name='gpt-4o', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 9, tzinfo=datetime.timezone.utc), kind='response') +""" +``` + +--- + ### Additional Configuration #### Usage Limits diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index ec9f16568..859cf2a81 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -10,7 +10,7 @@ from typing import Any, Generic, Literal, Union, cast import logfire_api -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeGuard, TypeVar, assert_never from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT @@ -40,6 +40,8 @@ 'HandleResponseNode', 'build_run_context', 'capture_run_messages', + 'is_model_request_node', + 'is_handle_response_node', ) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @@ -236,12 +238,30 @@ async def run( return await self._make_request(ctx) + @asynccontextmanager + async def stream( + self, + ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], + ) -> AsyncIterator[result.AgentStream[DepsT, T]]: + async with self._stream(ctx) as streamed_response: + agent_stream = result.AgentStream[DepsT, T]( + streamed_response, + ctx.deps.result_schema, + ctx.deps.result_validators, + build_run_context(ctx), + ctx.deps.usage_limits, + ) + yield agent_stream + # In case the user didn't manually consume the full stream, ensure it is fully consumed here, + # otherwise usage won't be properly counted: + async for _ in agent_stream: + pass + @asynccontextmanager async def _stream( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], ) -> AsyncIterator[models.StreamedResponse]: - # TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public assert not self._did_stream, 'stream() should only be called once per node' model_settings, model_request_parameters = await self._prepare_request(ctx) @@ -575,7 +595,7 @@ async def process_function_tools( for task in done: index = tasks.index(task) result = task.result() - yield _messages.FunctionToolResultEvent(result, call_id=call_index_to_event_id[index]) + yield _messages.FunctionToolResultEvent(result, tool_call_id=call_index_to_event_id[index]) if isinstance(result, (_messages.ToolReturnPart, _messages.RetryPromptPart)): results_by_index[index] = result else: @@ -685,3 +705,15 @@ def build_agent_graph( auto_instrument=False, ) return graph + + +def is_model_request_node( + node: BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], +) -> TypeGuard[ModelRequestNode[DepsT, NodeRunEndT]]: + return isinstance(node, ModelRequestNode) + + +def is_handle_response_node( + node: BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], +) -> TypeGuard[HandleResponseNode[DepsT, NodeRunEndT]]: + return isinstance(node, HandleResponseNode) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index bbc6eaa61..739537783 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -45,7 +45,8 @@ HandleResponseNode = _agent_graph.HandleResponseNode ModelRequestNode = _agent_graph.ModelRequestNode UserPromptNode = _agent_graph.UserPromptNode - +is_handle_response_node = _agent_graph.is_handle_response_node +is_model_request_node = _agent_graph.is_model_request_node __all__ = ( 'Agent', @@ -56,6 +57,8 @@ 'HandleResponseNode', 'ModelRequestNode', 'UserPromptNode', + 'is_handle_response_node', + 'is_model_request_node', ) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 661ef089d..bda8c6201 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -533,9 +533,24 @@ class PartDeltaEvent: """Event type identifier, used as a discriminator.""" +@dataclass +class FinalResultEvent: + """An event indicating the response to the current model request matches the result schema.""" + + tool_name: str | None + """The name of the result tool that was called. `None` if the result is from text content and not from a tool.""" + event_kind: Literal['final_result'] = 'final_result' + """Event type identifier, used as a discriminator.""" + + ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')] """An event in the model response stream, either starting a new part or applying a delta to an existing one.""" +AgentStreamEvent = Annotated[ + Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind') +] +"""An event in the agent stream.""" + @dataclass class FunctionToolCallEvent: @@ -558,7 +573,7 @@ class FunctionToolResultEvent: result: ToolReturnPart | RetryPromptPart """The result of the call to the function tool.""" - call_id: str + tool_call_id: str """An ID used to match the result to its original call.""" event_kind: Literal['function_tool_result'] = 'function_tool_result' """Event type identifier, used as a discriminator.""" diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 7646de5bf..140e5a025 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -7,9 +7,10 @@ from typing import Generic, Union, cast import logfire_api -from typing_extensions import TypeVar +from typing_extensions import TypeVar, assert_type from . import _result, _utils, exceptions, messages as _messages, models +from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits @@ -51,6 +52,125 @@ _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') +@dataclass +class AgentStream(Generic[AgentDepsT, ResultDataT]): + _raw_stream_response: models.StreamedResponse + _result_schema: _result.ResultSchema[ResultDataT] | None + _result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]] + _run_ctx: RunContext[AgentDepsT] + _usage_limits: UsageLimits | None + + _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) + _final_result_event: FinalResultEvent | None = field(default=None, init=False) + _initial_run_ctx_usage: Usage = field(init=False) + + def __post_init__(self): + self._initial_run_ctx_usage = copy(self._run_ctx.usage) + + async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]: + """Asynchronously stream the (validated) agent outputs.""" + async for response in self.stream_responses(debounce_by=debounce_by): + if self._final_result_event is not None: + yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True) + if self._final_result_event is not None: + yield await self._validate_response( + self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False + ) + + async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: + """Asynchronously stream the (unvalidated) model responses for the agent.""" + # if the message currently has any parts with content, yield before streaming + msg = self._raw_stream_response.get() + for part in msg.parts: + if part.has_content(): + yield msg + break + + async with _utils.group_by_temporal(self, debounce_by) as group_iter: + async for _items in group_iter: + yield self._raw_stream_response.get() # current state of the response + + def usage(self) -> Usage: + """Return the usage of the whole run. + + !!! note + This won't return the full usage until the stream is finished. + """ + return self._initial_run_ctx_usage + self._raw_stream_response.usage() + + async def _validate_response( + self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False + ) -> ResultDataT: + """Validate a structured result message.""" + if self._result_schema is not None and result_tool_name is not None: + match = self._result_schema.find_named_tool(message.parts, result_tool_name) + if match is None: + raise exceptions.UnexpectedModelBehavior( + f'Invalid response, unable to find tool: {self._result_schema.tool_names()}' + ) + + call, result_tool = match + result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) + + for validator in self._result_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) + return result_data + else: + text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) + for validator in self._result_validators: + text = await validator.validate( + text, + None, + self._run_ctx, + ) + # Since there is no result tool, we can assume that str is compatible with ResultDataT + return cast(ResultDataT, text) + + def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: + """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. + + This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches + on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the + first match is found. + """ + if self._agent_stream_iterator is not None: + return self._agent_stream_iterator + + async def aiter(): + result_schema = self._result_schema + allow_text_result = result_schema is None or result_schema.allow_text_result + + def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: + """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" + if isinstance(e, _messages.PartStartEvent): + new_part = e.part + if isinstance(new_part, _messages.ToolCallPart): + if result_schema is not None and (match := result_schema.find_tool([new_part])): + call, _ = match + return _messages.FinalResultEvent(tool_name=call.tool_name) + elif allow_text_result: + assert_type(e, _messages.PartStartEvent) + return _messages.FinalResultEvent(tool_name=None) + + usage_checking_stream = _get_usage_checking_stream_response( + self._raw_stream_response, self._usage_limits, self.usage + ) + async for event in usage_checking_stream: + yield event + if (final_result_event := _get_final_result_event(event)) is not None: + self._final_result_event = final_result_event + yield final_result_event + break + + # If we broke out of the above loop, we need to yield the rest of the events + # If we didn't, this will just be a no-op + async for event in usage_checking_stream: + yield event + + self._agent_stream_iterator = aiter() + return self._agent_stream_iterator + + @dataclass class StreamedRunResult(Generic[AgentDepsT, ResultDataT]): """Result of a streamed run that returns structured data via a tool call.""" From 6659a0ddb63282039d281154997c8b3e11dc2778 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 25 Feb 2025 08:30:20 -0800 Subject: [PATCH 2/8] Use async for in the docs --- docs/agents.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 177577719..05298ea72 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -242,7 +242,6 @@ from pydantic_ai.messages import ( FinalResultEvent, TextPartDelta, ToolCallPartDelta, ) from pydantic_ai.tools import RunContext -from pydantic_graph import End @dataclass @@ -287,8 +286,7 @@ async def main(): # Begin a node-by-node, streaming iteration with weather_agent.iter(user_prompt, deps=deps) as run: - node = run.next_node # The first node to run - while not isinstance(node, End): + async for node in run: if is_model_request_node(node): # A model request node => We can stream tokens from the model's request print("=== ModelRequestNode: streaming partial request tokens ===") @@ -314,8 +312,6 @@ async def main(): elif isinstance(event, FunctionToolResultEvent): print(f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}") - node = await run.next(node) - # Once an End node is reached, the agent run is complete assert run.result is not None print("\n=== Final Agent Output ===") From 633d950ab5b0d0cf9a097b527c244940ea198733 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:28:51 -0800 Subject: [PATCH 3/8] A bit more clean-up --- docs/agents.md | 135 ++++++++++--------- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 45 ++++--- pydantic_ai_slim/pydantic_ai/agent.py | 115 +++++++++------- 3 files changed, 163 insertions(+), 132 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index e6fa41c36..8c93dfae0 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -222,24 +222,22 @@ Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pyd ### Streaming -Here is an example of streaming in combination with `async for`: +Here is an example of streaming an agent run in combination with `async for` iteration: -```python {title="streaming.py"} +```python {title="streaming.py"} {test="skip"} import asyncio from dataclasses import dataclass from datetime import date -from pydantic_ai import ( - Agent, - capture_run_messages, -) -from pydantic_ai.agent import is_handle_response_node, is_model_request_node +from pydantic_ai import Agent from pydantic_ai.messages import ( - PartStartEvent, - PartDeltaEvent, + FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, - FinalResultEvent, TextPartDelta, ToolCallPartDelta, + PartDeltaEvent, + PartStartEvent, + TextPartDelta, + ToolCallPartDelta, ) from pydantic_ai.tools import RunContext @@ -248,18 +246,20 @@ from pydantic_ai.tools import RunContext class WeatherService: async def get_forecast(self, location: str, forecast_date: date) -> str: # In real code: call weather API, DB queries, etc. - return f"The forecast in {location} on {forecast_date} is 24°C and sunny." + return f'The forecast in {location} on {forecast_date} is 24°C and sunny.' async def get_historic_weather(self, location: str, forecast_date: date) -> str: # In real code: call a historical weather API or DB - return f"The weather in {location} on {forecast_date} was 18°C and partly cloudy." + return ( + f'The weather in {location} on {forecast_date} was 18°C and partly cloudy.' + ) weather_agent = Agent[WeatherService, str]( - "openai:gpt-4o", + 'openai:gpt-4o', deps_type=WeatherService, result_type=str, # We'll produce a final answer as plain text - system_prompt="Providing a weather forecast at the locations the user provides.", + system_prompt='Providing a weather forecast at the locations the user provides.', ) @@ -277,53 +277,60 @@ async def weather_forecast( async def main(): # The user asks for tomorrow's weather in Paris - user_prompt = "What will the weather be like in Paris tomorrow?" - - # We'll capture raw messages for debugging - with capture_run_messages() as messages: - # Provide a WeatherService instance as the agent's dependencies - deps = WeatherService() - - # Begin a node-by-node, streaming iteration - with weather_agent.iter(user_prompt, deps=deps) as run: - async for node in run: - if is_model_request_node(node): - # A model request node => We can stream tokens from the model's request - print("=== ModelRequestNode: streaming partial request tokens ===") - async with node.stream(run.ctx) as request_stream: - async for event in request_stream: - if isinstance(event, PartStartEvent): - print(f"[Request] Starting part {event.index}: {event.part!r}") - elif isinstance(event, PartDeltaEvent): - if isinstance(event.delta, TextPartDelta): - print(f"[Request] Part {event.index} text delta: {event.delta.content_delta!r}") - elif isinstance(event.delta, ToolCallPartDelta): - print(f"[Request] Part {event.index} args_delta={event.delta.args_delta}") - elif isinstance(event, FinalResultEvent): - print(f"[Result] The model produced a final result (tool_name={event.tool_name})") - - elif is_handle_response_node(node): - # A handle-response node => The model returned some data, potentially calls a tool - print("=== HandleResponseNode: streaming partial response & tool usage ===") - async with node.stream(run.ctx) as handle_stream: - async for event in handle_stream: - if isinstance(event, FunctionToolCallEvent): - print(f"[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})") - elif isinstance(event, FunctionToolResultEvent): - print(f"[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}") - - # Once an End node is reached, the agent run is complete - assert run.result is not None - print("\n=== Final Agent Output ===") - print("Forecast:", run.result.data) - - # Show the raw messages exchanged - print("\n=== Raw Messages Captured ===") - for m in messages: - print(" -", m) - - -if __name__ == "__main__": + user_prompt = 'What will the weather be like in Paris tomorrow?' + + # Provide a WeatherService instance as the agent's dependencies + deps = WeatherService() + + # Begin a node-by-node, streaming iteration + with weather_agent.iter(user_prompt, deps=deps) as run: + async for node in run: + if Agent.is_model_request_node(node): + # A model request node => We can stream tokens from the model's request + print('=== ModelRequestNode: streaming partial request tokens ===') + async with node.stream(run.ctx) as request_stream: + async for event in request_stream: + if isinstance(event, PartStartEvent): + print( + f'[Request] Starting part {event.index}: {event.part!r}' + ) + elif isinstance(event, PartDeltaEvent): + if isinstance(event.delta, TextPartDelta): + print( + f'[Request] Part {event.index} text delta: {event.delta.content_delta!r}' + ) + elif isinstance(event.delta, ToolCallPartDelta): + print( + f'[Request] Part {event.index} args_delta={event.delta.args_delta}' + ) + elif isinstance(event, FinalResultEvent): + print( + f'[Result] The model produced a final result (tool_name={event.tool_name})' + ) + + elif Agent.is_handle_response_node(node): + # A handle-response node => The model returned some data, potentially calls a tool + print( + '=== HandleResponseNode: streaming partial response & tool usage ===' + ) + async with node.stream(run.ctx) as handle_stream: + async for event in handle_stream: + if isinstance(event, FunctionToolCallEvent): + print( + f'[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})' + ) + elif isinstance(event, FunctionToolResultEvent): + print( + f'[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}' + ) + + # Once an End node is reached, the agent run is complete + assert run.result is not None + print('\n=== Final Agent Output ===') + print('Forecast:', run.result.data) + + +if __name__ == '__main__': asyncio.run(main()) """ @@ -380,12 +387,6 @@ if __name__ == "__main__": === Final Agent Output === Forecast: The weather forecast for Paris tomorrow, November 2, 2023, is expected to be 18°C and partly cloudy. - -=== Raw Messages Captured === - - ModelRequest(parts=[SystemPromptPart(content='Providing a weather forecast at the locations the user provides.', dynamic_ref=None, part_kind='system-prompt'), UserPromptPart(content='What will the weather be like in Paris tomorrow?', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 4, 867863, tzinfo=datetime.timezone.utc), part_kind='user-prompt')], kind='request') - - ModelResponse(parts=[ToolCallPart(tool_name='weather_forecast', args='{"location":"Paris","forecast_date":"2023-11-02"}', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', part_kind='tool-call')], model_name='gpt-4o', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 8, tzinfo=datetime.timezone.utc), kind='response') - - ModelRequest(parts=[ToolReturnPart(tool_name='weather_forecast', content='The weather in Paris on 2023-11-02 was 18°C and partly cloudy.', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 9, 150432, tzinfo=datetime.timezone.utc), part_kind='tool-return')], kind='request') - - ModelResponse(parts=[TextPart(content='The weather forecast for Paris tomorrow, November 2, 2023, is expected to be 18°C and partly cloudy.', part_kind='text')], model_name='gpt-4o', timestamp=datetime.datetime(2025, 2, 25, 7, 16, 9, tzinfo=datetime.timezone.utc), kind='response') """ ``` diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 859cf2a81..d4a7ac14f 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -2,7 +2,6 @@ import asyncio import dataclasses -from abc import ABC from collections.abc import AsyncIterator, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -40,8 +39,6 @@ 'HandleResponseNode', 'build_run_context', 'capture_run_messages', - 'is_model_request_node', - 'is_handle_response_node', ) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @@ -57,6 +54,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) T = TypeVar('T') +S = TypeVar('S') NoneType = type(None) EndStrategy = Literal['early', 'exhaustive'] """The strategy for handling multiple tool calls when a final result is found. @@ -109,8 +107,31 @@ class GraphAgentDeps(Generic[DepsT, ResultDataT]): run_span: logfire_api.LogfireSpan +class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): + """The base class for all agent nodes. + + Using subclass of `BaseNode` for all nodes reduces the amount of boilerplate of generics everywhere + """ + + +def is_agent_node( + node: BaseNode[GraphAgentState, GraphAgentDeps[T, Any], result.FinalResult[S]] | End[result.FinalResult[S]], +) -> TypeGuard[AgentNode[T, S]]: + """Check if the provided node is an instance of `AgentNode`. + + Usage: + + if is_agent_node(node): + # `node` is an AgentNode + ... + + This method preserves the generic parameters on the narrowed type, unlike `isinstance(node, AgentNode)`. + """ + return isinstance(node, AgentNode) + + @dataclasses.dataclass -class UserPromptNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], ABC): +class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): user_prompt: str | Sequence[_messages.UserContent] system_prompts: tuple[str, ...] @@ -217,7 +238,7 @@ async def add_tool(tool: Tool[DepsT]) -> None: @dataclasses.dataclass -class ModelRequestNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): +class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): """Make a request to the model using the last message in state.message_history.""" request: _messages.ModelRequest @@ -339,7 +360,7 @@ def _finish_handling( @dataclasses.dataclass -class HandleResponseNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): +class HandleResponseNode(AgentNode[DepsT, NodeRunEndT]): """Process a model response, and decide whether to end the run or make a new request.""" model_response: _messages.ModelResponse @@ -705,15 +726,3 @@ def build_agent_graph( auto_instrument=False, ) return graph - - -def is_model_request_node( - node: BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], -) -> TypeGuard[ModelRequestNode[DepsT, NodeRunEndT]]: - return isinstance(node, ModelRequestNode) - - -def is_handle_response_node( - node: BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]], -) -> TypeGuard[HandleResponseNode[DepsT, NodeRunEndT]]: - return isinstance(node, HandleResponseNode) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f737daf71..62789ea29 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -7,12 +7,12 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from copy import deepcopy from types import FrameType -from typing import Any, Callable, Generic, cast, final, overload +from typing import Any, Callable, Generic, TypeGuard, cast, final, overload import logfire_api from typing_extensions import TypeVar, deprecated -from pydantic_graph import BaseNode, End, Graph, GraphRun, GraphRunContext +from pydantic_graph import End, Graph, GraphRun, GraphRunContext from . import ( _agent_graph, @@ -45,8 +45,6 @@ HandleResponseNode = _agent_graph.HandleResponseNode ModelRequestNode = _agent_graph.ModelRequestNode UserPromptNode = _agent_graph.UserPromptNode -is_handle_response_node = _agent_graph.is_handle_response_node -is_model_request_node = _agent_graph.is_model_request_node __all__ = ( 'Agent', @@ -57,8 +55,6 @@ 'HandleResponseNode', 'ModelRequestNode', 'UserPromptNode', - 'is_handle_response_node', - 'is_model_request_node', ) _logfire = logfire_api.Logfire(otel_scope='pydantic-ai') @@ -74,6 +70,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) T = TypeVar('T') +S = TypeVar('S') NoneType = type(None) RunResultDataT = TypeVar('RunResultDataT') """Type variable for the result data of a run where `result_type` was customized on the run call.""" @@ -649,10 +646,9 @@ async def main(): ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node - node: BaseNode[Any, Any, Any] = cast(BaseNode[Any, Any, Any], first_node) + node = first_node while True: - if isinstance(node, _agent_graph.ModelRequestNode): - node = cast(_agent_graph.ModelRequestNode[AgentDepsT, Any], node) + if self.is_model_request_node(node): graph_ctx = agent_run.ctx async with node._stream(graph_ctx) as streamed_response: # pyright: ignore[reportPrivateUsage] @@ -720,9 +716,9 @@ async def on_complete() -> None: ) break next_node = await agent_run.next(node) - if not isinstance(next_node, BaseNode): + if not isinstance(next_node, _agent_graph.AgentNode): raise exceptions.AgentRunError('Should have produced a StreamedRunResult before getting here') - node = cast(BaseNode[Any, Any, Any], next_node) + node = cast(_agent_graph.AgentNode[Any, Any], next_node) if not yielded: raise exceptions.AgentRunError('Agent run finished without producing a final result') @@ -1176,6 +1172,46 @@ def _prepare_result_schema( else: return self._result_schema # pyright: ignore[reportReturnType] + @staticmethod + def is_model_request_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[_agent_graph.ModelRequestNode[T, S]]: + """Check if the node is a `ModelRequestNode`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.ModelRequestNode) + + @staticmethod + def is_handle_response_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[_agent_graph.HandleResponseNode[T, S]]: + """Check if the node is a `HandleResponseNode`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.HandleResponseNode) + + @staticmethod + def is_user_prompt_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[_agent_graph.UserPromptNode[T, S]]: + """Check if the node is a `UserPromptNode`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, _agent_graph.UserPromptNode) + + @staticmethod + def is_end_node( + node: _agent_graph.AgentNode[T, S] | End[result.FinalResult[S]], + ) -> TypeGuard[End[result.FinalResult[S]]]: + """Check if the node is a `End`, narrowing the type if it is. + + This method preserves the generic parameters while narrowing the type, unlike a direct call to `isinstance`. + """ + return isinstance(node, End) + @dataclasses.dataclass(repr=False) class AgentRun(Generic[AgentDepsT, ResultDataT]): @@ -1247,15 +1283,17 @@ def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.Grap @property def next_node( self, - ) -> ( - BaseNode[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[ResultDataT]] - | End[FinalResult[ResultDataT]] - ): + ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]: """The next node that will be run in the agent graph. This is the next node that will be used during async iteration, or if a node is not passed to `self.next(...)`. """ - return self._graph_run.next_node + next_node = self._graph_run.next_node + if isinstance(next_node, End): + return next_node + if _agent_graph.is_agent_node(next_node): + return next_node + raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') @property def result(self) -> AgentRunResult[ResultDataT] | None: @@ -1276,45 +1314,24 @@ def result(self) -> AgentRunResult[ResultDataT] | None: def __aiter__( self, - ) -> AsyncIterator[ - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ] - | End[FinalResult[ResultDataT]] - ]: + ) -> AsyncIterator[_agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]]: """Provide async-iteration over the nodes in the agent run.""" return self async def __anext__( self, - ) -> ( - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ] - | End[FinalResult[ResultDataT]] - ): + ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]: """Advance to the next node automatically based on the last returned node.""" - return await self._graph_run.__anext__() + next_node = await self._graph_run.__anext__() + if _agent_graph.is_agent_node(next_node): + return next_node + assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' + return next_node async def next( self, - node: BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ], - ) -> ( - BaseNode[ - _agent_graph.GraphAgentState, - _agent_graph.GraphAgentDeps[AgentDepsT, Any], - FinalResult[ResultDataT], - ] - | End[FinalResult[ResultDataT]] - ): + node: _agent_graph.AgentNode[AgentDepsT, ResultDataT], + ) -> _agent_graph.AgentNode[AgentDepsT, ResultDataT] | End[FinalResult[ResultDataT]]: """Manually drive the agent run by passing in the node you want to run next. This lets you inspect or mutate the node before continuing execution, or skip certain nodes @@ -1381,7 +1398,11 @@ async def main(): """ # Note: It might be nice to expose a synchronous interface for iteration, but we shouldn't do it # on this class, or else IDEs won't warn you if you accidentally use `for` instead of `async for` to iterate. - return await self._graph_run.next(node) + next_node = await self._graph_run.next(node) + if _agent_graph.is_agent_node(next_node): + return next_node + assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}' + return next_node def usage(self) -> _usage.Usage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" From 181578bfc776faa90b8f12dc243b0dcea767e227 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Tue, 25 Feb 2025 22:30:41 -0800 Subject: [PATCH 4/8] Fix typeguard import --- pydantic_ai_slim/pydantic_ai/agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 62789ea29..4d0979e95 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -7,10 +7,10 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager from copy import deepcopy from types import FrameType -from typing import Any, Callable, Generic, TypeGuard, cast, final, overload +from typing import Any, Callable, Generic, cast, final, overload import logfire_api -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeGuard, TypeVar, deprecated from pydantic_graph import End, Graph, GraphRun, GraphRunContext From 8f383b405d2c9b029085d170f6e33a0c5c58d2a4 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 27 Feb 2025 10:59:21 -0700 Subject: [PATCH 5/8] Update test --- docs/agents.md | 110 ++++++++++++++--------------------------- tests/test_examples.py | 78 +++++++++++++++++++---------- 2 files changed, 91 insertions(+), 97 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 8c93dfae0..6686bae5c 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -224,7 +224,7 @@ Once the run finishes, `agent_run.final_result` becomes a [`AgentRunResult`][pyd Here is an example of streaming an agent run in combination with `async for` iteration: -```python {title="streaming.py"} {test="skip"} +```python {title="streaming.py"} import asyncio from dataclasses import dataclass from datetime import date @@ -275,119 +275,85 @@ async def weather_forecast( return await ctx.deps.get_historic_weather(location, forecast_date) -async def main(): - # The user asks for tomorrow's weather in Paris - user_prompt = 'What will the weather be like in Paris tomorrow?' +output_messages: list[str] = [] + - # Provide a WeatherService instance as the agent's dependencies - deps = WeatherService() +async def main(): + user_prompt = 'What will the weather be like in Paris on Tuesday?' # Begin a node-by-node, streaming iteration - with weather_agent.iter(user_prompt, deps=deps) as run: + with weather_agent.iter(user_prompt, deps=WeatherService()) as run: async for node in run: if Agent.is_model_request_node(node): # A model request node => We can stream tokens from the model's request - print('=== ModelRequestNode: streaming partial request tokens ===') + output_messages.append( + '=== ModelRequestNode: streaming partial request tokens ===' + ) async with node.stream(run.ctx) as request_stream: async for event in request_stream: if isinstance(event, PartStartEvent): - print( + output_messages.append( f'[Request] Starting part {event.index}: {event.part!r}' ) elif isinstance(event, PartDeltaEvent): if isinstance(event.delta, TextPartDelta): - print( + output_messages.append( f'[Request] Part {event.index} text delta: {event.delta.content_delta!r}' ) elif isinstance(event.delta, ToolCallPartDelta): - print( + output_messages.append( f'[Request] Part {event.index} args_delta={event.delta.args_delta}' ) elif isinstance(event, FinalResultEvent): - print( + output_messages.append( f'[Result] The model produced a final result (tool_name={event.tool_name})' ) elif Agent.is_handle_response_node(node): # A handle-response node => The model returned some data, potentially calls a tool - print( + output_messages.append( '=== HandleResponseNode: streaming partial response & tool usage ===' ) async with node.stream(run.ctx) as handle_stream: async for event in handle_stream: if isinstance(event, FunctionToolCallEvent): - print( + output_messages.append( f'[Tools] The LLM calls tool={event.part.tool_name!r} with args={event.part.args} (tool_call_id={event.part.tool_call_id!r})' ) elif isinstance(event, FunctionToolResultEvent): - print( + output_messages.append( f'[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}' ) # Once an End node is reached, the agent run is complete assert run.result is not None - print('\n=== Final Agent Output ===') - print('Forecast:', run.result.data) + output_messages.append(f'=== Final Agent Output: {run.result.data} ===') if __name__ == '__main__': asyncio.run(main()) -""" -=== ModelRequestNode: streaming partial request tokens === -[Request] Starting part 0: ToolCallPart(tool_name='weather_forecast', args='', tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R', part_kind='tool-call') -[Request] Part 0 args_delta={" -[Request] Part 0 args_delta=location -[Request] Part 0 args_delta=":" -[Request] Part 0 args_delta=Paris -[Request] Part 0 args_delta="," -[Request] Part 0 args_delta=forecast -[Request] Part 0 args_delta=_date -[Request] Part 0 args_delta=":" -[Request] Part 0 args_delta=202 -[Request] Part 0 args_delta=3 -[Request] Part 0 args_delta=- -[Request] Part 0 args_delta=11 -[Request] Part 0 args_delta=- -[Request] Part 0 args_delta=02 -[Request] Part 0 args_delta="} -=== HandleResponseNode: streaming partial response & tool usage === -[Tools] The LLM calls tool='weather_forecast' with args={"location":"Paris","forecast_date":"2023-11-02"} (tool_call_id='call_Q0QqiZfIhHyNViiLG7jT0G9R') -[Tools] Tool call 'call_Q0QqiZfIhHyNViiLG7jT0G9R' returned => The weather in Paris on 2023-11-02 was 18°C and partly cloudy. -=== ModelRequestNode: streaming partial request tokens === -[Request] Starting part 0: TextPart(content='', part_kind='text') -[Result] The model produced a final result (tool_name=None) -[Request] Part 0 text delta: 'The' -[Request] Part 0 text delta: ' weather' -[Request] Part 0 text delta: ' forecast' -[Request] Part 0 text delta: ' for' -[Request] Part 0 text delta: ' Paris' -[Request] Part 0 text delta: ' tomorrow' -[Request] Part 0 text delta: ',' -[Request] Part 0 text delta: ' November' -[Request] Part 0 text delta: ' ' -[Request] Part 0 text delta: '2' -[Request] Part 0 text delta: ',' -[Request] Part 0 text delta: ' ' -[Request] Part 0 text delta: '202' -[Request] Part 0 text delta: '3' -[Request] Part 0 text delta: ',' -[Request] Part 0 text delta: ' is' -[Request] Part 0 text delta: ' expected' -[Request] Part 0 text delta: ' to' -[Request] Part 0 text delta: ' be' -[Request] Part 0 text delta: ' ' -[Request] Part 0 text delta: '18' -[Request] Part 0 text delta: '°C' -[Request] Part 0 text delta: ' and' -[Request] Part 0 text delta: ' partly' -[Request] Part 0 text delta: ' cloudy' -[Request] Part 0 text delta: '.' -=== HandleResponseNode: streaming partial response & tool usage === - -=== Final Agent Output === -Forecast: The weather forecast for Paris tomorrow, November 2, 2023, is expected to be 18°C and partly cloudy. -""" + print(output_messages) + """ + [ + '=== ModelRequestNode: streaming partial request tokens ===', + '[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=None, part_kind=\'tool-call\')', + '[Request] Part 0 args_delta=ris","forecast_', + '[Request] Part 0 args_delta=date":"2030-01-', + '[Request] Part 0 args_delta=01"}', + '=== HandleResponseNode: streaming partial response & tool usage ===', + '[Tools] The LLM calls tool=\'weather_forecast\' with args={"location":"Paris","forecast_date":"2030-01-01"} (tool_call_id=None)', + "[Tools] Tool call 'ea89591f-d193-435f-9640-f095f165bb89' returned => The forecast in Paris on 2030-01-01 is 24°C and sunny.", + '=== ModelRequestNode: streaming partial request tokens ===', + "[Request] Starting part 0: TextPart(content='It will be ', part_kind='text')", + '[Result] The model produced a final result (tool_name=None)', + "[Request] Part 0 text delta: 'warm and sunny '", + "[Request] Part 0 text delta: 'in Paris on '", + "[Request] Part 0 text delta: 'Tuesday.'", + '=== HandleResponseNode: streaming partial response & tool usage ===', + '=== Final Agent Output: It will be warm and sunny in Paris on Tuesday. ===', + ] + """ ``` --- diff --git a/tests/test_examples.py b/tests/test_examples.py index 6974d319a..6715890ff 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -182,6 +182,9 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: 'What is the weather like in West London and in Wiltshire?': ( 'The weather in West London is raining, while in Wiltshire it is sunny.' ), + 'What will the weather be like in Paris on Tuesday?': ToolCallPart( + tool_name='weather_forecast', args={'location': 'Paris', 'forecast_date': '2030-01-01'} + ), 'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.', 'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.', 'What is the capital of France?': 'Paris', @@ -270,6 +273,13 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: ), } +tool_responses: dict[tuple[str, str], str] = { + ( + 'weather_forecast', + 'The forecast in Paris on 2030-01-01 is 24°C and sunny.', + ): 'It will be warm and sunny in Paris on Tuesday.', +} + async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: # pragma: no cover # noqa: C901 m = messages[-1].parts[-1] @@ -348,35 +358,53 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes raise RuntimeError(f'Unexpected message: {m}') -async def stream_model_logic( +async def stream_model_logic( # noqa C901 messages: list[ModelMessage], info: AgentInfo ) -> AsyncIterator[str | DeltaToolCalls]: # pragma: no cover - m = messages[-1].parts[-1] - if isinstance(m, UserPromptPart): - assert isinstance(m.content, str) - if response := text_responses.get(m.content): - if isinstance(response, str): - words = response.split(' ') - chunk: list[str] = [] - for work in words: - chunk.append(work) - if len(chunk) == 3: - yield ' '.join(chunk) + ' ' - chunk.clear() - if chunk: - yield ' '.join(chunk) - return - else: - json_text = response.args_as_json_str() - - yield {1: DeltaToolCall(name=response.tool_name)} - for chunk_index in range(0, len(json_text), 15): - text_chunk = json_text[chunk_index : chunk_index + 15] - yield {1: DeltaToolCall(json_args=text_chunk)} - return + async def stream_text_response(r: str) -> AsyncIterator[str]: + if isinstance(r, str): + words = r.split(' ') + chunk: list[str] = [] + for word in words: + chunk.append(word) + if len(chunk) == 3: + yield ' '.join(chunk) + ' ' + chunk.clear() + if chunk: + yield ' '.join(chunk) + + async def stream_tool_call_response(r: ToolCallPart) -> AsyncIterator[DeltaToolCalls]: + json_text = r.args_as_json_str() + + yield {1: DeltaToolCall(name=r.tool_name)} + for chunk_index in range(0, len(json_text), 15): + text_chunk = json_text[chunk_index : chunk_index + 15] + yield {1: DeltaToolCall(json_args=text_chunk)} + + async def stream_part_response(r: str | ToolCallPart) -> AsyncIterator[str | DeltaToolCalls]: + if isinstance(r, str): + async for chunk in stream_text_response(r): + yield chunk + else: + async for chunk in stream_tool_call_response(r): + yield chunk + + last_part = messages[-1].parts[-1] + if isinstance(last_part, UserPromptPart): + assert isinstance(last_part.content, str) + if response := text_responses.get(last_part.content): + async for chunk in stream_part_response(response): + yield chunk + return + elif isinstance(last_part, ToolReturnPart): + assert isinstance(last_part.content, str) + if response := tool_responses.get((last_part.tool_name, last_part.content)): + async for chunk in stream_part_response(response): + yield chunk + return sys.stdout.write(str(debug.format(messages, info))) - raise RuntimeError(f'Unexpected message: {m}') + raise RuntimeError(f'Unexpected message: {last_part}') def mock_infer_model(model: Model | KnownModelName) -> Model: From 6e0e0c797e3f622c902e3f57d6bc05a6fb612cfb Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 27 Feb 2025 11:36:45 -0700 Subject: [PATCH 6/8] Fix test --- docs/agents.md | 6 +++--- pydantic_ai_slim/pydantic_ai/models/function.py | 4 +++- tests/test_examples.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 6686bae5c..42469a4bc 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -337,13 +337,13 @@ if __name__ == '__main__': """ [ '=== ModelRequestNode: streaming partial request tokens ===', - '[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=None, part_kind=\'tool-call\')', + '[Request] Starting part 0: ToolCallPart(tool_name=\'weather_forecast\', args=\'{"location":"Pa\', tool_call_id=\'0001\', part_kind=\'tool-call\')', '[Request] Part 0 args_delta=ris","forecast_', '[Request] Part 0 args_delta=date":"2030-01-', '[Request] Part 0 args_delta=01"}', '=== HandleResponseNode: streaming partial response & tool usage ===', - '[Tools] The LLM calls tool=\'weather_forecast\' with args={"location":"Paris","forecast_date":"2030-01-01"} (tool_call_id=None)', - "[Tools] Tool call 'ea89591f-d193-435f-9640-f095f165bb89' returned => The forecast in Paris on 2030-01-01 is 24°C and sunny.", + '[Tools] The LLM calls tool=\'weather_forecast\' with args={"location":"Paris","forecast_date":"2030-01-01"} (tool_call_id=\'0001\')', + "[Tools] Tool call '0001' returned => The forecast in Paris on 2030-01-01 is 24°C and sunny.", '=== ModelRequestNode: streaming partial request tokens ===', "[Request] Starting part 0: TextPart(content='It will be ', part_kind='text')", '[Result] The model produced a final result (tool_name=None)', diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 084f18df8..1d6d25085 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -177,6 +177,8 @@ class DeltaToolCall: """Incremental change to the name of the tool.""" json_args: str | None = None """Incremental change to the arguments as JSON""" + tool_call_id: str | None = None + """Incremental change to the tool call ID.""" DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall] @@ -224,7 +226,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=dtc_index, tool_name=delta_tool_call.name, args=delta_tool_call.json_args, - tool_call_id=None, + tool_call_id=delta_tool_call.tool_call_id, ) if maybe_event is not None: yield maybe_event diff --git a/tests/test_examples.py b/tests/test_examples.py index 6715890ff..3499a1918 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -183,7 +183,7 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: 'The weather in West London is raining, while in Wiltshire it is sunny.' ), 'What will the weather be like in Paris on Tuesday?': ToolCallPart( - tool_name='weather_forecast', args={'location': 'Paris', 'forecast_date': '2030-01-01'} + tool_name='weather_forecast', args={'location': 'Paris', 'forecast_date': '2030-01-01'}, tool_call_id='0001' ), 'Tell me a joke.': 'Did you hear about the toothpaste scandal? They called it Colgate.', 'Explain?': 'This is an excellent joke invented by Samuel Colvin, it needs no explanation.', @@ -376,7 +376,7 @@ async def stream_text_response(r: str) -> AsyncIterator[str]: async def stream_tool_call_response(r: ToolCallPart) -> AsyncIterator[DeltaToolCalls]: json_text = r.args_as_json_str() - yield {1: DeltaToolCall(name=r.tool_name)} + yield {1: DeltaToolCall(name=r.tool_name, tool_call_id=r.tool_call_id)} for chunk_index in range(0, len(json_text), 15): text_chunk = json_text[chunk_index : chunk_index + 15] yield {1: DeltaToolCall(json_args=text_chunk)} From 09b4319299aefd4b3221f6b8446546207a7e93f3 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 27 Feb 2025 15:45:16 -0700 Subject: [PATCH 7/8] Update tests a bit --- docs/agents.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 42469a4bc..26e041ac9 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -284,7 +284,10 @@ async def main(): # Begin a node-by-node, streaming iteration with weather_agent.iter(user_prompt, deps=WeatherService()) as run: async for node in run: - if Agent.is_model_request_node(node): + if Agent.is_user_prompt_node(node): + # A user prompt node => The user has provided input + output_messages.append(f'=== UserPromptNode: {node.user_prompt} ===') + elif Agent.is_model_request_node(node): # A model request node => We can stream tokens from the model's request output_messages.append( '=== ModelRequestNode: streaming partial request tokens ===' @@ -308,7 +311,6 @@ async def main(): output_messages.append( f'[Result] The model produced a final result (tool_name={event.tool_name})' ) - elif Agent.is_handle_response_node(node): # A handle-response node => The model returned some data, potentially calls a tool output_messages.append( @@ -324,10 +326,10 @@ async def main(): output_messages.append( f'[Tools] Tool call {event.tool_call_id!r} returned => {event.result.content}' ) - - # Once an End node is reached, the agent run is complete - assert run.result is not None - output_messages.append(f'=== Final Agent Output: {run.result.data} ===') + elif Agent.is_end_node(node): + assert run.result.data == node.data.data + # Once an End node is reached, the agent run is complete + output_messages.append(f'=== Final Agent Output: {run.result.data} ===') if __name__ == '__main__': From e73b85bca2b456dc64dbf4470a793b13bfe27491 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Thu, 27 Feb 2025 16:16:07 -0700 Subject: [PATCH 8/8] Update coverage --- pydantic_ai_slim/pydantic_ai/agent.py | 2 +- tests/test_streaming.py | 113 +++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 54c242683..041aeb99d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1293,7 +1293,7 @@ def next_node( return next_node if _agent_graph.is_agent_node(next_node): return next_node - raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') + raise exceptions.AgentRunError(f'Unexpected node type: {type(next_node)}') # pragma: no cover @property def result(self) -> AgentRunResult[ResultDataT] | None: diff --git a/tests/test_streaming.py b/tests/test_streaming.py index b4612e6c8..92bcc4a73 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -2,14 +2,18 @@ import datetime import json +import re from collections.abc import AsyncIterator +from copy import deepcopy from datetime import timezone +from typing import Union import pytest from inline_snapshot import snapshot from pydantic import BaseModel from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( ModelMessage, ModelRequest, @@ -22,7 +26,8 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Usage +from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_graph import End from .conftest import IsNow @@ -739,3 +744,109 @@ async def test_custom_result_type_default_structured() -> None: async with agent.run_stream('test', result_type=str) as result: response = await result.get_data() assert response == snapshot('success (no tool calls)') + + +async def test_iter_stream_output(): + m = TestModel(custom_result_text='The cat sat on the mat.') + + agent = Agent(m) + + @agent.result_validator + def result_validator_simple(data: str) -> str: + # Make a substitution in the validated results + return re.sub('cat sat', 'bat sat', data) + + run: AgentRun + stream: AgentStream + messages: list[str] = [] + + stream_usage: Usage | None = None + with agent.iter('Hello') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for chunk in stream.stream_output(debounce_by=None): + messages.append(chunk) + stream_usage = deepcopy(stream.usage()) + assert run.next_node == End(data=FinalResult(data='The bat sat on the mat.', tool_name=None)) + assert ( + run.usage() + == stream_usage + == Usage(requests=1, request_tokens=51, response_tokens=7, total_tokens=58, details=None) + ) + + assert messages == [ + '', + 'The ', + 'The cat ', + 'The bat sat ', + 'The bat sat on ', + 'The bat sat on the ', + 'The bat sat on the mat.', + 'The bat sat on the mat.', + ] + + +async def test_iter_stream_responses(): + m = TestModel(custom_result_text='The cat sat on the mat.') + + agent = Agent(m) + + @agent.result_validator + def result_validator_simple(data: str) -> str: + # Make a substitution in the validated results + return re.sub('cat sat', 'bat sat', data) + + run: AgentRun + stream: AgentStream + messages: list[ModelResponse] = [] + with agent.iter('Hello') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for chunk in stream.stream_responses(debounce_by=None): + messages.append(chunk) + + assert messages == [ + ModelResponse( + parts=[TextPart(content=text, part_kind='text')], + model_name='test', + timestamp=IsNow(tz=timezone.utc), + kind='response', + ) + for text in [ + '', + '', + 'The ', + 'The cat ', + 'The cat sat ', + 'The cat sat on ', + 'The cat sat on the ', + 'The cat sat on the mat.', + ] + ] + + # Note: as you can see above, the result validator is not applied to the streamed responses, just the final result: + assert run.result is not None + assert run.result.data == 'The bat sat on the mat.' + + +async def test_stream_iter_structured_validator() -> None: + class NotResultType(BaseModel): + not_value: str + + agent = Agent[None, Union[ResultType, NotResultType]]('test', result_type=Union[ResultType, NotResultType]) # pyright: ignore[reportArgumentType] + + @agent.result_validator + def result_validator(data: ResultType | NotResultType) -> ResultType | NotResultType: + assert isinstance(data, ResultType) + return ResultType(value=data.value + ' (validated)') + + outputs: list[ResultType] = [] + with agent.iter('test') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for output in stream.stream_output(debounce_by=None): + outputs.append(output) + assert outputs == [ResultType(value='a (validated)'), ResultType(value='a (validated)')]