Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ async def _prepare_request_parameters(
) -> models.ModelRequestParameters:
"""Build tools and create an agent model."""
run_context = build_run_context(ctx)

# This will raise errors for any tool name conflicts
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)

output_schema = ctx.deps.output_schema
Expand Down
79 changes: 50 additions & 29 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dataclasses import dataclass, field, replace
from typing import Any, Generic

from opentelemetry.trace import Tracer
from pydantic import ValidationError
from typing_extensions import assert_never

Expand All @@ -21,41 +22,46 @@
class ToolManager(Generic[AgentDepsT]):
"""Manages tools for an agent run step. It caches the agent run's toolset's tool definitions and handles calling tools and retries."""

ctx: RunContext[AgentDepsT]
"""The agent run context for a specific run step."""
toolset: AbstractToolset[AgentDepsT]
"""The toolset that provides the tools for this run step."""
tools: dict[str, ToolsetTool[AgentDepsT]]
ctx: RunContext[AgentDepsT] | None = None
"""The agent run context for a specific run step."""
tools: dict[str, ToolsetTool[AgentDepsT]] | None = None
"""The cached tools for this run step."""
failed_tools: set[str] = field(default_factory=set)
"""Names of tools that failed in this run step."""

@classmethod
async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
"""Build a new tool manager for a specific run step."""
return cls(
ctx=ctx,
toolset=toolset,
tools=await toolset.get_tools(ctx),
)

async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
if ctx.run_step == self.ctx.run_step:
return self

retries = {
failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools
}
return await self.__class__.build(self.toolset, replace(ctx, retries=retries))
if self.ctx is not None:
if ctx.run_step == self.ctx.run_step:
return self

retries = {
failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1
for failed_tool_name in self.failed_tools
}
ctx = replace(ctx, retries=retries)

return self.__class__(
toolset=self.toolset,
ctx=ctx,
tools=await self.toolset.get_tools(ctx),
)

@property
def tool_defs(self) -> list[ToolDefinition]:
"""The tool definitions for the tools in this tool manager."""
if self.tools is None:
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover

return [tool.tool_def for tool in self.tools.values()]

def get_tool_def(self, name: str) -> ToolDefinition | None:
"""Get the tool definition for a given tool name, or `None` if the tool is unknown."""
if self.tools is None:
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover

try:
return self.tools[name].tool_def
except KeyError:
Expand All @@ -71,15 +77,25 @@ async def handle_call(
allow_partial: Whether to allow partial validation of the tool arguments.
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
"""
if self.tools is None or self.ctx is None:
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover

if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
# Output tool calls are not traced
return await self._call_tool(call, allow_partial, wrap_validation_errors)
else:
return await self._call_tool_traced(call, allow_partial, wrap_validation_errors)
return await self._call_tool_traced(
call,
allow_partial,
wrap_validation_errors,
self.ctx.tracer,
self.ctx.trace_include_content,
)

async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validation_errors: bool) -> Any:
if self.tools is None or self.ctx is None:
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover

async def _call_tool(
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
) -> Any:
name = call.tool_name
tool = self.tools.get(name)
try:
Expand Down Expand Up @@ -137,14 +153,19 @@ async def _call_tool(
raise e

async def _call_tool_traced(
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
self,
call: ToolCallPart,
allow_partial: bool,
wrap_validation_errors: bool,
tracer: Tracer,
include_content: bool = False,
) -> Any:
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
span_attributes = {
'gen_ai.tool.name': call.tool_name,
# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
'gen_ai.tool.call.id': call.tool_call_id,
**({'tool_arguments': call.args_as_json_str()} if self.ctx.trace_include_content else {}),
**({'tool_arguments': call.args_as_json_str()} if include_content else {}),
'logfire.msg': f'running tool: {call.tool_name}',
# add the JSON schema so these attributes are formatted nicely in Logfire
'logfire.json_schema': json.dumps(
Expand All @@ -156,7 +177,7 @@ async def _call_tool_traced(
'tool_arguments': {'type': 'object'},
'tool_response': {'type': 'object'},
}
if self.ctx.trace_include_content
if include_content
else {}
),
'gen_ai.tool.name': {},
Expand All @@ -165,16 +186,16 @@ async def _call_tool_traced(
}
),
}
with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
try:
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
except ToolRetryError as e:
part = e.tool_retry
if self.ctx.trace_include_content and span.is_recording():
if include_content and span.is_recording():
span.set_attribute('tool_response', part.model_response())
raise e

if self.ctx.trace_include_content and span.is_recording():
if include_content and span.is_recording():
span.set_attribute(
'tool_response',
tool_result
Expand Down
137 changes: 62 additions & 75 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,8 @@ async def main():
if output_toolset:
output_toolset.max_retries = self._max_result_retries
output_toolset.output_validators = output_validators
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
tool_manager = ToolManager[AgentDepsT](toolset)

# Build the graph
graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
Expand All @@ -581,88 +583,73 @@ async def main():
run_step=0,
)

# Merge model settings in order of precedence: run > agent > model
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
model_settings = merge_model_settings(merged_settings, model_settings)
usage_limits = usage_limits or _usage.UsageLimits()

async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
parts = [
self._instructions,
*[await func.run(run_context) for func in self._instructions_functions],
]

model_profile = model_used.profile
if isinstance(output_schema, _output.PromptedOutputSchema):
instructions = output_schema.instructions(model_profile.prompted_output_template)
parts.append(instructions)

parts = [p for p in parts if p]
if not parts:
return None
return '\n\n'.join(parts).strip()

if isinstance(model_used, InstrumentedModel):
instrumentation_settings = model_used.instrumentation_settings
tracer = model_used.instrumentation_settings.tracer
else:
instrumentation_settings = None
tracer = NoOpTracer()

run_context = RunContext[AgentDepsT](
deps=deps,
model=model_used,
usage=usage,
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
user_deps=deps,
prompt=user_prompt,
messages=state.message_history,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
history_processors=self.history_processors,
builtin_tools=list(self._builtin_tools),
tool_manager=tool_manager,
tracer=tracer,
trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content,
run_step=state.run_step,
get_instructions=get_instructions,
instrumentation_settings=instrumentation_settings,
)
start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
instructions=self._instructions,
instructions_functions=self._instructions_functions,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
)

toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)

async with toolset:
# This will raise errors for any name conflicts
tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context)

# Merge model settings in order of precedence: run > agent > model
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
model_settings = merge_model_settings(merged_settings, model_settings)
usage_limits = usage_limits or _usage.UsageLimits()
agent_name = self.name or 'agent'
run_span = tracer.start_span(
'agent run',
attributes={
'model_name': model_used.model_name if model_used else 'no-model',
'agent_name': agent_name,
'logfire.msg': f'{agent_name} run',
},
)

async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
parts = [
self._instructions,
*[await func.run(run_context) for func in self._instructions_functions],
]

model_profile = model_used.profile
if isinstance(output_schema, _output.PromptedOutputSchema):
instructions = output_schema.instructions(model_profile.prompted_output_template)
parts.append(instructions)

parts = [p for p in parts if p]
if not parts:
return None
return '\n\n'.join(parts).strip()

graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
user_deps=deps,
prompt=user_prompt,
new_message_index=new_message_index,
model=model_used,
model_settings=model_settings,
usage_limits=usage_limits,
max_result_retries=self._max_result_retries,
end_strategy=self.end_strategy,
output_schema=output_schema,
output_validators=output_validators,
history_processors=self.history_processors,
builtin_tools=list(self._builtin_tools),
tool_manager=tool_manager,
tracer=tracer,
get_instructions=get_instructions,
instrumentation_settings=instrumentation_settings,
)
start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
instructions=self._instructions,
instructions_functions=self._instructions_functions,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
)
agent_name = self.name or 'agent'
run_span = tracer.start_span(
'agent run',
attributes={
'model_name': model_used.model_name if model_used else 'no-model',
'agent_name': agent_name,
'logfire.msg': f'{agent_name} run',
},
)

try:
try:
async with toolset:
async with graph.iter(
start_node,
state=state,
Expand All @@ -682,12 +669,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
),
)
finally:
try:
if instrumentation_settings and run_span.is_recording():
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
finally:
try:
if instrumentation_settings and run_span.is_recording():
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
finally:
run_span.end()
run_span.end()

def _run_span_end_attributes(
self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings
Expand Down
13 changes: 1 addition & 12 deletions tests/test_ag_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,18 +1104,7 @@ async def store_state(
events.append(json.loads(event.removeprefix('data: ')))

assert events == simple_result()
assert seen_states == snapshot(
[
41, # run msg_1, prepare_tools call 1
42, # run msg_1, prepare_tools call 2
0, # run msg_2, prepare_tools call 1
1, # run msg_2, prepare_tools call 2
0, # run msg_3, prepare_tools call 1
1, # run msg_3, prepare_tools call 2
42, # run msg_4, prepare_tools call 1
43, # run msg_4, prepare_tools call 2
]
)
assert seen_states == snapshot([41, 0, 0, 42])


async def test_request_with_state_without_handler() -> None:
Expand Down
6 changes: 1 addition & 5 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3768,11 +3768,7 @@ async def via_toolset_decorator_for_entire_run(ctx: RunContext[None]) -> Abstrac
assert run_result._state.run_step == 3 # pyright: ignore[reportPrivateUsage]
assert len(available_tools) == 3
assert toolset_creation_counts == snapshot(
{
'via_toolsets_arg': 4,
'via_toolset_decorator': 4,
'via_toolset_decorator_for_entire_run': 1,
}
defaultdict(int, {'via_toolsets_arg': 3, 'via_toolset_decorator': 3, 'via_toolset_decorator_for_entire_run': 1})
)


Expand Down
2 changes: 0 additions & 2 deletions tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,6 @@ async def test_complex_agent_run_in_workflow(
'RunWorkflow:ComplexAgentWorkflow',
'StartActivity:agent__complex_agent__mcp_server__mcp__get_tools',
'RunActivity:agent__complex_agent__mcp_server__mcp__get_tools',
'StartActivity:agent__complex_agent__mcp_server__mcp__get_tools',
'RunActivity:agent__complex_agent__mcp_server__mcp__get_tools',
'StartActivity:agent__complex_agent__model_request_stream',
'ctx.run_step=1',
'{"index":0,"part":{"tool_name":"get_country","args":"","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_kind":"tool-call"},"event_kind":"part_start"}',
Expand Down
7 changes: 3 additions & 4 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,10 +1226,9 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int:
with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"):
agent.run_sync('Begin infinite retry loop!')

# There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in.
assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5]
assert prepare_retries == [0, 0, 1, 2, 3, 4, 5]
assert call_retries == [0, 1, 2, 3, 4, 5]
assert prepare_tools_retries == snapshot([0, 1, 2, 3, 4, 5])
assert prepare_retries == snapshot([0, 1, 2, 3, 4, 5])
assert call_retries == snapshot([0, 1, 2, 3, 4, 5])


def test_deferred_tool():
Expand Down
Loading