diff --git a/docs/ag-ui.md b/docs/ag-ui.md index 45bc27af87..b19f0baf7b 100644 --- a/docs/ag-ui.md +++ b/docs/ag-ui.md @@ -4,8 +4,6 @@ The [Agent User Interaction (AG-UI) Protocol](https://docs.ag-ui.com/introductio [CopilotKit](https://webflow.copilotkit.ai/blog/introducing-ag-ui-the-protocol-where-agents-meet-users) team that standardises how frontend applications communicate with AI agents, with support for streaming, frontend tools, shared state, and custom events. -Any Pydantic AI agent can be exposed as an AG-UI server using the [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] convenience method. - !!! note The AG-UI integration was originally built by the team at [Rocket Science](https://www.rocketscience.gg/) and contributed in collaboration with the Pydantic AI and CopilotKit teams. Thanks Rocket Science! @@ -13,8 +11,8 @@ Any Pydantic AI agent can be exposed as an AG-UI server using the [`Agent.to_ag_ The only dependencies are: -- [ag-ui-protocol](https://docs.ag-ui.com/introduction): to provide the AG-UI types and encoder -- [starlette](https://www.starlette.io): to expose the AG-UI server as an [ASGI application](https://asgi.readthedocs.io/en/latest/) +- [ag-ui-protocol](https://docs.ag-ui.com/introduction): to provide the AG-UI types and encoder. +- [starlette](https://www.starlette.io): to handle [ASGI](https://asgi.readthedocs.io/en/latest/) requests from a framework like FastAPI. You can install Pydantic AI with the `ag-ui` extra to ensure you have all the required AG-UI dependencies: @@ -31,9 +29,95 @@ To run the examples you'll also need: pip/uv-add uvicorn ``` -## Quick start +## Usage + +There are three ways to run a Pydantic AI agent based on AG-UI run input with streamed AG-UI events as output, from most to least flexible. If you're using a Starlette-based web framework like FastAPI, you'll typically want to use the second method. + +1. [`run_ag_ui()`][pydantic_ai.ag_ui.run_ag_ui] takes an agent and an AG-UI [`RunAgentInput`](https://docs.ag-ui.com/sdk/python/core/types#runagentinput) object, and returns a stream of AG-UI events encoded as strings. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`. Use this if you're using a web framework not based on Starlette (e.g. Django or Flask) or want to modify the input or output some way. +2. [`handle_ag_ui_request()`][pydantic_ai.ag_ui.handle_ag_ui_request] takes an agent and a Starlette request (e.g. from FastAPI) coming from an AG-UI frontend, and returns a streaming Starlette response of AG-UI events that you can return directly from your endpoint. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, that you can vary for each request (e.g. based on the authenticated user). +3. [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] returns an ASGI application that handles every AG-UI request by running the agent. It also takes optional [`Agent.iter()`][pydantic_ai.Agent.iter] arguments including `deps`, but these will be the same for each request, with the exception of the AG-UI state that's injected as described under [state management](#state-management). This ASGI app can be [mounted](https://fastapi.tiangolo.com/advanced/sub-applications/) at a given path in an existing FastAPI app. + +### Handle run input and output directly + +This example uses [`run_ag_ui()`][pydantic_ai.ag_ui.run_ag_ui] and performs its own request parsing and response generation. +This can be modified to work with any web framework. + +```py {title="run_ag_ui.py"} +from ag_ui.core import RunAgentInput +from fastapi import FastAPI +from http import HTTPStatus +from fastapi.requests import Request +from fastapi.responses import Response, StreamingResponse +from pydantic import ValidationError +import json + +from pydantic_ai import Agent +from pydantic_ai.ag_ui import run_ag_ui, SSE_CONTENT_TYPE + + +agent = Agent('openai:gpt-4.1', instructions='Be fun!') + +app = FastAPI() + + +@app.post("/") +async def run_agent(request: Request) -> Response: + accept = request.headers.get('accept', SSE_CONTENT_TYPE) + try: + run_input = RunAgentInput.model_validate(await request.json()) + except ValidationError as e: # pragma: no cover + return Response( + content=json.dumps(e.json()), + media_type='application/json', + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + + event_stream = run_ag_ui(agent, run_input, accept=accept) + + return StreamingResponse(event_stream, media_type=accept) +``` + +Since `app` is an ASGI application, it can be used with any ASGI server: + +```shell +uvicorn run_ag_ui:app +``` + +This will expose the agent as an AG-UI server, and your frontend can start sending requests to it. + +### Handle a Starlette request + +This example uses [`handle_ag_ui_request()`][pydantic_ai.ag_ui.run_ag_ui] to directly handle a FastAPI request and return a response. Something analogous to this will work with any Starlette-based web framework. + +```py {title="handle_ag_ui_request.py"} +from fastapi import FastAPI +from starlette.requests import Request +from starlette.responses import Response + +from pydantic_ai import Agent +from pydantic_ai.ag_ui import handle_ag_ui_request + + +agent = Agent('openai:gpt-4.1', instructions='Be fun!') + +app = FastAPI() + +@app.post("/") +async def run_agent(request: Request) -> Response: + return await handle_ag_ui_request(agent, request) +``` + +Since `app` is an ASGI application, it can be used with any ASGI server: -To expose a Pydantic AI agent as an AG-UI server, you can use the [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] method: +```shell +uvicorn handle_ag_ui_request:app +``` + +This will expose the agent as an AG-UI server, and your frontend can start sending requests to it. + +### Stand-alone ASGI app + +This example uses [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] to turn the agent into a stand-alone ASGI application: ```py {title="agent_to_ag_ui.py" py="3.10" hl_lines="4"} from pydantic_ai import Agent @@ -45,13 +129,11 @@ app = agent.to_ag_ui() Since `app` is an ASGI application, it can be used with any ASGI server: ```shell -uvicorn agent_to_ag_ui:app --host 0.0.0.0 --port 9000 +uvicorn agent_to_ag_ui:app ``` This will expose the agent as an AG-UI server, and your frontend can start sending requests to it. -The `to_ag_ui()` method accepts the same arguments as the [`Agent.iter()`][pydantic_ai.agent.Agent.iter] method as well as arguments that let you configure the [Starlette](https://www.starlette.io)-based ASGI app. - ## Design The Pydantic AI AG-UI integration supports all features of the spec: @@ -61,14 +143,11 @@ The Pydantic AI AG-UI integration supports all features of the spec: - [State Management](https://docs.ag-ui.com/concepts/state) - [Tools](https://docs.ag-ui.com/concepts/tools) -The app receives messages in the form of a -[`RunAgentInput`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) -which describes the details of a request being passed to the agent including -messages and state. These are then converted to Pydantic AI types and passed to the -agent which then process the request. +The integration receives messages in the form of a +[`RunAgentInput`](https://docs.ag-ui.com/sdk/python/core/types#runagentinput) object +that describes the details of the requested agent run including message history, state, and available tools. -Events from the agent, including tool calls, are converted to AG-UI events and -streamed back to the caller as Server-Sent Events (SSE). +These are converted to Pydantic AI types and passed to the agent's run method. Events from the agent, including tool calls, are converted to AG-UI events and streamed back to the caller as Server-Sent Events (SSE). A user request may require multiple round trips between client UI and Pydantic AI server, depending on the tools and events needed. @@ -77,7 +156,7 @@ server, depending on the tools and events needed. ### State management -The adapter provides full support for +The integration provides full support for [AG-UI state management](https://docs.ag-ui.com/concepts/state), which enables real-time synchronization between agents and frontend applications. diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 447a4ba60d..1ea6f16eb7 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -8,11 +8,10 @@ import json import uuid -from collections.abc import Iterable, Mapping, Sequence -from dataclasses import Field, dataclass, field, replace +from collections.abc import AsyncIterator, Iterable, Mapping, Sequence +from dataclasses import Field, dataclass, replace from http import HTTPStatus from typing import ( - TYPE_CHECKING, Any, Callable, ClassVar, @@ -23,10 +22,36 @@ runtime_checkable, ) -from pydantic_ai.exceptions import UserError +from pydantic import BaseModel, ValidationError -if TYPE_CHECKING: - pass +from ._agent_graph import CallToolsNode, ModelRequestNode +from .agent import Agent, AgentRun +from .exceptions import UserError +from .messages import ( + AgentStreamEvent, + FunctionToolResultEvent, + ModelMessage, + ModelRequest, + ModelResponse, + PartDeltaEvent, + PartStartEvent, + SystemPromptPart, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, + ToolCallPart, + ToolCallPartDelta, + ToolReturnPart, + UserPromptPart, +) +from .models import KnownModelName, Model +from .output import DeferredToolCalls, OutputDataT, OutputSpec +from .settings import ModelSettings +from .tools import AgentDepsT, ToolDefinition +from .toolsets import AbstractToolset +from .toolsets.deferred import DeferredToolset +from .usage import Usage, UsageLimits try: from ag_ui.core import ( @@ -74,43 +99,13 @@ 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' ) from e -from collections.abc import AsyncGenerator - -from pydantic import BaseModel, ValidationError - -from ._agent_graph import CallToolsNode, ModelRequestNode -from .agent import Agent, AgentRun, RunOutputDataT -from .messages import ( - AgentStreamEvent, - FunctionToolResultEvent, - ModelMessage, - ModelRequest, - ModelResponse, - PartDeltaEvent, - PartStartEvent, - SystemPromptPart, - TextPart, - TextPartDelta, - ThinkingPart, - ThinkingPartDelta, - ToolCallPart, - ToolCallPartDelta, - ToolReturnPart, - UserPromptPart, -) -from .models import KnownModelName, Model -from .output import DeferredToolCalls, OutputDataT, OutputSpec -from .settings import ModelSettings -from .tools import AgentDepsT, ToolDefinition -from .toolsets import AbstractToolset -from .toolsets.deferred import DeferredToolset -from .usage import Usage, UsageLimits - __all__ = [ 'SSE_CONTENT_TYPE', 'StateDeps', 'StateHandler', 'AGUIApp', + 'handle_ag_ui_request', + 'run_ag_ui', ] SSE_CONTENT_TYPE: Final[str] = 'text/event-stream' @@ -125,7 +120,7 @@ def __init__( agent: Agent[AgentDepsT, OutputDataT], *, # Agent.iter parameters. - output_type: OutputSpec[OutputDataT] | None = None, + output_type: OutputSpec[Any] | None = None, model: Model | KnownModelName | str | None = None, deps: AgentDepsT = None, model_settings: ModelSettings | None = None, @@ -142,10 +137,16 @@ def __init__( on_shutdown: Sequence[Callable[[], Any]] | None = None, lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, ) -> None: - """Initialise the AG-UI application. + """An ASGI application that handles every AG-UI request by running the agent. + + Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's + injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol. + To provide different `deps` for each request (e.g. based on the authenticated user), + use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or + [`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead. Args: - agent: The Pydantic AI `Agent` to adapt. + agent: The agent to run. output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no output validators since output validators would expect an argument that matches the agent's @@ -156,7 +157,7 @@ def __init__( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + toolsets: Optional additional toolsets for this run. debug: Boolean indicating if debug tracebacks should be returned on errors. routes: A list of routes to serve incoming HTTP and WebSocket requests. @@ -185,320 +186,349 @@ def __init__( on_shutdown=on_shutdown, lifespan=lifespan, ) - adapter = _Adapter(agent=agent) - async def endpoint(request: Request) -> Response | StreamingResponse: + async def endpoint(request: Request) -> Response: """Endpoint to run the agent with the provided input data.""" - accept = request.headers.get('accept', SSE_CONTENT_TYPE) - try: - input_data = RunAgentInput.model_validate(await request.json()) - except ValidationError as e: # pragma: no cover - return Response( - content=json.dumps(e.json()), - media_type='application/json', - status_code=HTTPStatus.UNPROCESSABLE_ENTITY, - ) - - return StreamingResponse( - adapter.run( - input_data, - accept, - output_type=output_type, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=infer_name, - toolsets=toolsets, - ), - media_type=SSE_CONTENT_TYPE, + return await handle_ag_ui_request( + agent, + request, + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, ) self.router.add_route('/', endpoint, methods=['POST'], name='run_agent') -@dataclass(repr=False) -class _Adapter(Generic[AgentDepsT, OutputDataT]): - """An agent adapter providing AG-UI protocol support for Pydantic AI agents. - - This class manages the agent runs, tool calls, state storage and providing - an adapter for running agents with Server-Sent Event (SSE) streaming - responses using the AG-UI protocol. +async def handle_ag_ui_request( + agent: Agent[AgentDepsT, Any], + request: Request, + *, + output_type: OutputSpec[Any] | None = None, + model: Model | KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, +) -> Response: + """Handle an AG-UI request by running the agent and returning a streaming response. Args: - agent: The Pydantic AI `Agent` to adapt. + agent: The agent to run. + request: The Starlette request (e.g. from FastAPI) containing the AG-UI run input. + + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Returns: + A streaming Starlette response with AG-UI protocol events. """ + accept = request.headers.get('accept', SSE_CONTENT_TYPE) + try: + input_data = RunAgentInput.model_validate(await request.json()) + except ValidationError as e: # pragma: no cover + return Response( + content=json.dumps(e.json()), + media_type='application/json', + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) - agent: Agent[AgentDepsT, OutputDataT] = field(repr=False) - - async def run( - self, - run_input: RunAgentInput, - accept: str = SSE_CONTENT_TYPE, - *, - output_type: OutputSpec[RunOutputDataT] | None = None, - model: Model | KnownModelName | str | None = None, - deps: AgentDepsT = None, - model_settings: ModelSettings | None = None, - usage_limits: UsageLimits | None = None, - usage: Usage | None = None, - infer_name: bool = True, - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, - ) -> AsyncGenerator[str, None]: - """Run the agent with streaming response using AG-UI protocol events. + return StreamingResponse( + run_ag_ui( + agent, + input_data, + accept, + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ), + media_type=accept, + ) - The first two arguments are specific to `Adapter` the rest map directly to the `Agent.iter` method. - Args: - run_input: The AG-UI run input containing thread_id, run_id, messages, etc. - accept: The accept header value for the run. +async def run_ag_ui( + agent: Agent[AgentDepsT, Any], + run_input: RunAgentInput, + accept: str = SSE_CONTENT_TYPE, + *, + output_type: OutputSpec[Any] | None = None, + model: Model | KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, +) -> AsyncIterator[str]: + """Run the agent with the AG-UI run input and stream AG-UI protocol events. - output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no - output validators since output validators would expect an argument that matches the agent's output type. - model: Optional model to use for this run, required if `model` was not set when creating the agent. - deps: Optional dependencies to use for this run. - model_settings: Optional settings to use for this model's request. - usage_limits: Optional limits on model request count or token usage. - usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + Args: + agent: The agent to run. + run_input: The AG-UI run input containing thread_id, run_id, messages, etc. + accept: The accept header value for the run. + + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. + + Yields: + Streaming event chunks encoded as strings according to the accept header value. + """ + encoder = EventEncoder(accept=accept) + if run_input.tools: + # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the + # Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any + # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets. + toolset = DeferredToolset[AgentDepsT]( + [ + ToolDefinition( + name=tool.name, + description=tool.description, + parameters_json_schema=tool.parameters, + ) + for tool in run_input.tools + ] + ) + toolsets = [*toolsets, toolset] if toolsets else [toolset] + + try: + yield encoder.encode( + RunStartedEvent( + thread_id=run_input.thread_id, + run_id=run_input.run_id, + ), + ) - Yields: - Streaming SSE-formatted event chunks. - """ - encoder = EventEncoder(accept=accept) - if run_input.tools: - # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the - # Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any - # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets. - toolset = DeferredToolset[AgentDepsT]( - [ - ToolDefinition( - name=tool.name, - description=tool.description, - parameters_json_schema=tool.parameters, - ) - for tool in run_input.tools - ] - ) - toolsets = [*toolsets, toolset] if toolsets else [toolset] - - try: - yield encoder.encode( - RunStartedEvent( - thread_id=run_input.thread_id, - run_id=run_input.run_id, - ), - ) + if not run_input.messages: + raise _NoMessagesError - if not run_input.messages: - raise _NoMessagesError - - raw_state: dict[str, Any] = run_input.state or {} - if isinstance(deps, StateHandler): - if isinstance(deps.state, BaseModel): - try: - state = type(deps.state).model_validate(raw_state) - except ValidationError as e: # pragma: no cover - raise _InvalidStateError from e - else: - state = raw_state - - deps = replace(deps, state=state) - elif raw_state: - raise UserError( - f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.' - ) + raw_state: dict[str, Any] = run_input.state or {} + if isinstance(deps, StateHandler): + if isinstance(deps.state, BaseModel): + try: + state = type(deps.state).model_validate(raw_state) + except ValidationError as e: # pragma: no cover + raise _InvalidStateError from e else: - # `deps` not being a `StateHandler` is OK if there is no state. - pass - - messages = _messages_from_ag_ui(run_input.messages) + state = raw_state - async with self.agent.iter( - user_prompt=None, - output_type=[output_type or self.agent.output_type, DeferredToolCalls], - message_history=messages, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - infer_name=infer_name, - toolsets=toolsets, - ) as run: - async for event in self._agent_stream(run): - yield encoder.encode(event) - except _RunError as e: - yield encoder.encode( - RunErrorEvent(message=e.message, code=e.code), - ) - except Exception as e: - yield encoder.encode( - RunErrorEvent(message=str(e)), + deps = replace(deps, state=state) + elif raw_state: + raise UserError( + f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.' ) - raise e else: - yield encoder.encode( - RunFinishedEvent( - thread_id=run_input.thread_id, - run_id=run_input.run_id, - ), - ) + # `deps` not being a `StateHandler` is OK if there is no state. + pass - async def _agent_stream( - self, - run: AgentRun[AgentDepsT, Any], - ) -> AsyncGenerator[BaseEvent, None]: - """Run the agent streaming responses using AG-UI protocol events. + messages = _messages_from_ag_ui(run_input.messages) + + async with agent.iter( + user_prompt=None, + output_type=[output_type or agent.output_type, DeferredToolCalls], + message_history=messages, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as run: + async for event in _agent_stream(run): + yield encoder.encode(event) + except _RunError as e: + yield encoder.encode( + RunErrorEvent(message=e.message, code=e.code), + ) + except Exception as e: + yield encoder.encode( + RunErrorEvent(message=str(e)), + ) + raise e + else: + yield encoder.encode( + RunFinishedEvent( + thread_id=run_input.thread_id, + run_id=run_input.run_id, + ), + ) - Args: - run: The agent run to process. - Yields: - AG-UI Server-Sent Events (SSE). - """ - async for node in run: - stream_ctx = _RequestStreamContext() - if isinstance(node, ModelRequestNode): - async with node.stream(run.ctx) as request_stream: - async for agent_event in request_stream: - async for msg in self._handle_model_request_event(stream_ctx, agent_event): +async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEvent]: + """Run the agent streaming responses using AG-UI protocol events. + + Args: + run: The agent run to process. + + Yields: + AG-UI Server-Sent Events (SSE). + """ + async for node in run: + stream_ctx = _RequestStreamContext() + if isinstance(node, ModelRequestNode): + async with node.stream(run.ctx) as request_stream: + async for agent_event in request_stream: + async for msg in _handle_model_request_event(stream_ctx, agent_event): + yield msg + + if stream_ctx.part_end: # pragma: no branch + yield stream_ctx.part_end + stream_ctx.part_end = None + elif isinstance(node, CallToolsNode): + async with node.stream(run.ctx) as handle_stream: + async for event in handle_stream: + if isinstance(event, FunctionToolResultEvent): + async for msg in _handle_tool_result_event(stream_ctx, event): yield msg - if stream_ctx.part_end: # pragma: no branch - yield stream_ctx.part_end - stream_ctx.part_end = None - elif isinstance(node, CallToolsNode): - async with node.stream(run.ctx) as handle_stream: - async for event in handle_stream: - if isinstance(event, FunctionToolResultEvent): - async for msg in self._handle_tool_result_event(stream_ctx, event): - yield msg - - async def _handle_model_request_event( - self, - stream_ctx: _RequestStreamContext, - agent_event: AgentStreamEvent, - ) -> AsyncGenerator[BaseEvent, None]: - """Handle an agent event and yield AG-UI protocol events. - Args: - stream_ctx: The request stream context to manage state. - agent_event: The agent event to process. +async def _handle_model_request_event( + stream_ctx: _RequestStreamContext, + agent_event: AgentStreamEvent, +) -> AsyncIterator[BaseEvent]: + """Handle an agent event and yield AG-UI protocol events. - Yields: - AG-UI Server-Sent Events (SSE) based on the agent event. - """ - if isinstance(agent_event, PartStartEvent): - if stream_ctx.part_end: - # End the previous part. - yield stream_ctx.part_end - stream_ctx.part_end = None - - part = agent_event.part - if isinstance(part, TextPart): - message_id = stream_ctx.new_message_id() - yield TextMessageStartEvent( - message_id=message_id, - ) - if part.content: # pragma: no branch - yield TextMessageContentEvent( - message_id=message_id, - delta=part.content, - ) - stream_ctx.part_end = TextMessageEndEvent( + Args: + stream_ctx: The request stream context to manage state. + agent_event: The agent event to process. + + Yields: + AG-UI Server-Sent Events (SSE) based on the agent event. + """ + if isinstance(agent_event, PartStartEvent): + if stream_ctx.part_end: + # End the previous part. + yield stream_ctx.part_end + stream_ctx.part_end = None + + part = agent_event.part + if isinstance(part, TextPart): + message_id = stream_ctx.new_message_id() + yield TextMessageStartEvent( + message_id=message_id, + ) + if part.content: # pragma: no branch + yield TextMessageContentEvent( message_id=message_id, + delta=part.content, ) - elif isinstance(part, ToolCallPart): # pragma: no branch - message_id = stream_ctx.message_id or stream_ctx.new_message_id() - yield ToolCallStartEvent( - tool_call_id=part.tool_call_id, - tool_call_name=part.tool_name, - parent_message_id=message_id, - ) - if part.args: - yield ToolCallArgsEvent( - tool_call_id=part.tool_call_id, - delta=part.args if isinstance(part.args, str) else json.dumps(part.args), - ) - stream_ctx.part_end = ToolCallEndEvent( + stream_ctx.part_end = TextMessageEndEvent( + message_id=message_id, + ) + elif isinstance(part, ToolCallPart): # pragma: no branch + message_id = stream_ctx.message_id or stream_ctx.new_message_id() + yield ToolCallStartEvent( + tool_call_id=part.tool_call_id, + tool_call_name=part.tool_name, + parent_message_id=message_id, + ) + if part.args: + yield ToolCallArgsEvent( tool_call_id=part.tool_call_id, + delta=part.args if isinstance(part.args, str) else json.dumps(part.args), ) + stream_ctx.part_end = ToolCallEndEvent( + tool_call_id=part.tool_call_id, + ) - elif isinstance(part, ThinkingPart): # pragma: no branch - yield ThinkingTextMessageStartEvent( - type=EventType.THINKING_TEXT_MESSAGE_START, - ) - # Always send the content even if it's empty, as it may be - # used to indicate the start of thinking. + elif isinstance(part, ThinkingPart): # pragma: no branch + yield ThinkingTextMessageStartEvent( + type=EventType.THINKING_TEXT_MESSAGE_START, + ) + # Always send the content even if it's empty, as it may be + # used to indicate the start of thinking. + yield ThinkingTextMessageContentEvent( + type=EventType.THINKING_TEXT_MESSAGE_CONTENT, + delta=part.content, + ) + stream_ctx.part_end = ThinkingTextMessageEndEvent( + type=EventType.THINKING_TEXT_MESSAGE_END, + ) + + elif isinstance(agent_event, PartDeltaEvent): + delta = agent_event.delta + if isinstance(delta, TextPartDelta): + yield TextMessageContentEvent( + message_id=stream_ctx.message_id, + delta=delta.content_delta, + ) + elif isinstance(delta, ToolCallPartDelta): # pragma: no branch + assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set' + yield ToolCallArgsEvent( + tool_call_id=delta.tool_call_id, + delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta), + ) + elif isinstance(delta, ThinkingPartDelta): # pragma: no branch + if delta.content_delta: # pragma: no branch yield ThinkingTextMessageContentEvent( type=EventType.THINKING_TEXT_MESSAGE_CONTENT, - delta=part.content, - ) - stream_ctx.part_end = ThinkingTextMessageEndEvent( - type=EventType.THINKING_TEXT_MESSAGE_END, - ) - - elif isinstance(agent_event, PartDeltaEvent): - delta = agent_event.delta - if isinstance(delta, TextPartDelta): - yield TextMessageContentEvent( - message_id=stream_ctx.message_id, delta=delta.content_delta, ) - elif isinstance(delta, ToolCallPartDelta): # pragma: no branch - assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set' - yield ToolCallArgsEvent( - tool_call_id=delta.tool_call_id, - delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta), - ) - elif isinstance(delta, ThinkingPartDelta): # pragma: no branch - if delta.content_delta: # pragma: no branch - yield ThinkingTextMessageContentEvent( - type=EventType.THINKING_TEXT_MESSAGE_CONTENT, - delta=delta.content_delta, - ) - async def _handle_tool_result_event( - self, - stream_ctx: _RequestStreamContext, - event: FunctionToolResultEvent, - ) -> AsyncGenerator[BaseEvent, None]: - """Convert a tool call result to AG-UI events. - Args: - stream_ctx: The request stream context to manage state. - event: The tool call result event to process. +async def _handle_tool_result_event( + stream_ctx: _RequestStreamContext, + event: FunctionToolResultEvent, +) -> AsyncIterator[BaseEvent]: + """Convert a tool call result to AG-UI events. - Yields: - AG-UI Server-Sent Events (SSE). - """ - result = event.result - if not isinstance(result, ToolReturnPart): - return - - message_id = stream_ctx.new_message_id() - yield ToolCallResultEvent( - message_id=message_id, - type=EventType.TOOL_CALL_RESULT, - role='tool', - tool_call_id=result.tool_call_id, - content=result.model_response_str(), - ) + Args: + stream_ctx: The request stream context to manage state. + event: The tool call result event to process. - # Now check for AG-UI events returned by the tool calls. - content = result.content - if isinstance(content, BaseEvent): - yield content - elif isinstance(content, (str, bytes)): # pragma: no branch - # Avoid iterable check for strings and bytes. - pass - elif isinstance(content, Iterable): # pragma: no branch - for item in content: # type: ignore[reportUnknownMemberType] - if isinstance(item, BaseEvent): # pragma: no branch - yield item + Yields: + AG-UI Server-Sent Events (SSE). + """ + result = event.result + if not isinstance(result, ToolReturnPart): + return + + message_id = stream_ctx.new_message_id() + yield ToolCallResultEvent( + message_id=message_id, + type=EventType.TOOL_CALL_RESULT, + role='tool', + tool_call_id=result.tool_call_id, + content=result.model_response_str(), + ) + + # Now check for AG-UI events returned by the tool calls. + content = result.content + if isinstance(content, BaseEvent): + yield content + elif isinstance(content, (str, bytes)): # pragma: no branch + # Avoid iterable check for strings and bytes. + pass + elif isinstance(content, Iterable): # pragma: no branch + for item in content: # type: ignore[reportUnknownMemberType] + if isinstance(item, BaseEvent): # pragma: no branch + yield item def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 008e173375..271d0ebc7f 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1870,9 +1870,13 @@ def to_ag_ui( on_shutdown: Sequence[Callable[[], Any]] | None = None, lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, ) -> AGUIApp[AgentDepsT, OutputDataT]: - """Convert the agent to an AG-UI application. + """Returns an ASGI application that handles every AG-UI request by running the agent. - This allows you to use the agent with a compatible AG-UI frontend. + Note that the `deps` will be the same for each request, with the exception of the AG-UI state that's + injected into the `state` field of a `deps` object that implements the [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol. + To provide different `deps` for each request (e.g. based on the authenticated user), + use [`pydantic_ai.ag_ui.run_ag_ui`][pydantic_ai.ag_ui.run_ag_ui] or + [`pydantic_ai.ag_ui.handle_ag_ui_request`][pydantic_ai.ag_ui.handle_ag_ui_request] instead. Example: ```python @@ -1882,8 +1886,6 @@ def to_ag_ui( app = agent.to_ag_ui() ``` - The `app` is an ASGI application that can be used with any ASGI server. - To run the application, you can use the following command: ```bash @@ -1902,7 +1904,7 @@ def to_ag_ui( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. - toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + toolsets: Optional additional toolsets for this run. debug: Boolean indicating if debug tracebacks should be returned on errors. routes: A list of routes to serve incoming HTTP and WebSocket requests. diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 0da58aa218..8d42fa4c63 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -58,7 +58,7 @@ from pydantic_ai.ag_ui import ( SSE_CONTENT_TYPE, StateDeps, - _Adapter, # type: ignore[reportPrivateUsage] + run_ag_ui, ) has_ag_ui = True @@ -95,13 +95,12 @@ def simple_result() -> Any: ) -async def collect_events_from_adapter( - adapter: _Adapter[AgentDepsT, OutputDataT], *run_inputs: RunAgentInput, deps: AgentDepsT = None +async def run_and_collect_events( + agent: Agent[AgentDepsT, OutputDataT], *run_inputs: RunAgentInput, deps: AgentDepsT = None ) -> list[dict[str, Any]]: - """Helper function to collect events from an AG-UI adapter run.""" events = list[dict[str, Any]]() for run_input in run_inputs: - async for event in adapter.run(run_input, deps=deps): + async for event in run_ag_ui(agent, run_input, deps=deps): events.append(json.loads(event.removeprefix('data: '))) return events @@ -202,7 +201,7 @@ async def test_basic_user_message() -> None: agent = Agent( model=FunctionModel(stream_function=simple_stream), ) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', @@ -210,7 +209,7 @@ async def test_basic_user_message() -> None: ) ) - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == simple_result() @@ -227,9 +226,9 @@ async def stream_function( agent = Agent( model=FunctionModel(stream_function=stream_function), ) - adapter = _Adapter(agent=agent) + run_input = create_input() - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == snapshot( [ @@ -248,7 +247,7 @@ async def test_multiple_messages() -> None: agent = Agent( model=FunctionModel(stream_function=simple_stream), ) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', @@ -272,7 +271,7 @@ async def test_multiple_messages() -> None: ), ) - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == simple_result() @@ -282,7 +281,7 @@ async def test_messages_with_history() -> None: agent = Agent( model=FunctionModel(stream_function=simple_stream), ) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', @@ -294,7 +293,7 @@ async def test_messages_with_history() -> None: ), ) - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == simple_result() @@ -317,7 +316,7 @@ async def stream_function( model=FunctionModel(stream_function=stream_function), tools=[send_snapshot, send_custom, current_time], ) - adapter = _Adapter(agent=agent) + thread_id = uuid_str() run_inputs = [ create_input( @@ -355,7 +354,7 @@ async def stream_function( ), ] - events = await collect_events_from_adapter(adapter, *run_inputs) + events = await run_and_collect_events(agent, *run_inputs) assert events == snapshot( [ @@ -427,7 +426,7 @@ async def stream_function( agent = Agent( model=FunctionModel(stream_function=stream_function), ) - adapter = _Adapter(agent=agent) + tool_call_id1 = uuid_str() tool_call_id2 = uuid_str() run_inputs = [ @@ -486,7 +485,7 @@ async def stream_function( ), ] - events = await collect_events_from_adapter(adapter, *run_inputs) + events = await run_and_collect_events(agent, *run_inputs) assert events == snapshot( [ @@ -562,7 +561,7 @@ async def stream_function( yield '{"get_weather": "Tool result"}' agent = Agent(model=FunctionModel(stream_function=stream_function)) - adapter = _Adapter(agent=agent) + run_inputs = [ ( first_input := create_input( @@ -600,7 +599,7 @@ async def stream_function( thread_id=first_input.thread_id, ), ] - events = await collect_events_from_adapter(adapter, *run_inputs) + events = await run_and_collect_events(agent, *run_inputs) assert events == snapshot( [ @@ -675,14 +674,14 @@ async def stream_function( model=FunctionModel(stream_function=stream_function), tools=[send_snapshot], ) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', content='Please call send_snapshot', ), ) - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == snapshot( [ @@ -744,14 +743,14 @@ async def stream_function( model=FunctionModel(stream_function=stream_function), tools=[send_custom], ) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', content='Please call send_custom', ), ) - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == snapshot( [ @@ -812,7 +811,6 @@ async def stream_function( tools=[send_snapshot, send_custom, current_time], ) - adapter = _Adapter(agent=agent) run_input = create_input( UserMessage( id='msg_1', @@ -820,7 +818,7 @@ async def stream_function( ), ) - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == snapshot( [ @@ -873,7 +871,7 @@ async def stream_function( agent = Agent( model=FunctionModel(stream_function=stream_function), ) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', @@ -881,7 +879,7 @@ async def stream_function( ), ) - events = await collect_events_from_adapter(adapter, run_input) + events = await run_and_collect_events(agent, run_input) assert events == snapshot( [ @@ -933,7 +931,7 @@ async def stream_function( model=FunctionModel(stream_function=stream_function), tools=[current_time], ) - adapter = _Adapter(agent=agent) + run_inputs = [ ( first_input := create_input( @@ -989,7 +987,7 @@ async def stream_function( thread_id=first_input.thread_id, ), ] - events = await collect_events_from_adapter(adapter, *run_inputs) + events = await run_and_collect_events(agent, *run_inputs) assert events == snapshot( [ @@ -1068,7 +1066,7 @@ async def store_state( deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType] prepare_tools=store_state, ) - adapter = _Adapter(agent=agent) + run_inputs = [ create_input( UserMessage( @@ -1102,7 +1100,7 @@ async def store_state( for run_input in run_inputs: events = list[dict[str, Any]]() - async for event in adapter.run(run_input, deps=deps): + async for event in run_ag_ui(agent, run_input, deps=deps): events.append(json.loads(event.removeprefix('data: '))) assert events == simple_result() @@ -1122,7 +1120,7 @@ async def store_state( async def test_request_with_state_without_handler() -> None: agent = Agent(model=FunctionModel(stream_function=simple_stream)) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', @@ -1135,7 +1133,7 @@ async def test_request_with_state_without_handler() -> None: UserError, match='AG-UI state is provided but `deps` of type `NoneType` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.', ): - async for _ in adapter.run(run_input): + async for _ in run_ag_ui(agent, run_input): pass @@ -1155,7 +1153,7 @@ async def store_state(ctx: RunContext[CustomStateDeps], tool_defs: list[ToolDefi deps_type=CustomStateDeps, prepare_tools=store_state, ) - adapter = _Adapter(agent=agent) + run_input = create_input( UserMessage( id='msg_1', @@ -1164,7 +1162,7 @@ async def store_state(ctx: RunContext[CustomStateDeps], tool_defs: list[ToolDefi state={'value': 42}, ) - async for _ in adapter.run(run_input, deps=CustomStateDeps(state={'value': 0})): + async for _ in run_ag_ui(agent, run_input, deps=CustomStateDeps(state={'value': 0})): pass assert seen_states[-1] == {'value': 42} @@ -1183,7 +1181,6 @@ async def test_concurrent_runs() -> None: async def get_state(ctx: RunContext[StateDeps[StateInt]]) -> int: return ctx.deps.state.value - adapter = _Adapter(agent=agent) concurrent_tasks: list[asyncio.Task[list[dict[str, Any]]]] = [] for i in range(5): # Test with 5 concurrent runs @@ -1196,7 +1193,7 @@ async def get_state(ctx: RunContext[StateDeps[StateInt]]) -> int: thread_id=f'test_thread_{i}', ) - task = asyncio.create_task(collect_events_from_adapter(adapter, run_input, deps=StateDeps(StateInt()))) + task = asyncio.create_task(run_and_collect_events(agent, run_input, deps=StateDeps(StateInt()))) concurrent_tasks.append(task) results = await asyncio.gather(*concurrent_tasks)