diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 33299a7766..7a2ce36f37 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -12,6 +12,7 @@ Of course, you can combine multiple strategies in a single application. ## Agent delegation "Agent delegation" refers to the scenario where an agent delegates work to another agent, then takes back control when the delegate agent (the agent called from within a tool) finishes. +If you want to hand off control to another agent completely, without coming back to the first agent, you can use an [output function](output.md#output-functions). Since agents are stateless and designed to be global, you do not need to include the agent itself in agent [dependencies](dependencies.md). diff --git a/docs/output.md b/docs/output.md index 09eeb27580..7eac789c4d 100644 --- a/docs/output.md +++ b/docs/output.md @@ -1,9 +1,13 @@ -"Output" refers to the final value returned from [running an agent](agents.md#running-agents) these can be either plain text or structured data. +"Output" refers to the final value returned from [running an agent](agents.md#running-agents). This can be either plain text, [structured data](#structured-output), or the result of a [function](#output-functions) called with arguments provided by the model. -The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) +The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so that you can access other data, like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results). Both `AgentRunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. +A run ends when a plain text response is received (assuming no output type is specified or `str` is one of the allowed options), or when the model responds with one of the structured output types by calling a special output tool. A run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits). + +Here's an example using a Pydantic model as the `output_type`, forcing the model to respond with data matching our specification: + ```python {title="olympics.py" line_length="90"} from pydantic import BaseModel @@ -25,27 +29,32 @@ print(result.usage()) _(This example is complete, it can be run "as is")_ -Runs end when either a plain text response is received or the model calls a tool associated with one of the structured output types (run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits)). - ## Output data {#structured-output} -When the output type is `str`, or a union including `str`, plain text responses are enabled on the model, and the raw text response from the model is used as the response data. +The [`Agent`][pydantic_ai.Agent] class constructor takes an `output_type` argument that takes one or more types or [output functions](#output-functions). It supports both type unions and lists of types and functions. + +When no output type is specified, or when the output type is `str` or a union or list of types including `str`, the model is allowed to respond with plain text, and this text is used as the output data. +If `str` is not among the allowed output types, the model is not allowed to respond with plain text and is forced to return structured data (or arguments to an output function). -If the output type is a union with multiple members (after removing `str` from the members), each member is registered as a separate tool with the model in order to reduce the complexity of the tool schemas and maximise the chances a model will respond correctly. +If the output type is a union or list with multiple members, each member (except for `str`, if it is a member) is registered with the model as a separate output tool in order to reduce the complexity of the tool schemas and maximise the chances a model will respond correctly. If the output type schema is not of type `"object"` (e.g. it's `int` or `list[int]`), the output type is wrapped in a single element object, so the schema of all tools registered with the model are object schemas. Structured outputs (like tools) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model. -!!! note "Bring on PEP-747" - Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands, unions are not valid as `type`s in Python. +!!! note "Type checking considerations" + The Agent class is generic in its output type, and this type is carried through to `AgentRunResult.output` and `StreamedRunResult.output` so that your IDE or static type checker can warn you when your code doesn't properly take into account all the possible values those outputs could have. - When creating the agent we need to `# type: ignore` the `output_type` argument, and add a type hint to tell type checkers about the type of the agent. + Static type checkers like pyright and mypy will do their best the infer the agent's output type from the `output_type` you've specified, but they're not always able to do so correctly when you provide functions or multiple types in a union or list, even though PydanticAI will behave correctly. When this happens, your type checker will complain even when you're confident you've passed a valid `output_type`, and you'll need to help the type checker by explicitly specifying the generic parameters on the `Agent` constructor. This is shown in the second example below and the output functions example further down. -Here's an example of returning either text or a structured value + Specifically, there are three valid uses of `output_type` where you'll need to do this: + 1. When using a union of types, e.g. `output_type=Foo | Bar` or in older Python, `output_type=Union[Foo, Bar]`. Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands in Python 3.15, type checkers do not consider these a valid value for `output_type`. In addition to the generic parameters on the `Agent` constructor, you'll need to add `# type: ignore` to the line that passes the union to `output_type`. + 2. With mypy: When using a list, as a functionally equivalent alternative to a union, or because you're passing in [output functions](#output-functions). Pyright does handle this correctly, and we've filed [an issue](https://github.com/python/mypy/issues/19142) with mypy to try and get this fixed. + 3. With mypy: when using an async output function. Pyright does handle this correctly, and we've filed [an issue](https://github.com/python/mypy/issues/19143) with mypy to try and get this fixed. + +Here's an example of returning either text or structured data: ```python {title="box_or_error.py"} -from typing import Union from pydantic import BaseModel @@ -59,9 +68,9 @@ class Box(BaseModel): units: str -agent: Agent[None, Union[Box, str]] = Agent( +agent = Agent( 'openai:gpt-4o-mini', - output_type=Union[Box, str], # type: ignore + output_type=[Box, str], system_prompt=( "Extract me the dimensions of a box, " "if you can't extract all data, ask the user to try again." @@ -79,14 +88,14 @@ print(result.output) _(This example is complete, it can be run "as is")_ -Here's an example of using a union return type which registers multiple tools, and wraps non-object schemas in an object: +Here's an example of using a union return type, for which PydanticAI will register multiple tools and wraps non-object schemas in an object: ```python {title="colors_or_sizes.py"} from typing import Union from pydantic_ai import Agent -agent: Agent[None, Union[list[str], list[int]]] = Agent( +agent = Agent[None, Union[list[str], list[int]]]( 'openai:gpt-4o-mini', output_type=Union[list[str], list[int]], # type: ignore system_prompt='Extract either colors or sizes from the shapes provided.', @@ -103,10 +112,135 @@ print(result.output) _(This example is complete, it can be run "as is")_ -### Output validator functions +### Output functions + +Instead of plain text or structured data, you may want the output of your agent run to be the result of a function called with arguments provided by the model, for example to further process or validate the data provided through the arguments (with the option to tell the model to try again), or to hand off to another agent. + +Output functions are similar to [function tools](tools.md), but the model is forced to call one of them, the call ends the agent run, and the result is not passed back to the model. + +As with tool functions, output function arguments provided by the model are validated using Pydantic, they can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and they can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type). + +To specify output functions, you set the agent's `output_type` to either a single function (or bound instance method), or a list of functions. The list can also contain other output types like simple scalars or entire Pydantic models. +You typically do not want to also register your output function as a tool (using the `@agent.tool` decorator or `tools` argument), as this could confuse the model about which it should be calling. + +Here's an example of all of these features in action: + +```python {title="output_functions.py"} +import re +from typing import Union + +from pydantic import BaseModel + +from pydantic_ai import Agent, ModelRetry, RunContext +from pydantic_ai._output import ToolRetryError +from pydantic_ai.exceptions import UnexpectedModelBehavior + + +class Row(BaseModel): + name: str + country: str + + +tables = { + 'capital_cities': [ + Row(name='Amsterdam', country='Netherlands'), + Row(name='Mexico City', country='Mexico'), + ] +} + + +class SQLFailure(BaseModel): + """An unrecoverable failure. Only use this when you can't change the query to make it work.""" + + explanation: str + + +def run_sql_query(query: str) -> list[Row]: + """Run a SQL query on the database.""" + + select_table = re.match(r'SELECT (.+) FROM (\w+)', query) + if select_table: + column_names = select_table.group(1) + if column_names != '*': + raise ModelRetry("Only 'SELECT *' is supported, you'll have to do column filtering manually.") + + table_name = select_table.group(2) + if table_name not in tables: + raise ModelRetry( + f"Unknown table '{table_name}' in query '{query}'. Available tables: {', '.join(tables.keys())}." + ) + + return tables[table_name] + + raise ModelRetry(f"Unsupported query: '{query}'.") + + +sql_agent = Agent[None, Union[list[Row], SQLFailure]]( + 'openai:gpt-4o', + output_type=[run_sql_query, SQLFailure], + instructions='You are a SQL agent that can run SQL queries on a database.', +) + + +async def hand_off_to_sql_agent(ctx: RunContext, query: str) -> list[Row]: + """I take natural language queries, turn them into SQL, and run them on a database.""" + + # Drop the final message with the output tool call, as it shouldn't be passed on to the SQL agent + messages = ctx.messages[:-1] + try: + result = await sql_agent.run(query, message_history=messages) + output = result.output + if isinstance(output, SQLFailure): + raise ModelRetry(f'SQL agent failed: {output.explanation}') + return output + except UnexpectedModelBehavior as e: + # Bubble up potentially retryable errors to the router agent + if (cause := e.__cause__) and isinstance(cause, ToolRetryError): + raise ModelRetry(f'SQL agent failed: {cause.tool_retry.content}') from e + else: + raise + + +class RouterFailure(BaseModel): + """Use me when no appropriate agent is found or the used agent failed.""" + + explanation: str + + +router_agent = Agent[None, Union[list[Row], RouterFailure]]( + 'openai:gpt-4o', + output_type=[hand_off_to_sql_agent, RouterFailure], + instructions='You are a router to other agents. Never try to solve a problem yourself, just pass it on.', +) + +result = router_agent.run_sync('Select the names and countries of all capitals') +print(result.output) +""" +[ + Row(name='Amsterdam', country='Netherlands'), + Row(name='Mexico City', country='Mexico'), +] +""" + +result = router_agent.run_sync('Select all pets') +print(result.output) +""" +explanation = "The requested table 'pets' does not exist in the database. The only available table is 'capital_cities', which does not contain data about pets." +""" + +result = router_agent.run_sync('How do I fly from Amsterdam to Mexico City?') +print(result.output) +""" +explanation = 'I am not equipped to provide travel information, such as flights from Amsterdam to Mexico City.' +""" +``` + +### Output validators {#output-validator-functions} Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. +If you want to implement separate validation logic for different output types, it's recommended to use [output functions](#output-functions) instead, to save you from having to do `isinstance` checks inside the output validator. + Here's a simplified variant of the [SQL Generation example](examples/sql-gen.md): ```python {title="sql_gen.py"} @@ -127,7 +261,7 @@ class InvalidRequest(BaseModel): Output = Union[Success, InvalidRequest] -agent: Agent[DatabaseConn, Output] = Agent( +agent = Agent[DatabaseConn, Output]( 'google-gla:gemini-1.5-flash', output_type=Output, # type: ignore deps_type=DatabaseConn, diff --git a/docs/tools.md b/docs/tools.md index 33bfdeb3a0..921edd95e1 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -2,7 +2,9 @@ Function tools provide a mechanism for models to retrieve extra information to help them generate a response. -They're useful when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. +They're useful when you want to enable the model to take some action and use the result, when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. + +If you want a model to be able to call a function as its final action, without the result being sent back to the model, you can use an [output function](output.md#output-functions) instead. !!! info "Function tools vs. RAG" Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. diff --git a/examples/pydantic_ai_examples/sql_gen.py b/examples/pydantic_ai_examples/sql_gen.py index 28b5459fb7..fdf8c5ff3d 100644 --- a/examples/pydantic_ai_examples/sql_gen.py +++ b/examples/pydantic_ai_examples/sql_gen.py @@ -92,7 +92,7 @@ class InvalidRequest(BaseModel): Response: TypeAlias = Union[Success, InvalidRequest] -agent: Agent[Deps, Response] = Agent( +agent = Agent[Deps, Response]( 'google-gla:gemini-1.5-flash', # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else output_type=Response, # type: ignore diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 4d3ace8b12..bf83971742 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -24,7 +24,7 @@ result, usage as _usage, ) -from .result import OutputDataT, ToolOutput +from .result import OutputDataT from .settings import ModelSettings, merge_model_settings from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc @@ -64,12 +64,14 @@ class GraphAgentState: retries: int run_step: int - def increment_retries(self, max_result_retries: int) -> None: + def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None: self.retries += 1 if self.retries > max_result_retries: - raise exceptions.UnexpectedModelBehavior( - f'Exceeded maximum retries ({max_result_retries}) for result validation' - ) + message = f'Exceeded maximum retries ({max_result_retries}) for result validation' + if error: + raise exceptions.UnexpectedModelBehavior(message) from error + else: + raise exceptions.UnexpectedModelBehavior(message) @dataclasses.dataclass @@ -264,7 +266,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None: output_schema = ctx.deps.output_schema return models.ModelRequestParameters( function_tools=function_tool_defs, - allow_text_output=allow_text_output(output_schema), + allow_text_output=_output.allow_text_output(output_schema), output_tools=output_schema.tool_defs() if output_schema is not None else [], ) @@ -450,7 +452,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # when the model has already returned text along side tool calls # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any - if allow_text_output(ctx.deps.output_schema): + if _output.allow_text_output(ctx.deps.output_schema): for message in reversed(ctx.state.message_history): if isinstance(message, _messages.ModelResponse): last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)] @@ -471,6 +473,7 @@ async def _handle_tool_calls( tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: output_schema = ctx.deps.output_schema + run_context = build_run_context(ctx) # first, look for the output tool call final_result: result.FinalResult[NodeRunEndT] | None = None @@ -478,12 +481,12 @@ async def _handle_tool_calls( if output_schema is not None: for call, output_tool in output_schema.find_tool(tool_calls): try: - result_data = output_tool.validate(call) + result_data = await output_tool.process(call, run_context) result_data = await _validate_output(result_data, ctx, call) except _output.ToolRetryError as e: # TODO: Should only increment retry stuff once per node execution, not for each tool call # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries) + ctx.state.increment_retries(ctx.deps.max_result_retries, e) parts.append(e.tool_retry) else: final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) @@ -505,7 +508,6 @@ async def _handle_tool_calls( else: if tool_responses: parts.extend(tool_responses) - run_context = build_run_context(ctx) instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( _messages.ModelRequest(parts=parts, instructions=instructions) @@ -533,27 +535,22 @@ async def _handle_text_response( output_schema = ctx.deps.output_schema text = '\n\n'.join(texts) - if allow_text_output(output_schema): - # The following cast is safe because we know `str` is an allowed result type - result_data_input = cast(NodeRunEndT, text) - try: - result_data = await _validate_output(result_data_input, ctx, None) - except _output.ToolRetryError as e: - ctx.state.increment_retries(ctx.deps.max_result_retries) - return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) + try: + if _output.allow_text_output(output_schema): + # The following cast is safe because we know `str` is an allowed result type + result_data = cast(NodeRunEndT, text) else: - return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) - else: - ctx.state.increment_retries(ctx.deps.max_result_retries) - return ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest( - parts=[ - _messages.RetryPromptPart( - content='Plain text responses are not permitted, please include your response in a tool call', - ) - ] + m = _messages.RetryPromptPart( + content='Plain text responses are not permitted, please include your response in a tool call', ) - ) + raise _output.ToolRetryError(m) + + result_data = await _validate_output(result_data, ctx, None) + except _output.ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) + else: + return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: @@ -795,11 +792,6 @@ async def _validate_output( return result_data -def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool: - """Check if the result schema allows text results.""" - return output_schema is None or output_schema.allow_text_output - - @dataclasses.dataclass class _RunMessages: messages: list[_messages.ModelMessage] @@ -849,7 +841,9 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( - name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT] + name: str | None, + deps_type: type[DepsT], + output_type: _output.OutputType[OutputT], ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py similarity index 83% rename from pydantic_ai_slim/pydantic_ai/_pydantic.py rename to pydantic_ai_slim/pydantic_ai/_function_schema.py index f439357469..facca89aa6 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -5,6 +5,9 @@ from __future__ import annotations as _annotations +import inspect +from collections.abc import Awaitable +from dataclasses import dataclass from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Callable, cast @@ -15,10 +18,12 @@ from pydantic.json_schema import GenerateJsonSchema from pydantic.plugin._schema_validator import create_schema_validator from pydantic_core import SchemaValidator, core_schema -from typing_extensions import TypedDict, get_origin +from typing_extensions import get_origin + +from pydantic_ai.tools import RunContext from ._griffe import doc_descriptions -from ._utils import check_object_json_schema, is_model_like +from ._utils import check_object_json_schema, is_model_like, run_in_executor if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema @@ -27,24 +32,53 @@ __all__ = ('function_schema',) -class FunctionSchema(TypedDict): +@dataclass +class FunctionSchema: """Internal information about a function schema.""" + function: Callable[..., Any] description: str validator: SchemaValidator json_schema: ObjectJsonSchema # if not None, the function takes a single by that name (besides potentially `info`) + takes_ctx: bool + is_async: bool single_arg_name: str | None positional_fields: list[str] var_positional_field: str | None + async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any: + args, kwargs = self._call_args(args_dict, ctx) + if self.is_async: + function = cast(Callable[[Any], Awaitable[str]], self.function) + return await function(*args, **kwargs) + else: + function = cast(Callable[[Any], str], self.function) + return await run_in_executor(function, *args, **kwargs) + + def _call_args( + self, + args_dict: dict[str, Any], + ctx: RunContext[Any], + ) -> tuple[list[Any], dict[str, Any]]: + if self.single_arg_name: + args_dict = {self.single_arg_name: args_dict} + + args = [ctx] if self.takes_ctx else [] + for positional_field in self.positional_fields: + args.append(args_dict.pop(positional_field)) # pragma: no cover + if self.var_positional_field: + args.extend(args_dict.pop(self.var_positional_field)) + + return args, args_dict + def function_schema( # noqa: C901 function: Callable[..., Any], - takes_ctx: bool, - docstring_format: DocstringFormat, - require_parameter_descriptions: bool, schema_generator: type[GenerateJsonSchema], + takes_ctx: bool | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, ) -> FunctionSchema: """Build a Pydantic validator and JSON schema from a tool function. @@ -58,6 +92,9 @@ def function_schema( # noqa: C901 Returns: A `FunctionSchema` instance. """ + if takes_ctx is None: + takes_ctx = _takes_ctx(function) + config = ConfigDict(title=function.__name__, use_attribute_docstrings=True) config_wrapper = ConfigWrapper(config) gen_schema = _generate_schema.GenerateSchema(config_wrapper) @@ -176,10 +213,13 @@ def function_schema( # noqa: C901 single_arg_name=single_arg_name, positional_fields=positional_fields, var_positional_field=var_positional_field, + takes_ctx=takes_ctx, + is_async=inspect.iscoroutinefunction(function), + function=function, ) -def takes_ctx(function: Callable[..., Any]) -> bool: +def _takes_ctx(function: Callable[..., Any]) -> bool: """Check if a function takes a `RunContext` first argument. Args: @@ -196,7 +236,7 @@ def takes_ctx(function: Callable[..., Any]) -> bool: else: type_hints = _typing_extra.get_function_type_hints(function) annotation = type_hints[first_param_name] - return annotation is not sig.empty and _is_call_ctx(annotation) + return True is not sig.empty and _is_call_ctx(annotation) def _build_schema( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 1641bf6981..cfb0f8f0cb 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,22 +1,55 @@ from __future__ import annotations as _annotations import inspect -from collections.abc import Awaitable, Iterable, Iterator +from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, Union, cast from pydantic import TypeAdapter, ValidationError -from typing_extensions import TypedDict, TypeVar, get_args, get_origin +from pydantic_core import SchemaValidator +from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin -from . import _utils, messages as _messages +from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry -from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput -from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition +from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition T = TypeVar('T') """An invariant TypeVar.""" +OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) +""" +An invariant type variable for the result data of a model. + +We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used +in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types +possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and +changing it would have negative consequences for the ergonomics of the library. + +At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would +resolve these potential variance issues. +""" +OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) +"""Covariant type variable for the result data type of a run.""" + +OutputValidatorFunc = Union[ + Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv], + Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]], + Callable[[OutputDataT_inv], OutputDataT_inv], + Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]], +] +""" +A function that always takes and returns the same type of data (which is the result type of an agent run), and: + +* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument +* may or may not be async + +Usage `OutputValidatorFunc[AgentDepsT, T]`. +""" + + +DEFAULT_OUTPUT_TOOL_NAME = 'final_result' +DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' @dataclass @@ -76,69 +109,135 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): super().__init__() +@dataclass(init=False) +class ToolOutput(Generic[OutputDataT]): + """Marker class to use tools for outputs, and customize the tool.""" + + output_type: SimpleOutputType[OutputDataT] + name: str | None + description: str | None + max_retries: int | None + strict: bool | None + + def __init__( + self, + type_: SimpleOutputType[OutputDataT], + *, + name: str | None = None, + description: str | None = None, + max_retries: int | None = None, + strict: bool | None = None, + ): + self.output_type = type_ + self.name = name + self.description = description + self.max_retries = max_retries + self.strict = strict + + +T_co = TypeVar('T_co', covariant=True) +# output_type=Type or output_type=function or output_type=object.method +SimpleOutputType = TypeAliasType( + 'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,) +) +# output_type=ToolOutput() or +SimpleOutputTypeOrMarker = TypeAliasType( + 'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,) +) +# output_type= or [, ...] +OutputType = TypeAliasType( + 'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,) +) + + @dataclass class OutputSchema(Generic[OutputDataT]): - """Model the final response from an agent run. + """Model the final output from an agent run. Similar to `Tool` but for the final output of running an agent. """ - tools: dict[str, OutputSchemaTool[OutputDataT]] + tools: dict[str, OutputTool[OutputDataT]] allow_text_output: bool @classmethod def build( - cls: type[OutputSchema[T]], - output_type: type[T] | ToolOutput[T], + cls: type[OutputSchema[OutputDataT]], + output_type: OutputType[OutputDataT], name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> OutputSchema[T] | None: - """Build an OutputSchema dataclass from a response type.""" + ) -> OutputSchema[OutputDataT] | None: + """Build an OutputSchema dataclass from an output type.""" if output_type is str: return None - if isinstance(output_type, ToolOutput): - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - name = output_type.name - description = output_type.description - output_type_ = output_type.output_type - strict = output_type.strict + output_types: Sequence[SimpleOutputTypeOrMarker[OutputDataT]] + if isinstance(output_type, Sequence): + output_types = output_type else: - output_type_ = output_type + output_types = (output_type,) - if output_type_option := extract_str_from_union(output_type): - output_type_ = output_type_option.value + output_types_flat: list[SimpleOutputTypeOrMarker[OutputDataT]] = [] + for output_type in output_types: + if union_types := get_union_args(output_type): + output_types_flat.extend(union_types) + else: + output_types_flat.append(output_type) + + allow_text_output = False + if str in output_types_flat: allow_text_output = True - else: - allow_text_output = False - - tools: dict[str, OutputSchemaTool[T]] = {} - if args := get_union_args(output_type_): - for i, arg in enumerate(args, start=1): - tool_name = raw_tool_name = union_tool_name(name, arg) - while tool_name in tools: - tool_name = f'{raw_tool_name}_{i}' - tools[tool_name] = cast( - OutputSchemaTool[T], - OutputSchemaTool( - output_type=arg, name=tool_name, description=description, multiple=True, strict=strict - ), - ) - else: - name = name or DEFAULT_OUTPUT_TOOL_NAME - tools[name] = cast( - OutputSchemaTool[T], - OutputSchemaTool( - output_type=output_type_, name=name, description=description, multiple=False, strict=strict - ), + output_types_flat = [t for t in output_types_flat if t is not str] + + multiple = len(output_types_flat) > 1 + + default_tool_name = name or DEFAULT_OUTPUT_TOOL_NAME + default_tool_description = description + default_tool_strict = strict + + tools: dict[str, OutputTool[OutputDataT]] = {} + for output_type in output_types_flat: + tool_name = None + tool_description = None + tool_strict = None + if isinstance(output_type, ToolOutput): + tool_output_type = output_type.output_type + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + tool_name = output_type.name + tool_description = output_type.description + tool_strict = output_type.strict + else: + tool_output_type = output_type + + if tool_name is None: + tool_name = default_tool_name + if multiple: + tool_name += f'_{tool_output_type.__name__}' + + i = 1 + original_tool_name = tool_name + while tool_name in tools: + i += 1 + tool_name = f'{original_tool_name}_{i}' + + tool_description = tool_description or default_tool_description + if tool_strict is None: + tool_strict = default_tool_strict + + parameters_schema = OutputObjectSchema( + output_type=tool_output_type, description=tool_description, strict=tool_strict ) + tools[tool_name] = OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple) - return cls(tools=tools, allow_text_output=allow_text_output) + return cls( + tools=tools, + allow_text_output=allow_text_output, + ) def find_named_tool( self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None: + ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: """Find a tool that matches one of the calls, with a specific name.""" for part in parts: # pragma: no branch if isinstance(part, _messages.ToolCallPart): # pragma: no branch @@ -148,7 +247,7 @@ def find_named_tool( def find_tool( self, parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]: + ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: """Find a tool that matches one of the calls.""" for part in parts: if isinstance(part, _messages.ToolCallPart): # pragma: no branch @@ -164,64 +263,138 @@ def tool_defs(self) -> list[ToolDefinition]: return [t.tool_def for t in self.tools.values()] -DEFAULT_DESCRIPTION = 'The final response which ends this conversation' +def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool: + return output_schema is None or output_schema.allow_text_output + + +@dataclass +class OutputObjectDefinition: + name: str + json_schema: ObjectJsonSchema + description: str | None = None + strict: bool | None = None @dataclass(init=False) -class OutputSchemaTool(Generic[OutputDataT]): - tool_def: ToolDefinition - type_adapter: TypeAdapter[Any] +class OutputObjectSchema(Generic[OutputDataT]): + definition: OutputObjectDefinition + validator: SchemaValidator + function_schema: _function_schema.FunctionSchema | None = None + outer_typed_dict_key: str | None = None def __init__( - self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None + self, + *, + output_type: SimpleOutputType[OutputDataT], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, ): - """Build a OutputSchemaTool from a response type.""" - if _utils.is_model_like(output_type): - self.type_adapter = TypeAdapter(output_type) - outer_typed_dict_key: str | None = None - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) + if inspect.isfunction(output_type) or inspect.ismethod(output_type): + self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) + self.validator = self.function_schema.validator + json_schema = self.function_schema.json_schema + json_schema['description'] = self.function_schema.description else: - response_data_typed_dict = TypedDict( # noqa: UP013 - 'response_data_typed_dict', - {'response': output_type}, # pyright: ignore[reportInvalidTypeForm] - ) - self.type_adapter = TypeAdapter(response_data_typed_dict) - outer_typed_dict_key = 'response' - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) + type_adapter: TypeAdapter[Any] + if _utils.is_model_like(output_type): + type_adapter = TypeAdapter(output_type) + else: + self.outer_typed_dict_key = 'response' + response_data_typed_dict = TypedDict( # noqa: UP013 + 'response_data_typed_dict', + {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm] + ) + type_adapter = TypeAdapter(response_data_typed_dict) + + # Really a PluggableSchemaValidator, but it's API-compatible + self.validator = cast(SchemaValidator, type_adapter.validator) + json_schema = _utils.check_object_json_schema( + type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) - # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM - parameters_json_schema.pop('title') - if json_schema_description := parameters_json_schema.pop('description', None): + if self.outer_typed_dict_key: + # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM + json_schema.pop('title') + + if json_schema_description := json_schema.pop('description', None): if description is None: - tool_description = json_schema_description + description = json_schema_description else: - tool_description = f'{description}. {json_schema_description}' # pragma: no cover + description = f'{description}. {json_schema_description}' + + self.definition = OutputObjectDefinition( + name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME), + description=description, + json_schema=json_schema, + strict=strict, + ) + + async def process( + self, + data: str | dict[str, Any] | None, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + ) -> OutputDataT: + """Process an output message, performing validation and (if necessary) calling the output function. + + Args: + data: The output data to validate. + run_context: The current run context. + allow_partial: If true, allow partial validation. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + output = self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) else: - tool_description = description or DEFAULT_DESCRIPTION + output = self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + + if self.function_schema: + output = await self.function_schema.call(output, run_context) + + if k := self.outer_typed_dict_key: + output = output[k] + return output + + +@dataclass(init=False) +class OutputTool(Generic[OutputDataT]): + parameters_schema: OutputObjectSchema[OutputDataT] + tool_def: ToolDefinition + + def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool): + self.parameters_schema = parameters_schema + definition = parameters_schema.definition + + description = definition.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION if multiple: - tool_description = f'{union_arg_name(output_type)}: {tool_description}' + description = f'{definition.name}: {description}' self.tool_def = ToolDefinition( name=name, - description=tool_description, - parameters_json_schema=parameters_json_schema, - outer_typed_dict_key=outer_typed_dict_key, - strict=strict, + description=description, + parameters_json_schema=definition.json_schema, + strict=definition.strict, + outer_typed_dict_key=parameters_schema.outer_typed_dict_key, ) - def validate( - self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True + async def process( + self, + tool_call: _messages.ToolCallPart, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, ) -> OutputDataT: - """Validate an output message. + """Process an output message. Args: tool_call: The tool call from the LLM to validate. + run_context: The current run context. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -229,15 +402,7 @@ def validate( Either the validated output data (left) or a retry message (right). """ try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(tool_call.args, str): - output = self.type_adapter.validate_json( - tool_call.args or '{}', experimental_allow_partial=pyd_allow_partial - ) - else: - output = self.type_adapter.validate_python( - tool_call.args or {}, experimental_allow_partial=pyd_allow_partial - ) + output = await self.parameters_schema.process(tool_call.args, run_context, allow_partial=allow_partial) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -248,38 +413,20 @@ def validate( raise ToolRetryError(m) from e else: raise # pragma: lax no cover + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + tool_name=tool_call.tool_name, + content=r.message, + tool_call_id=tool_call.tool_call_id, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: lax no cover else: - if k := self.tool_def.outer_typed_dict_key: - output = output[k] return output -def union_tool_name(base_name: str | None, union_arg: Any) -> str: - return f'{base_name or DEFAULT_OUTPUT_TOOL_NAME}_{union_arg_name(union_arg)}' - - -def union_arg_name(union_arg: Any) -> str: - return union_arg.__name__ - - -def extract_str_from_union(output_type: Any) -> _utils.Option[Any]: - """Extract the string type from a Union, return the remaining union or remaining type.""" - union_args = get_union_args(output_type) - if any(t is str for t in union_args): - remain_args: list[Any] = [] - includes_str = False - for arg in union_args: - if arg is str: - includes_str = True - else: - remain_args.append(arg) - if includes_str: # pragma: no branch - if len(remain_args) == 1: - return _utils.Some(remain_args[0]) - else: - return _utils.Some(Union[tuple(remain_args)]) # pragma: no cover - - def get_union_args(tp: Any) -> tuple[Any, ...]: """Extract the arguments of a Union type if `output_type` is a union, otherwise return an empty tuple.""" if typing_objects.is_typealiastype(tp): diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 749652d5c9..57010f5f8e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -29,7 +29,7 @@ usage as _usage, ) from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model -from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput +from .result import FinalResult, OutputDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -127,7 +127,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ - output_type: type[OutputDataT] | ToolOutput[OutputDataT] + output_type: _output.OutputType[OutputDataT] """ The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. """ @@ -162,7 +162,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str, + output_type: _output.OutputType[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] @@ -199,7 +199,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, - result_tool_name: str = 'final_result', + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, result_tool_description: str | None = None, result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), @@ -214,7 +214,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - # TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads + # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads output_type: Any = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] @@ -374,7 +374,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -404,7 +404,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -492,7 +492,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -524,7 +524,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -770,7 +770,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -800,7 +800,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -883,7 +883,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent], *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -914,7 +914,7 @@ async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -994,7 +994,7 @@ async def stream_to_final( if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, _messages.TextPart): - if _agent_graph.allow_text_output(output_schema): + if _output.allow_text_output(output_schema): return FinalResult(s, None, None) elif isinstance(new_part, _messages.ToolCallPart) and output_schema: for call, _ in output_schema.find_tool([new_part]): @@ -1628,7 +1628,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( - self, output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None + self, output_type: _output.OutputType[RunOutputDataT] | None ) -> _output.OutputSchema[RunOutputDataT] | None: if output_type is not None: if self._output_validators: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 96c2f22d8b..443e98b328 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,100 +5,35 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Union, cast +from typing import Generic, cast from typing_extensions import TypeVar, assert_type, deprecated, overload -from . import _utils, exceptions, messages as _messages, models +from . import _output, _utils, exceptions, messages as _messages, models +from ._output import ( + OutputDataT, + OutputDataT_inv, + OutputSchema, + OutputValidator, + OutputValidatorFunc, + ToolOutput, +) from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -if TYPE_CHECKING: - from . import _output - __all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'OutputValidatorFunc' T = TypeVar('T') """An invariant TypeVar.""" -OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) -""" -An invariant type variable for the result data of a model. - -We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used -in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types -possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and -changing it would have negative consequences for the ergonomics of the library. - -At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would -resolve these potential variance issues. -""" -OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) -"""Covariant type variable for the result data type of a run.""" - -OutputValidatorFunc = Union[ - Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv], - Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]], - Callable[[OutputDataT_inv], OutputDataT_inv], - Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]], -] -""" -A function that always takes and returns the same type of data (which is the result type of an agent run), and: - -* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument -* may or may not be async - -Usage `OutputValidatorFunc[AgentDepsT, T]`. -""" - -DEFAULT_OUTPUT_TOOL_NAME = 'final_result' - - -@dataclass(init=False) -class ToolOutput(Generic[OutputDataT]): - """Marker class to use tools for structured outputs, and customize the tool.""" - - output_type: type[OutputDataT] - # TODO: Add `output_call` support, for calling a function to get the output - # output_call: Callable[..., OutputDataT] | None - name: str - description: str | None - max_retries: int | None - strict: bool | None - - def __init__( - self, - *, - type_: type[OutputDataT], - # call: Callable[..., OutputDataT] | None = None, - name: str = 'final_result', - description: str | None = None, - max_retries: int | None = None, - strict: bool | None = None, - ): - self.output_type = type_ - self.name = name - self.description = description - self.max_retries = max_retries - self.strict = strict - - # TODO: add support for call and make type_ optional, with the following logic: - # if type_ is None and call is None: - # raise ValueError('Either type_ or call must be provided') - # if call is not None: - # if type_ is None: - # type_ = get_type_hints(call).get('return') - # if type_ is None: - # raise ValueError('Unable to determine type_ from call signature; please provide it explicitly') - # self.output_call = call @dataclass class AgentStream(Generic[AgentDepsT, OutputDataT]): _raw_stream_response: models.StreamedResponse - _output_schema: _output.OutputSchema[OutputDataT] | None - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] + _output_schema: OutputSchema[OutputDataT] | None + _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None @@ -144,6 +79,7 @@ async def _validate_response( self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" + call = None if self._output_schema is not None and output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: @@ -152,21 +88,17 @@ async def _validate_response( ) call, output_tool = match - result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) - - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data + result_data = await output_tool.process( + call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - for validator in self._output_validators: - text = await validator.validate( - text, - None, - self._run_ctx, - ) - # Since there is no output tool, we can assume that str is compatible with OutputDataT - return cast(OutputDataT, text) + # The following cast is safe because we know `str` is an allowed output type + result_data = cast(OutputDataT, text) + + for validator in self._output_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) + return result_data def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -180,7 +112,6 @@ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: async def aiter(): output_schema = self._output_schema - allow_text_output = output_schema is None or output_schema.allow_text_output def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" @@ -192,7 +123,7 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. return _messages.FinalResultEvent( tool_name=call.tool_name, tool_call_id=call.tool_call_id ) - elif allow_text_output: # pragma: no branch + elif _output.allow_text_output(output_schema): # pragma: no branch assert_type(e, _messages.PartStartEvent) return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) @@ -224,9 +155,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _usage_limits: UsageLimits | None _stream_response: models.StreamedResponse - _output_schema: _output.OutputSchema[OutputDataT] | None + _output_schema: OutputSchema[OutputDataT] | None _run_ctx: RunContext[AgentDepsT] - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] + _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] @@ -458,6 +389,7 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" + call = None if self._output_schema is not None and self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: @@ -466,17 +398,16 @@ async def validate_structured_output( ) call, output_tool = match - result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) - - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover - return result_data + result_data = await output_tool.process( + call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - for validator in self._output_validators: - text = await validator.validate(text, None, self._run_ctx) # pragma: no cover - # Since there is no output tool, we can assume that str is compatible with OutputDataT - return cast(OutputDataT, text) + result_data = cast(OutputDataT, text) + + for validator in self._output_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover + return result_data async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 7d174ba7d6..db232ca75f 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,22 +1,22 @@ from __future__ import annotations as _annotations import dataclasses -import inspect import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union from opentelemetry.trace import Tracer from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue -from pydantic_core import SchemaValidator, core_schema +from pydantic_core import core_schema from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar -from . import _pydantic, _utils, messages as _messages, models +from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry, UnexpectedModelBehavior if TYPE_CHECKING: + from .models import Model from .result import Usage __all__ = ( @@ -45,7 +45,7 @@ class RunContext(Generic[AgentDepsT]): deps: AgentDepsT """Dependencies for the agent.""" - model: models.Model + model: Model """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" @@ -208,12 +208,7 @@ class Tool(Generic[AgentDepsT]): docstring_format: DocstringFormat require_parameter_descriptions: bool strict: bool | None - _is_async: bool = field(init=False) - _single_arg_name: str | None = field(init=False) - _positional_fields: list[str] = field(init=False) - _var_positional_field: str | None = field(init=False) - _validator: SchemaValidator = field(init=False, repr=False) - _base_parameters_json_schema: ObjectJsonSchema = field(init=False) + function_schema: _function_schema.FunctionSchema """ The base JSON schema for the tool's parameters. @@ -237,6 +232,7 @@ def __init__( require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, + function_schema: _function_schema.FunctionSchema | None = None, ): """Create a new tool instance. @@ -289,28 +285,24 @@ async def prep_my_tool( schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + function_schema: The function schema to use for the tool. If not provided, it will be generated. """ - if takes_ctx is None: - takes_ctx = _pydantic.takes_ctx(function) - - f = _pydantic.function_schema( - function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator - ) self.function = function - self.takes_ctx = takes_ctx + self.function_schema = function_schema or _function_schema.function_schema( + function, + schema_generator, + takes_ctx=takes_ctx, + docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions, + ) + self.takes_ctx = self.function_schema.takes_ctx self.max_retries = max_retries self.name = name or function.__name__ - self.description = description or f['description'] + self.description = description or self.function_schema.description self.prepare = prepare self.docstring_format = docstring_format self.require_parameter_descriptions = require_parameter_descriptions self.strict = strict - self._is_async = inspect.iscoroutinefunction(self.function) - self._single_arg_name = f['single_arg_name'] - self._positional_fields = f['positional_fields'] - self._var_positional_field = f['var_positional_field'] - self._validator = f['validator'] - self._base_parameters_json_schema = f['json_schema'] async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. @@ -324,7 +316,7 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition tool_def = ToolDefinition( name=self.name, description=self.description, - parameters_json_schema=self._base_parameters_json_schema, + parameters_json_schema=self.function_schema.json_schema, strict=self.strict, ) if self.prepare is not None: @@ -366,21 +358,22 @@ async def _run( self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: try: + validator = self.function_schema.validator if isinstance(message.args, str): - args_dict = self._validator.validate_json(message.args or '{}') + args_dict = validator.validate_json(message.args or '{}') else: - args_dict = self._validator.validate_python(message.args or {}) + args_dict = validator.validate_python(message.args or {}) except ValidationError as e: return self._on_error(e, message) - args, kwargs = self._call_args(args_dict, message, run_context) + ctx = dataclasses.replace( + run_context, + retry=self.current_retry, + tool_name=message.tool_name, + tool_call_id=message.tool_call_id, + ) try: - if self._is_async: - function = cast(Callable[[Any], Awaitable[str]], self.function) - response_content = await function(*args, **kwargs) - else: - function = cast(Callable[[Any], str], self.function) - response_content = await _utils.run_in_executor(function, *args, **kwargs) + response_content = await self.function_schema.call(args_dict, ctx) except ModelRetry as e: return self._on_error(e, message) @@ -391,29 +384,6 @@ async def _run( tool_call_id=message.tool_call_id, ) - def _call_args( - self, - args_dict: dict[str, Any], - message: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - ) -> tuple[list[Any], dict[str, Any]]: - if self._single_arg_name: - args_dict = {self._single_arg_name: args_dict} - - ctx = dataclasses.replace( - run_context, - retry=self.current_retry, - tool_name=message.tool_name, - tool_call_id=message.tool_call_id, - ) - args = [ctx] if self.takes_ctx else [] - for positional_field in self._positional_fields: - args.append(args_dict.pop(positional_field)) # pragma: no cover - if self._var_positional_field: - args.extend(args_dict.pop(self._var_positional_field)) - - return args, args_dict - def _on_error( self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart ) -> _messages.RetryPromptPart: diff --git a/tests/test_agent.py b/tests/test_agent.py index 215ae4bf95..68804fe52e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -10,6 +10,7 @@ from inline_snapshot import snapshot from pydantic import BaseModel, TypeAdapter, field_validator from pydantic_core import to_json +from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai.agent import AgentRunResult @@ -29,7 +30,7 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Usage +from pydantic_ai.result import ToolOutput, Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -389,8 +390,8 @@ def test_response_tuple(): @pytest.mark.parametrize( 'input_union_callable', - [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str], - ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'], + [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str, lambda: [Foo, str]], + ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str', '[Foo, str]'], ) def test_response_union_allow_str(input_union_callable: Callable[[], Any]): try: @@ -446,6 +447,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'union_code', [ pytest.param('OutputType = Union[Foo, Bar]'), + pytest.param('OutputType = [Foo, Bar]'), pytest.param('OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')), pytest.param( 'OutputType: TypeAlias = Foo | Bar', @@ -461,6 +463,7 @@ def test_response_multiple_return_tools(create_module: Callable[[str], Any], uni from pydantic import BaseModel from typing import Union from typing_extensions import TypeAlias +from pydantic_ai import ToolOutput class Foo(BaseModel): a: int @@ -531,6 +534,601 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert got_tool_call_name == snapshot('final_result_Bar') +def test_output_type_with_two_descriptions(): + class MyOutput(BaseModel): + """Description from docstring""" + + valid: bool + + m = TestModel() + agent = Agent(m, output_type=ToolOutput(MyOutput, description='Description from ToolOutput')) + result = agent.run_sync('Hello') + assert result.output == snapshot(MyOutput(valid=False)) + assert m.last_model_request_parameters is not None + assert m.last_model_request_parameters.output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='Description from ToolOutput. Description from docstring', + parameters_json_schema={ + 'properties': {'valid': {'type': 'boolean'}}, + 'required': ['valid'], + 'title': 'MyOutput', + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_tool_output_union(): + class Foo(BaseModel): + a: int + b: str + + class Bar(BaseModel): + c: bool + + m = TestModel() + marker: ToolOutput[Union[Foo, Bar]] = ToolOutput(Union[Foo, Bar], strict=False) # type: ignore + agent = Agent(m, output_type=marker) + result = agent.run_sync('Hello') + assert result.output == snapshot(Foo(a=0, b='a')) + assert m.last_model_request_parameters is not None + assert m.last_model_request_parameters.output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + '$defs': { + 'Bar': { + 'properties': {'c': {'type': 'boolean'}}, + 'required': ['c'], + 'title': 'Bar', + 'type': 'object', + }, + 'Foo': { + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}}, + 'required': ['a', 'b'], + 'title': 'Foo', + 'type': 'object', + }, + }, + 'additionalProperties': False, + 'properties': {'response': {'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}]}}, + 'required': ['response'], + 'type': 'object', + }, + outer_typed_dict_key='response', + strict=False, + ) + ] + ) + + +def test_output_type_function(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_function_with_run_context(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(ctx: RunContext[None], city: str) -> Weather: + assert ctx is not None + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_bound_instance_method(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, city: str) -> Self: + return self + + weather = Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=weather.get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_bound_instance_method_with_run_context(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, ctx: RunContext[None], city: str) -> Self: + assert ctx is not None + return self + + weather = Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=weather.get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_function_with_retry(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + if city != 'Mexico City': + raise ModelRetry('City not found, I only know Mexico City') + return Weather(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + args_json = '{"city": "New York City"}' + else: + args_json = '{"city": "Mexico City"}' + + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('New York City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='New York City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "New York City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=53, response_tokens=7, total_tokens=60), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='City not found, I only know Mexico City', + tool_name='final_result', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=68, response_tokens=13, total_tokens=81), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +def test_output_type_async_function(): + class Weather(BaseModel): + temperature: float + description: str + + async def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_function_with_custom_tool_name(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=ToolOutput(get_weather, name='get_weather')) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='get_weather', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_function_or_model(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=[get_weather, Weather]) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result_get_weather', + description='get_weather: The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ), + ToolDefinition( + name='final_result_Weather', + description='Weather: The final response which ends this conversation', + parameters_json_schema={ + 'properties': {'temperature': {'type': 'number'}, 'description': {'type': 'string'}}, + 'required': ['temperature', 'description'], + 'title': 'Weather', + 'type': 'object', + }, + ), + ] + ) + + +def test_output_type_handoff_to_agent(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + + handoff_result = None + + async def handoff(city: str) -> Weather: + result = await agent.run(f'Get me the weather in {city}') + nonlocal handoff_result + handoff_result = result + return result.output + + def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + supervisor_agent = Agent(FunctionModel(call_handoff_tool), output_type=handoff) + + result = supervisor_agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Mexico City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=52, response_tokens=6, total_tokens=58), + model_name='function:call_handoff_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + assert handoff_result is not None + assert handoff_result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Get me the weather in Mexico City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=57, response_tokens=6, total_tokens=63), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +def test_output_type_multiple_custom_tools(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent( + FunctionModel(call_tool), + output_type=[ + ToolOutput(get_weather, name='get_weather'), + ToolOutput(Weather, name='return_weather'), + ], + ) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='get_weather', + description='get_weather: The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ), + ToolDefinition( + name='return_weather', + description='Weather: The final response which ends this conversation', + parameters_json_schema={ + 'properties': {'temperature': {'type': 'number'}, 'description': {'type': 'string'}}, + 'required': ['temperature', 'description'], + 'title': 'Weather', + 'type': 'object', + }, + ), + ] + ) + + def test_run_with_history_new(): m = TestModel() diff --git a/tests/test_examples.py b/tests/test_examples.py index ff74ff80cb..ad377bedbf 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -72,8 +72,12 @@ def find_filter_examples() -> Iterable[ParameterSet]: for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph', 'pydantic_evals'): if ex.path.name != '_utils.py': + try: + path = ex.path.relative_to(Path.cwd()) + except ValueError: + path = ex.path + test_id = f'{path}:{ex.start_line}' prefix_settings = ex.prefix_settings() - test_id = str(ex) if opt_title := prefix_settings.get('title'): test_id += f':{opt_title}' yield pytest.param(ex, id=test_id) @@ -401,6 +405,32 @@ async def list_tools() -> list[None]: args={'numerator': '123', 'denominator': '456'}, tool_call_id='pyd_ai_2e0e396768a14fe482df90a29a78dc7b', ), + 'Select the names and countries of all capitals': ToolCallPart( + tool_name='final_result_hand_off_to_sql_agent', + args={'query': 'SELECT name, country FROM capitals;'}, + ), + 'SELECT name, country FROM capitals;': ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT name, country FROM capitals;'}, + ), + 'SELECT * FROM capital_cities;': ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT * FROM capital_cities;'}, + ), + 'Select all pets': ToolCallPart( + tool_name='final_result_hand_off_to_sql_agent', + args={'query': 'SELECT * FROM pets;'}, + ), + 'SELECT * FROM pets;': ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT * FROM pets;'}, + ), + 'How do I fly from Amsterdam to Mexico City?': ToolCallPart( + tool_name='final_result_RouterFailure', + args={ + 'explanation': 'I am not equipped to provide travel information, such as flights from Amsterdam to Mexico City.' + }, + ), } tool_responses: dict[tuple[str, str], str] = { @@ -582,6 +612,69 @@ async def model_logic( # noqa: C901 return ModelResponse( parts=[ToolCallPart(tool_name='get_document', args={}, tool_call_id='pyd_ai_tool_call_id')] ) + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_run_sql_query' + and m.content == "Only 'SELECT *' is supported, you'll have to do column filtering manually." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT * FROM capitals;'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_hand_off_to_sql_agent' + and m.content + == "SQL agent failed: Unknown table 'capitals' in query 'SELECT * FROM capitals;'. Available tables: capital_cities." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_hand_off_to_sql_agent', + args={'query': 'SELECT * FROM capital_cities;'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_run_sql_query' + and m.content == "Unknown table 'pets' in query 'SELECT * FROM pets;'. Available tables: capital_cities." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_SQLFailure', + args={ + 'explanation': "The table 'pets' does not exist in the database. Only the table 'capital_cities' is available." + }, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) + # SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available. + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_hand_off_to_sql_agent' + and m.content + == "SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_RouterFailure', + args={ + 'explanation': "The requested table 'pets' does not exist in the database. The only available table is 'capital_cities', which does not contain data about pets." + }, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') diff --git a/tests/test_tools.py b/tests/test_tools.py index ec2d9cfaf4..218c564e44 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -572,7 +572,7 @@ def test_tool_return_conflict(): Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) # this raises an error with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"): - Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(type_=int, name='ctx_tool')) + Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) def test_init_ctx_tool_invalid(): diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 676cb34229..180ce2b0dc 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -1,16 +1,19 @@ """This file is used to test static typing, it's analyzed with pyright and mypy.""" -from collections.abc import Awaitable, Iterator -from contextlib import contextmanager +from collections.abc import Awaitable from dataclasses import dataclass from typing import Callable, TypeAlias, Union from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool +from pydantic_ai._output import ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.tools import ToolDefinition +# Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True +MYPY = False + @dataclass class MyDeps: @@ -37,16 +40,6 @@ def system_prompt_ok2() -> str: assert_type(system_prompt_ok2, Callable[[], str]) -@contextmanager -def expect_error(error_type: type[Exception]) -> Iterator[None]: - try: - yield None - except Exception as e: - assert isinstance(e, error_type), f'Expected {error_type}, got {type(e)}' - else: - raise AssertionError('Expected an error') - - @typed_agent.tool async def ok_tool(ctx: RunContext[MyDeps], x: str) -> str: assert_type(ctx.deps, MyDeps) @@ -108,13 +101,6 @@ async def bad_tool2(ctx: RunContext[int], x: str) -> str: return f'{x} {ctx.deps}' -with expect_error(ValueError): - - @typed_agent.tool # type: ignore[arg-type] - async def bad_tool3(x: str) -> str: - return x - - @typed_agent.output_validator def ok_validator_simple(data: str) -> str: return data @@ -187,8 +173,51 @@ def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str: return f'{x} {y}' -def foobar_plain(x: str, y: int) -> str: - return f'{x} {y}' +async def foobar_plain(x: int, y: int) -> int: + return x * y + + +class MyClass: + def my_method(self) -> bool: + return True + + +str_function_agent = Agent(output_type=foobar_ctx) +assert_type(str_function_agent, Agent[None, str]) + +bool_method_agent = Agent(output_type=MyClass().my_method) +assert_type(bool_method_agent, Agent[None, bool]) + +if MYPY: + # mypy requires the generic parameters to be specified explicitly to be happy here + async_int_function_agent = Agent[None, int](output_type=foobar_plain) + assert_type(async_int_function_agent, Agent[None, int]) + + two_models_output_agent = Agent[None, Foo | Bar](output_type=[Foo, Bar]) + assert_type(two_models_output_agent, Agent[None, Foo | Bar]) + + two_scalars_output_agent = Agent[None, int | str](output_type=[int, str]) + assert_type(two_scalars_output_agent, Agent[None, int | str]) + + marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore + complex_output_agent = Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]( + output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker] + ) + assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) +else: + # pyright is able to correctly infer the type here + async_int_function_agent = Agent(output_type=foobar_plain) + assert_type(async_int_function_agent, Agent[None, int]) + + two_models_output_agent = Agent(output_type=[Foo, Bar]) + assert_type(two_models_output_agent, Agent[None, Foo | Bar]) + + two_scalars_output_agent = Agent(output_type=[int, str]) + assert_type(two_scalars_output_agent, Agent[None, int | str]) + + marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore + complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker]) + assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) Tool(foobar_ctx, takes_ctx=True) @@ -235,7 +264,6 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD result = greet_agent.run_sync('testing...', deps='human') assert result.output == '{"greet":"hello a"}' -MYPY = False if not MYPY: default_agent = Agent() assert_type(default_agent, Agent[None, str])