diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index b75ea82268..73d3af0c9a 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -251,6 +251,8 @@ async def _prepare_request_parameters( ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" run_context = build_run_context(ctx) + + # This will raise errors for any tool name conflicts ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) output_schema = ctx.deps.output_schema diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 085b5da46f..47a316997c 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field, replace from typing import Any, Generic +from opentelemetry.trace import Tracer from pydantic import ValidationError from typing_extensions import assert_never @@ -21,41 +22,46 @@ class ToolManager(Generic[AgentDepsT]): """Manages tools for an agent run step. It caches the agent run's toolset's tool definitions and handles calling tools and retries.""" - ctx: RunContext[AgentDepsT] - """The agent run context for a specific run step.""" toolset: AbstractToolset[AgentDepsT] """The toolset that provides the tools for this run step.""" - tools: dict[str, ToolsetTool[AgentDepsT]] + ctx: RunContext[AgentDepsT] | None = None + """The agent run context for a specific run step.""" + tools: dict[str, ToolsetTool[AgentDepsT]] | None = None """The cached tools for this run step.""" failed_tools: set[str] = field(default_factory=set) """Names of tools that failed in this run step.""" - @classmethod - async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: - """Build a new tool manager for a specific run step.""" - return cls( - ctx=ctx, - toolset=toolset, - tools=await toolset.get_tools(ctx), - ) - async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: """Build a new tool manager for the next run step, carrying over the retries from the current run step.""" - if ctx.run_step == self.ctx.run_step: - return self - - retries = { - failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools - } - return await self.__class__.build(self.toolset, replace(ctx, retries=retries)) + if self.ctx is not None: + if ctx.run_step == self.ctx.run_step: + return self + + retries = { + failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 + for failed_tool_name in self.failed_tools + } + ctx = replace(ctx, retries=retries) + + return self.__class__( + toolset=self.toolset, + ctx=ctx, + tools=await self.toolset.get_tools(ctx), + ) @property def tool_defs(self) -> list[ToolDefinition]: """The tool definitions for the tools in this tool manager.""" + if self.tools is None: + raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover + return [tool.tool_def for tool in self.tools.values()] def get_tool_def(self, name: str) -> ToolDefinition | None: """Get the tool definition for a given tool name, or `None` if the tool is unknown.""" + if self.tools is None: + raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover + try: return self.tools[name].tool_def except KeyError: @@ -71,15 +77,25 @@ async def handle_call( allow_partial: Whether to allow partial validation of the tool arguments. wrap_validation_errors: Whether to wrap validation errors in a retry prompt part. """ + if self.tools is None or self.ctx is None: + raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover + if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output': # Output tool calls are not traced return await self._call_tool(call, allow_partial, wrap_validation_errors) else: - return await self._call_tool_traced(call, allow_partial, wrap_validation_errors) + return await self._call_tool_traced( + call, + allow_partial, + wrap_validation_errors, + self.ctx.tracer, + self.ctx.trace_include_content, + ) + + async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validation_errors: bool) -> Any: + if self.tools is None or self.ctx is None: + raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover - async def _call_tool( - self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True - ) -> Any: name = call.tool_name tool = self.tools.get(name) try: @@ -137,14 +153,19 @@ async def _call_tool( raise e async def _call_tool_traced( - self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True + self, + call: ToolCallPart, + allow_partial: bool, + wrap_validation_errors: bool, + tracer: Tracer, + include_content: bool = False, ) -> Any: """See .""" span_attributes = { 'gen_ai.tool.name': call.tool_name, # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai 'gen_ai.tool.call.id': call.tool_call_id, - **({'tool_arguments': call.args_as_json_str()} if self.ctx.trace_include_content else {}), + **({'tool_arguments': call.args_as_json_str()} if include_content else {}), 'logfire.msg': f'running tool: {call.tool_name}', # add the JSON schema so these attributes are formatted nicely in Logfire 'logfire.json_schema': json.dumps( @@ -156,7 +177,7 @@ async def _call_tool_traced( 'tool_arguments': {'type': 'object'}, 'tool_response': {'type': 'object'}, } - if self.ctx.trace_include_content + if include_content else {} ), 'gen_ai.tool.name': {}, @@ -165,16 +186,16 @@ async def _call_tool_traced( } ), } - with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span: + with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: try: tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors) except ToolRetryError as e: part = e.tool_retry - if self.ctx.trace_include_content and span.is_recording(): + if include_content and span.is_recording(): span.set_attribute('tool_response', part.model_response()) raise e - if self.ctx.trace_include_content and span.is_recording(): + if include_content and span.is_recording(): span.set_attribute( 'tool_response', tool_result diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index de96eabe24..fb4c10b632 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -566,6 +566,8 @@ async def main(): if output_toolset: output_toolset.max_retries = self._max_result_retries output_toolset.output_validators = output_validators + toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) + tool_manager = ToolManager[AgentDepsT](toolset) # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( @@ -581,6 +583,27 @@ async def main(): run_step=0, ) + # Merge model settings in order of precedence: run > agent > model + merged_settings = merge_model_settings(model_used.settings, self.model_settings) + model_settings = merge_model_settings(merged_settings, model_settings) + usage_limits = usage_limits or _usage.UsageLimits() + + async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: + parts = [ + self._instructions, + *[await func.run(run_context) for func in self._instructions_functions], + ] + + model_profile = model_used.profile + if isinstance(output_schema, _output.PromptedOutputSchema): + instructions = output_schema.instructions(model_profile.prompted_output_template) + parts.append(instructions) + + parts = [p for p in parts if p] + if not parts: + return None + return '\n\n'.join(parts).strip() + if isinstance(model_used, InstrumentedModel): instrumentation_settings = model_used.instrumentation_settings tracer = model_used.instrumentation_settings.tracer @@ -588,81 +611,45 @@ async def main(): instrumentation_settings = None tracer = NoOpTracer() - run_context = RunContext[AgentDepsT]( - deps=deps, - model=model_used, - usage=usage, + graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( + user_deps=deps, prompt=user_prompt, - messages=state.message_history, + new_message_index=new_message_index, + model=model_used, + model_settings=model_settings, + usage_limits=usage_limits, + max_result_retries=self._max_result_retries, + end_strategy=self.end_strategy, + output_schema=output_schema, + output_validators=output_validators, + history_processors=self.history_processors, + builtin_tools=list(self._builtin_tools), + tool_manager=tool_manager, tracer=tracer, - trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, - run_step=state.run_step, + get_instructions=get_instructions, + instrumentation_settings=instrumentation_settings, + ) + start_node = _agent_graph.UserPromptNode[AgentDepsT]( + user_prompt=user_prompt, + instructions=self._instructions, + instructions_functions=self._instructions_functions, + system_prompts=self._system_prompts, + system_prompt_functions=self._system_prompt_functions, + system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, ) - toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) - - async with toolset: - # This will raise errors for any name conflicts - tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context) - - # Merge model settings in order of precedence: run > agent > model - merged_settings = merge_model_settings(model_used.settings, self.model_settings) - model_settings = merge_model_settings(merged_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - agent_name = self.name or 'agent' - run_span = tracer.start_span( - 'agent run', - attributes={ - 'model_name': model_used.model_name if model_used else 'no-model', - 'agent_name': agent_name, - 'logfire.msg': f'{agent_name} run', - }, - ) - - async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: - parts = [ - self._instructions, - *[await func.run(run_context) for func in self._instructions_functions], - ] - - model_profile = model_used.profile - if isinstance(output_schema, _output.PromptedOutputSchema): - instructions = output_schema.instructions(model_profile.prompted_output_template) - parts.append(instructions) - - parts = [p for p in parts if p] - if not parts: - return None - return '\n\n'.join(parts).strip() - - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( - user_deps=deps, - prompt=user_prompt, - new_message_index=new_message_index, - model=model_used, - model_settings=model_settings, - usage_limits=usage_limits, - max_result_retries=self._max_result_retries, - end_strategy=self.end_strategy, - output_schema=output_schema, - output_validators=output_validators, - history_processors=self.history_processors, - builtin_tools=list(self._builtin_tools), - tool_manager=tool_manager, - tracer=tracer, - get_instructions=get_instructions, - instrumentation_settings=instrumentation_settings, - ) - start_node = _agent_graph.UserPromptNode[AgentDepsT]( - user_prompt=user_prompt, - instructions=self._instructions, - instructions_functions=self._instructions_functions, - system_prompts=self._system_prompts, - system_prompt_functions=self._system_prompt_functions, - system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, - ) + agent_name = self.name or 'agent' + run_span = tracer.start_span( + 'agent run', + attributes={ + 'model_name': model_used.model_name if model_used else 'no-model', + 'agent_name': agent_name, + 'logfire.msg': f'{agent_name} run', + }, + ) - try: + try: + async with toolset: async with graph.iter( start_node, state=state, @@ -682,12 +669,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: else json.dumps(InstrumentedModel.serialize_any(final_result.output)) ), ) + finally: + try: + if instrumentation_settings and run_span.is_recording(): + run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) finally: - try: - if instrumentation_settings and run_span.is_recording(): - run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings)) - finally: - run_span.end() + run_span.end() def _run_span_end_attributes( self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 8d42fa4c63..84139fbc96 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -1104,18 +1104,7 @@ async def store_state( events.append(json.loads(event.removeprefix('data: '))) assert events == simple_result() - assert seen_states == snapshot( - [ - 41, # run msg_1, prepare_tools call 1 - 42, # run msg_1, prepare_tools call 2 - 0, # run msg_2, prepare_tools call 1 - 1, # run msg_2, prepare_tools call 2 - 0, # run msg_3, prepare_tools call 1 - 1, # run msg_3, prepare_tools call 2 - 42, # run msg_4, prepare_tools call 1 - 43, # run msg_4, prepare_tools call 2 - ] - ) + assert seen_states == snapshot([41, 0, 0, 42]) async def test_request_with_state_without_handler() -> None: diff --git a/tests/test_agent.py b/tests/test_agent.py index bbe91d74c4..83bdcef5eb 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3768,11 +3768,7 @@ async def via_toolset_decorator_for_entire_run(ctx: RunContext[None]) -> Abstrac assert run_result._state.run_step == 3 # pyright: ignore[reportPrivateUsage] assert len(available_tools) == 3 assert toolset_creation_counts == snapshot( - { - 'via_toolsets_arg': 4, - 'via_toolset_decorator': 4, - 'via_toolset_decorator_for_entire_run': 1, - } + defaultdict(int, {'via_toolsets_arg': 3, 'via_toolset_decorator': 3, 'via_toolset_decorator_for_entire_run': 1}) ) diff --git a/tests/test_temporal.py b/tests/test_temporal.py index 1154e0fa7c..6df7fc19d4 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -322,8 +322,6 @@ async def test_complex_agent_run_in_workflow( 'RunWorkflow:ComplexAgentWorkflow', 'StartActivity:agent__complex_agent__mcp_server__mcp__get_tools', 'RunActivity:agent__complex_agent__mcp_server__mcp__get_tools', - 'StartActivity:agent__complex_agent__mcp_server__mcp__get_tools', - 'RunActivity:agent__complex_agent__mcp_server__mcp__get_tools', 'StartActivity:agent__complex_agent__model_request_stream', 'ctx.run_step=1', '{"index":0,"part":{"tool_name":"get_country","args":"","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_kind":"tool-call"},"event_kind":"part_start"}', diff --git a/tests/test_tools.py b/tests/test_tools.py index c88ce659c4..730f535ea0 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1226,10 +1226,9 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int: with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"): agent.run_sync('Begin infinite retry loop!') - # There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in. - assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5] - assert prepare_retries == [0, 0, 1, 2, 3, 4, 5] - assert call_retries == [0, 1, 2, 3, 4, 5] + assert prepare_tools_retries == snapshot([0, 1, 2, 3, 4, 5]) + assert prepare_retries == snapshot([0, 1, 2, 3, 4, 5]) + assert call_retries == snapshot([0, 1, 2, 3, 4, 5]) def test_deferred_tool(): diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index 2e256df5f3..81f6863bf5 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -61,7 +61,7 @@ def add(a: int, b: int) -> int: return a + b no_prefix_context = build_run_context(PrefixDeps()) - no_prefix_toolset = await ToolManager[PrefixDeps].build(toolset, no_prefix_context) + no_prefix_toolset = await ToolManager[PrefixDeps](toolset).for_run_step(no_prefix_context) assert no_prefix_toolset.tool_defs == snapshot( [ ToolDefinition( @@ -79,7 +79,7 @@ def add(a: int, b: int) -> int: assert await no_prefix_toolset.handle_call(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2})) == 3 foo_context = build_run_context(PrefixDeps(prefix='foo')) - foo_toolset = await ToolManager[PrefixDeps].build(toolset, foo_context) + foo_toolset = await ToolManager[PrefixDeps](toolset).for_run_step(foo_context) assert foo_toolset.tool_defs == snapshot( [ ToolDefinition( @@ -102,7 +102,7 @@ def subtract(a: int, b: int) -> int: return a - b # pragma: lax no cover bar_context = build_run_context(PrefixDeps(prefix='bar')) - bar_toolset = await ToolManager[PrefixDeps].build(toolset, bar_context) + bar_toolset = await ToolManager[PrefixDeps](toolset).for_run_step(bar_context) assert bar_toolset.tool_defs == snapshot( [ ToolDefinition( @@ -162,7 +162,7 @@ async def prepare_add_new_tool(ctx: RunContext[None], tool_defs: list[ToolDefini 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' ), ): - await ToolManager[None].build(prepared_toolset, context) + await ToolManager[None](prepared_toolset).for_run_step(context) async def test_prepared_toolset_user_error_change_tool_names(): @@ -198,7 +198,7 @@ async def prepare_change_names(ctx: RunContext[None], tool_defs: list[ToolDefini 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' ), ): - await ToolManager[None].build(prepared_toolset, context) + await ToolManager[None](prepared_toolset).for_run_step(context) async def test_comprehensive_toolset_composition(): @@ -285,7 +285,7 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef # Test with regular user context regular_deps = TestDeps(user_role='user', enable_advanced=True) regular_context = build_run_context(regular_deps) - final_toolset = await ToolManager[TestDeps].build(prepared_toolset, regular_context) + final_toolset = await ToolManager[TestDeps](prepared_toolset).for_run_step(regular_context) # Tool definitions should have role annotation assert final_toolset.tool_defs == snapshot( [ @@ -338,7 +338,7 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef # Test with admin user context (should have string tools) admin_deps = TestDeps(user_role='admin', enable_advanced=True) admin_context = build_run_context(admin_deps) - admin_final_toolset = await ToolManager[TestDeps].build(prepared_toolset, admin_context) + admin_final_toolset = await ToolManager[TestDeps](prepared_toolset).for_run_step(admin_context) assert admin_final_toolset.tool_defs == snapshot( [ ToolDefinition( @@ -421,7 +421,7 @@ async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDef # Test with advanced features disabled basic_deps = TestDeps(user_role='user', enable_advanced=False) basic_context = build_run_context(basic_deps) - basic_final_toolset = await ToolManager[TestDeps].build(prepared_toolset, basic_context) + basic_final_toolset = await ToolManager[TestDeps](prepared_toolset).for_run_step(basic_context) assert basic_final_toolset.tool_defs == snapshot( [ ToolDefinition( @@ -506,7 +506,7 @@ async def test_tool_manager_reuse_self(): run_context = build_run_context(None, run_step=1) - tool_manager = ToolManager[None](run_context, FunctionToolset[None](), tools={}) + tool_manager = await ToolManager[None](FunctionToolset()).for_run_step(run_context) same_tool_manager = await tool_manager.for_run_step(ctx=run_context) @@ -544,7 +544,7 @@ def other_tool(x: int) -> int: # Create initial context and tool manager initial_context = build_run_context(TestDeps()) - tool_manager = await ToolManager[TestDeps].build(toolset, initial_context) + tool_manager = await ToolManager[TestDeps](toolset).for_run_step(initial_context) # Initially no failed tools assert tool_manager.failed_tools == set() @@ -568,6 +568,7 @@ def other_tool(x: int) -> int: new_tool_manager = await tool_manager.for_run_step(new_context) # The new tool manager should have retry count for the failed tool + assert new_tool_manager.ctx is not None assert new_tool_manager.ctx.retries == {'failing_tool': 1} assert new_tool_manager.failed_tools == set() # reset for new run step @@ -591,6 +592,7 @@ def other_tool(x: int) -> int: another_tool_manager = await new_tool_manager.for_run_step(another_context) # Should now have retry count of 2 for failing_tool + assert another_tool_manager.ctx is not None assert another_tool_manager.ctx.retries == {'failing_tool': 2} assert another_tool_manager.failed_tools == set() @@ -625,7 +627,7 @@ def tool_c(x: int) -> int: # Create tool manager context = build_run_context(TestDeps()) - tool_manager = await ToolManager[TestDeps].build(toolset, context) + tool_manager = await ToolManager[TestDeps](toolset).for_run_step(context) # Call tool_a - should fail and be added to failed_tools with pytest.raises(ToolRetryError): @@ -646,6 +648,7 @@ def tool_c(x: int) -> int: new_context = build_run_context(TestDeps(), run_step=1) new_tool_manager = await tool_manager.for_run_step(new_context) + assert new_tool_manager.ctx is not None assert new_tool_manager.ctx.retries == {'tool_a': 1, 'tool_b': 1} assert new_tool_manager.failed_tools == set() # reset for new run step