diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index d7a54c5c71..10d66dfbbc 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -12,6 +12,7 @@ from dataclasses import field, replace from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast +import anyio from opentelemetry.trace import Tracer from typing_extensions import TypeVar, assert_never @@ -663,8 +664,7 @@ async def _handle_tool_calls( yield event if output_final_result: - final_result = output_final_result[0] - self._next_node = self._handle_final_result(ctx, final_result, output_parts) + self._next_node = self._handle_final_result(ctx, output_final_result[0], output_parts) else: instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( @@ -882,7 +882,7 @@ async def process_tool_calls( # noqa: C901 output_final_result.append(final_result) -async def _call_tools( +async def _call_tools( # noqa: C901 tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], tool_call_results: dict[str, DeferredToolResult], @@ -940,30 +940,45 @@ async def handle_call_or_result( return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content) - if tool_manager.should_call_sequentially(tool_calls): - for index, call in enumerate(tool_calls): - if event := await handle_call_or_result( - _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), - index, - ): - yield event + send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]() + + async def _run_tools(): + async with send_stream: + assert tool_manager.ctx is not None, 'ToolManager.ctx needs to be set' + tool_manager.ctx.event_stream = send_stream + + if tool_manager.should_call_sequentially(tool_calls): + for index, call in enumerate(tool_calls): + if event := await handle_call_or_result( + _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), + index, + ): + await send_stream.send(event) + + else: + tasks = [ + asyncio.create_task( + _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), + name=call.tool_name, + ) + for call in tool_calls + ] - else: - tasks = [ - asyncio.create_task( - _call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)), - name=call.tool_name, - ) - for call in tool_calls - ] + pending = tasks + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.index(task) + if event := await handle_call_or_result(coro_or_task=task, index=index): + await send_stream.send(event) + + task = asyncio.create_task(_run_tools()) + + async with receive_stream: + async for message in receive_stream: + yield message - pending = tasks - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - index = tasks.index(task) - if event := await handle_call_or_result(coro_or_task=task, index=index): - yield event + await task # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index 2b8270f322..46060dbd1c 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -19,9 +19,17 @@ from pydantic_core import SchemaValidator, core_schema from typing_extensions import ParamSpec, TypeIs, TypeVar +from pydantic_ai.messages import CustomEvent, ToolReturn + from ._griffe import doc_descriptions from ._run_context import RunContext -from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor +from ._utils import ( + check_object_json_schema, + is_async_callable, + is_async_iterator_callable, + is_model_like, + run_in_executor, +) if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema @@ -41,13 +49,31 @@ class FunctionSchema: # if not None, the function takes a single by that name (besides potentially `info`) takes_ctx: bool is_async: bool + is_async_iterator: bool single_arg_name: str | None = None positional_fields: list[str] = field(default_factory=list) var_positional_field: str | None = None async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any: args, kwargs = self._call_args(args_dict, ctx) - if self.is_async: + if self.is_async_iterator: + assert ctx.event_stream is not None, ( + 'RunContext.event_stream needs to be set to use FunctionSchema.call with async iterators' + ) + + async for event_payload in self.function(*args, **kwargs): + if isinstance(event_payload, ToolReturn): + return event_payload + + event = ( + cast(CustomEvent, event_payload) + if isinstance(event_payload, CustomEvent) + else CustomEvent(payload=event_payload) + ) + await ctx.event_stream.send(event) + # TODO (DouweM): Raise if events are yielded after ToolReturn + return None + elif self.is_async: function = cast(Callable[[Any], Awaitable[str]], self.function) return await function(*args, **kwargs) else: @@ -221,6 +247,7 @@ def function_schema( # noqa: C901 var_positional_field=var_positional_field, takes_ctx=takes_ctx, is_async=is_async_callable(function), + is_async_iterator=is_async_iterator_callable(function), function=function, ) diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index df2a4c1b5a..9cdd1d793a 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -5,6 +5,7 @@ from dataclasses import field from typing import TYPE_CHECKING, Generic +from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar @@ -36,6 +37,8 @@ class RunContext(Generic[AgentDepsT]): """Messages exchanged in the conversation so far.""" tracer: Tracer = field(default_factory=NoOpTracer) """The tracer to use for tracing the run.""" + event_stream: MemoryObjectSendStream[_messages.CustomEvent] | None = None + """The event stream to use for handling custom events.""" trace_include_content: bool = False """Whether to include the content of the messages in the trace.""" instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index a5546a4e01..b47e4ca6e6 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -50,6 +50,10 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe """Build a new tool manager for the next run step, carrying over the retries from the current run step.""" if self.ctx is not None: if ctx.run_step == self.ctx.run_step: + # TODO (DouweM): Refactor to make sure it's always set + + if ctx.event_stream and not self.ctx.event_stream: + self.ctx.event_stream = ctx.event_stream return self retries = { diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 6fc8a080ec..5087193eef 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -366,7 +366,12 @@ def is_async_callable(obj: Any) -> Any: while isinstance(obj, functools.partial): obj = obj.func - return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore + return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess] + + +def is_async_iterator_callable(obj: Any) -> bool: + """Check if a callable is an async iterator.""" + return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess] def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None: diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index fe0ed77951..5f33379275 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -32,6 +32,7 @@ BaseToolCallPart, BuiltinToolCallPart, BuiltinToolReturnPart, + CustomEvent, FunctionToolResultEvent, ModelMessage, ModelRequest, @@ -431,6 +432,8 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve if isinstance(event, FunctionToolResultEvent): async for msg in _handle_tool_result_event(stream_ctx, event): yield msg + elif isinstance(event, CustomEvent) and isinstance(event.payload, BaseEvent): + yield event.payload async def _handle_model_request_event( # noqa: C901 @@ -582,6 +585,8 @@ async def _handle_tool_result_event( content=result.model_response_str(), ) + # TODO (DouweM): Stream `event.content` as if they were user parts? + # Now check for AG-UI events returned by the tool calls. possible_event = result.metadata or result.content if isinstance(possible_event, BaseEvent): diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py index 4cad787b11..f58a3b526c 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py @@ -71,6 +71,7 @@ def __init__( async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult: name = params.name + # TODO (DouweM): RunContext.event_stream -> call event_stream_handler directly? ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps) try: tool = (await toolset.get_tools(ctx))[name] diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 4c8bcd48b1..9dcfbb02d9 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -7,13 +7,13 @@ from dataclasses import KW_ONLY, dataclass, field, replace from datetime import datetime from mimetypes import guess_type -from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload +from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeAlias, cast, overload import pydantic import pydantic_core from genai_prices import calc_price, types as genai_types from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import Self, deprecated +from typing_extensions import Self, TypeVar, deprecated from . import _otel_messages, _utils from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc @@ -23,6 +23,8 @@ if TYPE_CHECKING: from .models.instrumented import InstrumentationSettings +EventPayloadT = TypeVar('EventPayloadT', default=Any) + AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac'] ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp'] @@ -1724,9 +1726,31 @@ class BuiltinToolResultEvent: """Event type identifier, used as a discriminator.""" +@dataclass(repr=False) +class CustomEvent(Generic[EventPayloadT]): + """An event indicating the result of a function tool call.""" + + payload: EventPayloadT + """The payload of the custom event.""" + + _: KW_ONLY + + name: str | None = None + """The optional name of the custom event.""" + + id: str | None = None + """The optional ID of the custom event.""" + + event_kind: Literal['custom'] = 'custom' + """Event type identifier, used as a discriminator.""" + + __repr__ = _utils.dataclasses_no_defaults_repr + + HandleResponseEvent = Annotated[ FunctionToolCallEvent | FunctionToolResultEvent + | CustomEvent | BuiltinToolCallEvent # pyright: ignore[reportDeprecated] | BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] pydantic.Discriminator('event_kind'), diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 844e99a25e..d73f3f1375 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -391,6 +391,7 @@ def from_schema( json_schema=json_schema, takes_ctx=takes_ctx, is_async=_utils.is_async_callable(function), + is_async_iterator=_utils.is_async_iterator_callable(function), ) return cls( diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index fcd0fea9c5..43f9993bd6 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -192,6 +192,20 @@ async def send_custom() -> ToolReturn: ) +async def yield_custom() -> AsyncIterator[CustomEvent | ToolReturn]: + yield CustomEvent( + type=EventType.CUSTOM, + name='custom_event1', + value={'key1': 'value1'}, + ) + yield CustomEvent( + type=EventType.CUSTOM, + name='custom_event2', + value={'key2': 'value2'}, + ) + yield ToolReturn('Done') + + def uuid_str() -> str: """Generate a random UUID string.""" return uuid.uuid4().hex @@ -815,6 +829,73 @@ async def stream_function( ) +async def test_tool_local_yield_events() -> None: + """Test local tool call that yields multiple events.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First call - make a tool call + yield {0: DeltaToolCall(name='yield_custom')} + yield {0: DeltaToolCall(json_args='{}')} + else: + # Second call - return text result + yield 'success yield_custom called' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + tools=[yield_custom], + ) + + run_input = create_input( + UserMessage( + id='msg_1', + content='Please call yield_custom', + ), + ) + events = await run_and_collect_events(agent, run_input) + + assert events == snapshot( + [ + { + 'type': 'RUN_STARTED', + 'threadId': (thread_id := IsSameStr()), + 'runId': (run_id := IsSameStr()), + }, + { + 'type': 'TOOL_CALL_START', + 'toolCallId': (tool_call_id := IsSameStr()), + 'toolCallName': 'yield_custom', + 'parentMessageId': IsStr(), + }, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, 'delta': '{}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id}, + {'type': 'CUSTOM', 'name': 'custom_event1', 'value': {'key1': 'value1'}}, + {'type': 'CUSTOM', 'name': 'custom_event2', 'value': {'key2': 'value2'}}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': IsStr(), + 'toolCallId': tool_call_id, + 'content': 'Done', + 'role': 'tool', + }, + {'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'}, + { + 'type': 'TEXT_MESSAGE_CONTENT', + 'messageId': message_id, + 'delta': 'success yield_custom called', + }, + {'type': 'TEXT_MESSAGE_END', 'messageId': message_id}, + { + 'type': 'RUN_FINISHED', + 'threadId': thread_id, + 'runId': run_id, + }, + ] + ) + + async def test_tool_local_parts() -> None: """Test local tool call with streaming/parts.""" diff --git a/tests/test_dbos.py b/tests/test_dbos.py index a9aac4b961..d6300b8f5b 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -265,7 +265,6 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D 'complex_agent__model.request_stream', 'event_stream_handler', 'event_stream_handler', - 'event_stream_handler', 'complex_agent__mcp_server__mcp.call_tool', 'event_stream_handler', 'complex_agent__mcp_server__mcp.get_tools', @@ -354,16 +353,9 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D content='running 2 tools', children=[ BasicSpan(content='running tool: get_country'), + BasicSpan(content='ctx.run_step=1'), BasicSpan( - content='event_stream_handler', - children=[ - BasicSpan(content='ctx.run_step=1'), - BasicSpan( - content=IsStr( - regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' - ) - ), - ], + content='{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":"2025-10-08T14:38:30.370338+00:00","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}' ), BasicSpan( content='running tool: get_product_name',