Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dataclasses import field, replace
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, cast

import anyio
from opentelemetry.trace import Tracer
from typing_extensions import TypeVar, assert_never

Expand Down Expand Up @@ -663,8 +664,7 @@ async def _handle_tool_calls(
yield event

if output_final_result:
final_result = output_final_result[0]
self._next_node = self._handle_final_result(ctx, final_result, output_parts)
self._next_node = self._handle_final_result(ctx, output_final_result[0], output_parts)
else:
instructions = await ctx.deps.get_instructions(run_context)
self._next_node = ModelRequestNode[DepsT, NodeRunEndT](
Expand Down Expand Up @@ -882,7 +882,7 @@ async def process_tool_calls( # noqa: C901
output_final_result.append(final_result)


async def _call_tools(
async def _call_tools( # noqa: C901
tool_manager: ToolManager[DepsT],
tool_calls: list[_messages.ToolCallPart],
tool_call_results: dict[str, DeferredToolResult],
Expand Down Expand Up @@ -940,30 +940,45 @@ async def handle_call_or_result(

return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)

if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
if event := await handle_call_or_result(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
index,
):
yield event
send_stream, receive_stream = anyio.create_memory_object_stream[_messages.HandleResponseEvent]()

async def _run_tools():
async with send_stream:
assert tool_manager.ctx is not None, 'ToolManager.ctx needs to be set'
tool_manager.ctx.event_stream = send_stream

if tool_manager.should_call_sequentially(tool_calls):
for index, call in enumerate(tool_calls):
if event := await handle_call_or_result(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
index,
):
await send_stream.send(event)

else:
tasks = [
asyncio.create_task(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
name=call.tool_name,
)
for call in tool_calls
]

else:
tasks = [
asyncio.create_task(
_call_tool(tool_manager, call, tool_call_results.get(call.tool_call_id)),
name=call.tool_name,
)
for call in tool_calls
]
pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
if event := await handle_call_or_result(coro_or_task=task, index=index):
await send_stream.send(event)

task = asyncio.create_task(_run_tools())

async with receive_stream:
async for message in receive_stream:
yield message

pending = tasks
while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for task in done:
index = tasks.index(task)
if event := await handle_call_or_result(coro_or_task=task, index=index):
yield event
await task

# We append the results at the end, rather than as they are received, to retain a consistent ordering
# This is mostly just to simplify testing
Expand Down
31 changes: 29 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
from pydantic_core import SchemaValidator, core_schema
from typing_extensions import ParamSpec, TypeIs, TypeVar

from pydantic_ai.messages import CustomEvent, ToolReturn

from ._griffe import doc_descriptions
from ._run_context import RunContext
from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
from ._utils import (
check_object_json_schema,
is_async_callable,
is_async_iterator_callable,
is_model_like,
run_in_executor,
)

if TYPE_CHECKING:
from .tools import DocstringFormat, ObjectJsonSchema
Expand All @@ -41,13 +49,31 @@ class FunctionSchema:
# if not None, the function takes a single by that name (besides potentially `info`)
takes_ctx: bool
is_async: bool
is_async_iterator: bool
single_arg_name: str | None = None
positional_fields: list[str] = field(default_factory=list)
var_positional_field: str | None = None

async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any:
args, kwargs = self._call_args(args_dict, ctx)
if self.is_async:
if self.is_async_iterator:
assert ctx.event_stream is not None, (
'RunContext.event_stream needs to be set to use FunctionSchema.call with async iterators'
)

async for event_payload in self.function(*args, **kwargs):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure we handle exceptions correctly

if isinstance(event_payload, ToolReturn):
return event_payload

event = (
cast(CustomEvent, event_payload)
if isinstance(event_payload, CustomEvent)
else CustomEvent(payload=event_payload)
)
await ctx.event_stream.send(event)
# TODO (DouweM): Raise if events are yielded after ToolReturn
return None
elif self.is_async:
function = cast(Callable[[Any], Awaitable[str]], self.function)
return await function(*args, **kwargs)
else:
Expand Down Expand Up @@ -221,6 +247,7 @@ def function_schema( # noqa: C901
var_positional_field=var_positional_field,
takes_ctx=takes_ctx,
is_async=is_async_callable(function),
is_async_iterator=is_async_iterator_callable(function),
function=function,
)

Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import field
from typing import TYPE_CHECKING, Generic

from anyio.streams.memory import MemoryObjectSendStream
from opentelemetry.trace import NoOpTracer, Tracer
from typing_extensions import TypeVar

Expand Down Expand Up @@ -36,6 +37,8 @@ class RunContext(Generic[AgentDepsT]):
"""Messages exchanged in the conversation so far."""
tracer: Tracer = field(default_factory=NoOpTracer)
"""The tracer to use for tracing the run."""
event_stream: MemoryObjectSendStream[_messages.CustomEvent] | None = None
"""The event stream to use for handling custom events."""
trace_include_content: bool = False
"""Whether to include the content of the messages in the trace."""
instrumentation_version: int = DEFAULT_INSTRUMENTATION_VERSION
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDe
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
if self.ctx is not None:
if ctx.run_step == self.ctx.run_step:
# TODO (DouweM): Refactor to make sure it's always set

if ctx.event_stream and not self.ctx.event_stream:
self.ctx.event_stream = ctx.event_stream
return self

retries = {
Expand Down
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,12 @@ def is_async_callable(obj: Any) -> Any:
while isinstance(obj, functools.partial):
obj = obj.func

return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # type: ignore
return inspect.iscoroutinefunction(obj) or (callable(obj) and inspect.iscoroutinefunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess]


def is_async_iterator_callable(obj: Any) -> bool:
"""Check if a callable is an async iterator."""
return inspect.isasyncgenfunction(obj) or (callable(obj) and inspect.isasyncgenfunction(obj.__call__)) # pyright: ignore[reportFunctionMemberAccess]


def _update_mapped_json_schema_refs(s: dict[str, Any], name_mapping: dict[str, str]) -> None:
Expand Down
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BaseToolCallPart,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CustomEvent,
FunctionToolResultEvent,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -431,6 +432,8 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
if isinstance(event, FunctionToolResultEvent):
async for msg in _handle_tool_result_event(stream_ctx, event):
yield msg
elif isinstance(event, CustomEvent) and isinstance(event.payload, BaseEvent):
yield event.payload


async def _handle_model_request_event( # noqa: C901
Expand Down Expand Up @@ -582,6 +585,8 @@ async def _handle_tool_result_event(
content=result.model_response_str(),
)

# TODO (DouweM): Stream `event.content` as if they were user parts?

# Now check for AG-UI events returned by the tool calls.
possible_event = result.metadata or result.content
if isinstance(possible_event, BaseEvent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(

async def call_tool_activity(params: _CallToolParams, deps: AgentDepsT) -> _CallToolResult:
name = params.name
# TODO (DouweM): RunContext.event_stream -> call event_stream_handler directly?
ctx = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
try:
tool = (await toolset.get_tools(ctx))[name]
Expand Down
28 changes: 26 additions & 2 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from dataclasses import KW_ONLY, dataclass, field, replace
from datetime import datetime
from mimetypes import guess_type
from typing import TYPE_CHECKING, Annotated, Any, Literal, TypeAlias, cast, overload
from typing import TYPE_CHECKING, Annotated, Any, Generic, Literal, TypeAlias, cast, overload

import pydantic
import pydantic_core
from genai_prices import calc_price, types as genai_types
from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Self, deprecated
from typing_extensions import Self, TypeVar, deprecated

from . import _otel_messages, _utils
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
Expand All @@ -23,6 +23,8 @@
if TYPE_CHECKING:
from .models.instrumented import InstrumentationSettings

EventPayloadT = TypeVar('EventPayloadT', default=Any)


AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac']
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
Expand Down Expand Up @@ -1724,9 +1726,31 @@ class BuiltinToolResultEvent:
"""Event type identifier, used as a discriminator."""


@dataclass(repr=False)
class CustomEvent(Generic[EventPayloadT]):
"""An event indicating the result of a function tool call."""

payload: EventPayloadT
"""The payload of the custom event."""

_: KW_ONLY

name: str | None = None
"""The optional name of the custom event."""

id: str | None = None
"""The optional ID of the custom event."""

event_kind: Literal['custom'] = 'custom'
"""Event type identifier, used as a discriminator."""

__repr__ = _utils.dataclasses_no_defaults_repr


HandleResponseEvent = Annotated[
FunctionToolCallEvent
| FunctionToolResultEvent
| CustomEvent
| BuiltinToolCallEvent # pyright: ignore[reportDeprecated]
| BuiltinToolResultEvent, # pyright: ignore[reportDeprecated]
pydantic.Discriminator('event_kind'),
Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def from_schema(
json_schema=json_schema,
takes_ctx=takes_ctx,
is_async=_utils.is_async_callable(function),
is_async_iterator=_utils.is_async_iterator_callable(function),
)

return cls(
Expand Down
81 changes: 81 additions & 0 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ async def send_custom() -> ToolReturn:
)


async def yield_custom() -> AsyncIterator[CustomEvent | ToolReturn]:
yield CustomEvent(
type=EventType.CUSTOM,
name='custom_event1',
value={'key1': 'value1'},
)
yield CustomEvent(
type=EventType.CUSTOM,
name='custom_event2',
value={'key2': 'value2'},
)
yield ToolReturn('Done')


def uuid_str() -> str:
"""Generate a random UUID string."""
return uuid.uuid4().hex
Expand Down Expand Up @@ -815,6 +829,73 @@ async def stream_function(
)


async def test_tool_local_yield_events() -> None:
"""Test local tool call that yields multiple events."""

async def stream_function(
messages: list[ModelMessage], agent_info: AgentInfo
) -> AsyncIterator[DeltaToolCalls | str]:
if len(messages) == 1:
# First call - make a tool call
yield {0: DeltaToolCall(name='yield_custom')}
yield {0: DeltaToolCall(json_args='{}')}
else:
# Second call - return text result
yield 'success yield_custom called'

agent = Agent(
model=FunctionModel(stream_function=stream_function),
tools=[yield_custom],
)

run_input = create_input(
UserMessage(
id='msg_1',
content='Please call yield_custom',
),
)
events = await run_and_collect_events(agent, run_input)

assert events == snapshot(
[
{
'type': 'RUN_STARTED',
'threadId': (thread_id := IsSameStr()),
'runId': (run_id := IsSameStr()),
},
{
'type': 'TOOL_CALL_START',
'toolCallId': (tool_call_id := IsSameStr()),
'toolCallName': 'yield_custom',
'parentMessageId': IsStr(),
},
{'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, 'delta': '{}'},
{'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id},
{'type': 'CUSTOM', 'name': 'custom_event1', 'value': {'key1': 'value1'}},
{'type': 'CUSTOM', 'name': 'custom_event2', 'value': {'key2': 'value2'}},
{
'type': 'TOOL_CALL_RESULT',
'messageId': IsStr(),
'toolCallId': tool_call_id,
'content': 'Done',
'role': 'tool',
},
{'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'},
{
'type': 'TEXT_MESSAGE_CONTENT',
'messageId': message_id,
'delta': 'success yield_custom called',
},
{'type': 'TEXT_MESSAGE_END', 'messageId': message_id},
{
'type': 'RUN_FINISHED',
'threadId': thread_id,
'runId': run_id,
},
]
)


async def test_tool_local_parts() -> None:
"""Test local tool call with streaming/parts."""

Expand Down
Loading
Loading