diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py index f3d3d362dc..1e7905b802 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -2,7 +2,7 @@ from .combined import CombinedToolset from .deferred import DeferredToolset from .filtered import FilteredToolset -from .function import FunctionToolset +from .function import FunctionToolset, FunctionToolsetTool from .prefixed import PrefixedToolset from .prepared import PreparedToolset from .renamed import RenamedToolset @@ -15,6 +15,7 @@ 'DeferredToolset', 'FilteredToolset', 'FunctionToolset', + 'FunctionToolsetTool', 'PrefixedToolset', 'RenamedToolset', 'PreparedToolset', diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py index 63f44a1f0c..cb90b42f96 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/function.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -20,7 +20,7 @@ @dataclass -class _FunctionToolsetTool(ToolsetTool[AgentDepsT]): +class FunctionToolsetTool(ToolsetTool[AgentDepsT]): """A tool definition for a function toolset tool that keeps track of the function to call.""" call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]] @@ -222,7 +222,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ else: raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') - tools[new_name] = _FunctionToolsetTool( + tools[new_name] = FunctionToolsetTool( toolset=self, tool_def=tool_def, max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries, @@ -234,5 +234,5 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[ async def call_tool( self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] ) -> Any: - assert isinstance(tool, _FunctionToolsetTool) + assert isinstance(tool, FunctionToolsetTool) return await tool.call_func(tool_args, ctx)