Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 152 additions & 39 deletions temporalio/contrib/openai_agents/_mcp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import abc
import asyncio
import dataclasses
import functools
import inspect
import logging
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, Union, cast

from agents import AgentBase, RunContextWrapper
from agents.mcp import MCPServer
Expand All @@ -29,19 +31,45 @@
logger = logging.getLogger(__name__)


@dataclasses.dataclass
class _StatelessListToolsArguments:
factory_argument: Optional[Any]


@dataclasses.dataclass
class _StatelessCallToolsArguments:
tool_name: str
arguments: Optional[dict[str, Any]]
factory_argument: Optional[Any]


@dataclasses.dataclass
class _StatelessListPromptsArguments:
factory_argument: Optional[Any]


@dataclasses.dataclass
class _StatelessGetPromptArguments:
name: str
arguments: Optional[dict[str, Any]]
factory_argument: Optional[Any]


class _StatelessMCPServerReference(MCPServer):
def __init__(
self,
server: str,
config: Optional[ActivityConfig],
cache_tools_list: bool,
factory_argument: Optional[Any] = None,
):
self._name = server + "-stateless"
self._config = config or ActivityConfig(
start_to_close_timeout=timedelta(minutes=1)
)
self._cache_tools_list = cache_tools_list
self._tools = None
self._factory_argument = factory_argument
super().__init__()

@property
Expand All @@ -63,7 +91,7 @@ async def list_tools(
return self._tools
tools = await workflow.execute_activity(
self.name + "-list-tools",
args=[],
_StatelessListToolsArguments(self._factory_argument),
result_type=list[MCPTool],
**self._config,
)
Expand All @@ -75,16 +103,16 @@ async def call_tool(
self, tool_name: str, arguments: Optional[dict[str, Any]]
) -> CallToolResult:
return await workflow.execute_activity(
self.name + "-call-tool",
args=[tool_name, arguments],
self.name + "-call-tool-v2",
_StatelessCallToolsArguments(tool_name, arguments, self._factory_argument),
result_type=CallToolResult,
**self._config,
)

async def list_prompts(self) -> ListPromptsResult:
return await workflow.execute_activity(
self.name + "-list-prompts",
args=[],
_StatelessListPromptsArguments(self._factory_argument),
result_type=ListPromptsResult,
**self._config,
)
Expand All @@ -93,8 +121,8 @@ async def get_prompt(
self, name: str, arguments: Optional[dict[str, Any]] = None
) -> GetPromptResult:
return await workflow.execute_activity(
self.name + "-get-prompt",
args=[name, arguments],
self.name + "-get-prompt-v2",
_StatelessGetPromptArguments(name, arguments, self._factory_argument),
result_type=GetPromptResult,
**self._config,
)
Expand All @@ -111,64 +139,107 @@ class StatelessMCPServerProvider:
function, this cannot be used.
"""

def __init__(self, server_factory: Callable[[], MCPServer]):
def __init__(
self,
name: str,
server_factory: Union[
Callable[[], MCPServer], Callable[[Optional[Any]], MCPServer]
],
):
"""Initialize the stateless temporal MCP server.

Args:
name: The name of the MCP server.
server_factory: A function which will produce MCPServer instances. It should return a new server each time
so that state is not shared between workflow runs
so that state is not shared between workflow runs.
"""
self._server_factory = server_factory
self._name = server_factory().name + "-stateless"

# Cache whether the server factory needs to be provided with arguments
sig = inspect.signature(self._server_factory)
self._server_accepts_arguments = len(sig.parameters) != 0

self._name = name + "-stateless"
super().__init__()

def _create_server(self, factory_argument: Optional[Any]) -> MCPServer:
if self._server_accepts_arguments:
return cast(Callable[[Optional[Any]], MCPServer], self._server_factory)(
factory_argument
)
else:
return cast(Callable[[], MCPServer], self._server_factory)()

@property
def name(self) -> str:
"""Get the server name."""
return self._name

def _get_activities(self) -> Sequence[Callable]:
@activity.defn(name=self.name + "-list-tools")
async def list_tools() -> list[MCPTool]:
server = self._server_factory()
async def list_tools(
args: Optional[_StatelessListToolsArguments] = None,
) -> list[MCPTool]:
server = self._create_server(args.factory_argument if args else None)
try:
await server.connect()
return await server.list_tools()
finally:
await server.cleanup()

@activity.defn(name=self.name + "-call-tool")
async def call_tool(
tool_name: str, arguments: Optional[dict[str, Any]]
) -> CallToolResult:
server = self._server_factory()
@activity.defn(name=self.name + "-call-tool-v2")
async def call_tool(args: _StatelessCallToolsArguments) -> CallToolResult:
server = self._create_server(args.factory_argument)
try:
await server.connect()
return await server.call_tool(tool_name, arguments)
return await server.call_tool(args.tool_name, args.arguments)
finally:
await server.cleanup()

@activity.defn(name=self.name + "-list-prompts")
async def list_prompts() -> ListPromptsResult:
server = self._server_factory()
async def list_prompts(
args: Optional[_StatelessListPromptsArguments] = None,
) -> ListPromptsResult:
server = self._create_server(args.factory_argument if args else None)
try:
await server.connect()
return await server.list_prompts()
finally:
await server.cleanup()

@activity.defn(name=self.name + "-get-prompt")
async def get_prompt(
name: str, arguments: Optional[dict[str, Any]]
) -> GetPromptResult:
server = self._server_factory()
@activity.defn(name=self.name + "-get-prompt-v2")
async def get_prompt(args: _StatelessGetPromptArguments) -> GetPromptResult:
server = self._create_server(args.factory_argument)
try:
await server.connect()
return await server.get_prompt(name, arguments)
return await server.get_prompt(args.name, args.arguments)
finally:
await server.cleanup()

return list_tools, call_tool, list_prompts, get_prompt
@activity.defn(name=self.name + "-call-tool")
async def call_tool_deprecated(
tool_name: str,
arguments: Optional[dict[str, Any]],
) -> CallToolResult:
return await call_tool(
_StatelessCallToolsArguments(tool_name, arguments, None)
)

@activity.defn(name=self.name + "-get-prompt")
async def get_prompt_deprecated(
name: str,
arguments: Optional[dict[str, Any]],
) -> GetPromptResult:
return await get_prompt(_StatelessGetPromptArguments(name, arguments, None))

return (
list_tools,
call_tool,
list_prompts,
get_prompt,
call_tool_deprecated,
get_prompt_deprecated,
)


def _handle_worker_failure(func):
Expand Down Expand Up @@ -202,12 +273,30 @@ async def wrapper(*args, **kwargs):
return wrapper


@dataclasses.dataclass
class _StatefulCallToolsArguments:
tool_name: str
arguments: Optional[dict[str, Any]]


@dataclasses.dataclass
class _StatefulGetPromptArguments:
name: str
arguments: Optional[dict[str, Any]]


@dataclasses.dataclass
class _StatefulServerSessionArguments:
factory_argument: Optional[Any]


class _StatefulMCPServerReference(MCPServer, AbstractAsyncContextManager):
def __init__(
self,
server: str,
config: Optional[ActivityConfig],
server_session_config: Optional[ActivityConfig],
factory_argument: Optional[Any],
):
self._name = server + "-stateful"
self._config = config or ActivityConfig(
Expand All @@ -218,6 +307,7 @@ def __init__(
start_to_close_timeout=timedelta(hours=1),
)
self._connect_handle: Optional[ActivityHandle] = None
self._factory_argument = factory_argument
super().__init__()

@property
Expand All @@ -228,7 +318,7 @@ async def connect(self) -> None:
self._config["task_queue"] = self.name + "@" + workflow.info().run_id
self._connect_handle = workflow.start_activity(
self.name + "-server-session",
args=[],
_StatefulServerSessionArguments(self._factory_argument),
**self._server_session_config,
)

Expand Down Expand Up @@ -276,8 +366,8 @@ async def call_tool(
"Stateful MCP Server not connected. Call connect first."
)
return await workflow.execute_activity(
self.name + "-call-tool",
args=[tool_name, arguments],
self.name + "-call-tool-v2",
_StatefulCallToolsArguments(tool_name, arguments),
result_type=CallToolResult,
**self._config,
)
Expand All @@ -304,8 +394,8 @@ async def get_prompt(
"Stateful MCP Server not connected. Call connect first."
)
return await workflow.execute_activity(
self.name + "-get-prompt",
args=[name, arguments],
self.name + "-get-prompt-v2",
_StatefulGetPromptArguments(name, arguments),
result_type=GetPromptResult,
**self._config,
)
Expand All @@ -329,16 +419,18 @@ class StatefulMCPServerProvider:

def __init__(
self,
server_factory: Callable[[], MCPServer],
name: str,
server_factory: Callable[[Optional[Any]], MCPServer],
):
"""Initialize the stateful temporal MCP server.

Args:
name: The name of the MCP server.
server_factory: A function which will produce MCPServer instances. It should return a new server each time
so that state is not shared between workflow runs
"""
self._server_factory = server_factory
self._name = server_factory().name + "-stateful"
self._name = name + "-stateful"
self._connect_handle: Optional[ActivityHandle] = None
self._servers: dict[str, MCPServer] = {}
super().__init__()
Expand All @@ -357,37 +449,51 @@ async def list_tools() -> list[MCPTool]:
return await self._servers[_server_id()].list_tools()

@activity.defn(name=self.name + "-call-tool")
async def call_tool(
async def call_tool_deprecated(
tool_name: str, arguments: Optional[dict[str, Any]]
) -> CallToolResult:
return await self._servers[_server_id()].call_tool(tool_name, arguments)

@activity.defn(name=self.name + "-call-tool-v2")
async def call_tool(args: _StatefulCallToolsArguments) -> CallToolResult:
return await self._servers[_server_id()].call_tool(
args.tool_name, args.arguments
)

@activity.defn(name=self.name + "-list-prompts")
async def list_prompts() -> ListPromptsResult:
return await self._servers[_server_id()].list_prompts()

@activity.defn(name=self.name + "-get-prompt")
async def get_prompt(
async def get_prompt_deprecated(
name: str, arguments: Optional[dict[str, Any]]
) -> GetPromptResult:
return await self._servers[_server_id()].get_prompt(name, arguments)

@activity.defn(name=self.name + "-get-prompt-v2")
async def get_prompt(args: _StatefulGetPromptArguments) -> GetPromptResult:
return await self._servers[_server_id()].get_prompt(
args.name, args.arguments
)

async def heartbeat_every(delay: float, *details: Any) -> None:
"""Heartbeat every so often while not cancelled"""
while True:
await asyncio.sleep(delay)
activity.heartbeat(*details)

@activity.defn(name=self.name + "-server-session")
async def connect() -> None:
async def connect(
args: Optional[_StatefulServerSessionArguments] = None,
) -> None:
heartbeat_task = asyncio.create_task(heartbeat_every(30))

server_id = self.name + "@" + activity.info().workflow_run_id
if server_id in self._servers:
raise ApplicationError(
"Cannot connect to an already running server. Use a distinct name if running multiple servers in one workflow."
)
server = self._server_factory()
server = self._server_factory(args.factory_argument if args else None)
try:
self._servers[server_id] = server
try:
Expand All @@ -396,7 +502,14 @@ async def connect() -> None:
worker = Worker(
activity.client(),
task_queue=server_id,
activities=[list_tools, call_tool, list_prompts, get_prompt],
activities=[
list_tools,
call_tool,
list_prompts,
get_prompt,
call_tool_deprecated,
get_prompt_deprecated,
],
activity_task_poller_behavior=PollerBehaviorSimpleMaximum(1),
)

Expand Down
Loading
Loading