diff --git a/docs/mcp-transports.md b/docs/mcp-transports.md new file mode 100644 index 00000000..eb9f5bee --- /dev/null +++ b/docs/mcp-transports.md @@ -0,0 +1,368 @@ +# MCP Transports + +This guide explains the different transport options for connecting to MCP servers with the Agents SDK. + +For a primer on MCP Servers and how they are implemented in the Agents SDK with `McpAgent`[here](docs/mcp-servers.md) + +## Streamable HTTP Transport (Recommended) + +The **Streamable HTTP** transport is the recommended way to connect to MCP servers. + +### How it works + +When a client connects to your MCP server: + +1. The client makes an HTTP request to your Worker with a JSON-RPC message in the body +2. Your Worker upgrades the connection to a WebSocket +3. The WebSocket connects to your `McpAgent` Durable Object which manages connection state +4. JSON-RPC messages flow bidirectionally over the WebSocket +5. Your Worker streams responses back to the client using Server-Sent Events (SSE) + +This is all handled automatically by the `McpAgent.serve()` method: + +```typescript +import { McpAgent } from "agents/mcp"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; + +export class MyMCP extends McpAgent { + server = new McpServer({ name: "Demo", version: "1.0.0" }); + + async init() { + // Define your tools, resources, prompts + } +} + +// Serve with Streamable HTTP transport +export default MyMCP.serve("/mcp"); +``` + +The `serve()` method returns a Worker with a `fetch` handler that: + +- Handles CORS preflight requests +- Manages WebSocket upgrades +- Routes messages to your Durable Object + +### Connection from clients + +Clients connect using the `streamable-http` transport: + +```typescript +await agent.addMcpServer("my-server", "https://your-worker.workers.dev/mcp"); +``` + +## SSE Transport (Deprecated) + +We also support the legacy **SSE (Server-Sent Events)** transport, but it's deprecated in favor of Streamable HTTP. + +If you need SSE transport for compatibility: + +```typescript +// Server +export default MyMCP.serveSSE("/sse"); + +// Client +await agent.addMcpServer("my-server", url, callbackHost); +``` + +## RPC Transport (Experimental) + +The **RPC transport** is a custom transport designed for internal applications where your MCP server and agent are both running on Cloudflare. They can even run in the same Worker! It sends JSON-RPC messages directly over Cloudflare's RPC bindings without going over the public internet. + +### Why use RPC transport? + +✅ **Faster**: No network overhead - direct function calls +✅ **Simpler**: No HTTP endpoints, no connection management +✅ **Internal only**: Perfect for agents calling MCP servers within the same Worker + +⚠️ **No authentication**: Not suitable for public APIs - use HTTP/SSE for external connections + +### Connecting an Agent to an McpAgent via RPC + +The RPC transport uses [Cloudflare Service Bindings](https://developers.cloudflare.com/workers/runtime-apis/bindings/service-bindings/) to connect your `Agent` (MCP client) directly to your `McpAgent` (MCP server) using Durable Object RPC calls. + +#### Step 1: Define your MCP server + +First, create your `McpAgent` with the tools you want to expose: + +```typescript +import { McpAgent } from "agents/mcp"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { z } from "zod"; + +export class MyMCP extends McpAgent { + server = new McpServer({ + name: "MyMCP", + version: "1.0.0" + }); + + initialState: State = { + counter: 0 + }; + + async init() { + // Define a tool + this.server.tool( + "add", + "Add to the counter", + { amount: z.number() }, + async ({ amount }) => { + this.setState({ counter: this.state.counter + amount }); + return { + content: [ + { + type: "text", + text: `Added ${amount}, total is now ${this.state.counter}` + } + ] + }; + } + ); + } +} +``` + +#### Step 2: Connect your Agent to the MCP server + +In your `Agent`, call `addMcpServer()` with RPC transport in the `onStart()` method: + +```typescript +import { AIChatAgent } from "agents/ai-chat-agent"; + +export class Chat extends AIChatAgent { + async onStart(): Promise { + // Connect to MyMCP via RPC using binding directly + await this.addMcpServer("my-mcp", this.env.MyMCP, { + transport: { type: "rpc" } + }); + + // Or using binding name string + await this.addMcpServer("my-mcp", "MyMCP", { + transport: { type: "rpc" } + }); + // ▲ ▲ + // │ └─ Binding name (from wrangler.jsonc) or namespace + // └─ Server ID (any unique string) + } + + async onChatMessage(onFinish) { + // MCP tools are now available! + const allTools = this.mcp.getAITools(); + + const result = streamText({ + model, + tools: allTools + // ... + }); + + return createUIMessageStreamResponse({ stream: result }); + } +} +``` + +Note that in production you would not connect to MCP servers in `onStart` but in standalone method you could add error handling. See this [MCP client example](examples/mcp-client) + +#### Step 3: Configure Durable Object bindings + +In your `wrangler.jsonc`, define bindings for both Durable Objects: + +```jsonc +{ + "durable_objects": { + "bindings": [ + { + "name": "Chat", + "class_name": "Chat" + }, + { + "name": "MyMCP", // This is the binding name you pass to addMcpServer + "class_name": "MyMCP" + } + ] + }, + "migrations": [ + { + "new_sqlite_classes": ["MyMCP", "Chat"], + "tag": "v1" + } + ] +} +``` + +#### Step 4: Set up your Worker fetch handler + +Route requests to your Chat agent: + +```typescript +import { routeAgentRequest } from "agents"; + +export default { + async fetch(request: Request, env: Env, ctx: ExecutionContext) { + const url = new URL(request.url); + + // Serve MCP server via streamable-http on /mcp endpoint + if (url.pathname.startsWith("/mcp")) { + return MyMCP.serve("/mcp").fetch(request, env, ctx); + } + + // Route other requests to agents + const response = await routeAgentRequest(request, env); + if (response) return response; + + return new Response("Not found", { status: 404 }); + } +}; +``` + +Optionally, you can also expose your MCP server via streamable-http. + +That's it! When your agent makes an MCP call, it: + +1. Serializes the JSON-RPC message +2. Calls `stub.handleMcpMessage(message)` over RPC +3. The `McpAgent` processes it and returns the response +4. Your agent receives the result - all without any network calls + +### Passing props from client to server + +Since RPC transport doesn't have an OAuth flow, you can pass user context (like userId, role, etc.) directly as props: + +```typescript +// Pass props to provide user context to the MCP server +await this.addMcpServer("my-mcp", this.env.MyMCP, { + transport: { type: "rpc", props: { userId: "user-123", role: "admin" } } +}); +``` + +Your `McpAgent` can then access these props: + +```typescript +export class MyMCP extends McpAgent< + Env, + State, + { userId?: string; role?: string } +> { + async init() { + this.server.tool("whoami", "Get current user info", {}, async () => { + const userId = this.props?.userId || "anonymous"; + const role = this.props?.role || "guest"; + + return { + content: [{ type: "text", text: `User ID: ${userId}, Role: ${role}` }] + }; + }); + } +} +``` + +The props are: + +- **Type-safe**: TypeScript extracts the Props type from your McpAgent generic +- **Persistent**: Stored in Durable Object storage via `updateProps()` +- **Available immediately**: Set before any tool calls are made + +This is useful for: + +- User authentication context +- Tenant/organization IDs +- Feature flags or permissions +- Any per-connection configuration + +### How RPC transport works under the hood + +When you call `addMcpServer()` with RPC transport, the SDK creates an RPC transport that calls the `handleMcpMessage()` method on your `McpAgent`: + +```typescript +// Built into the McpAgent base class +async handleMcpMessage( + message: JSONRPCMessage +): Promise { + // Recreate transport if needed (e.g., after hibernation) + if (!this._transport) { + const server = await this.server; + this._transport = new RPCServerTransport(); + await server.connect(this._transport); + } + + return await this._transport.handle(message); +} +``` + +This happens entirely within your Worker's execution context using Cloudflare's RPC mechanism - no HTTP, no WebSockets, no public internet. + +The RPC transport is minimal by design (~350 lines) and fully supports: + +- JSON-RPC 2.0 validation (with helpful error messages) +- Batch requests +- Notifications (messages without `id` field) +- Automatic reconnection after Durable Object hibernation + +### Configuring RPC Transport Server Timeout + +The RPC transport has a configurable timeout for waiting for tool responses. By default, the server will wait **60 seconds** for a tool handler to call `send()`. You can customize this by overriding the `getRpcTransportOptions()` method in your `McpAgent`: + +```typescript +export class MyMCP extends McpAgent { + server = new McpServer({ + name: "MyMCP", + version: "1.0.0" + }); + + // Configure RPC transport timeout + protected getRpcTransportOptions() { + return { + timeout: 120000 // 2 minutes (default is 60000) + }; + } + + async init() { + this.server.tool( + "long-running-task", + "A tool that takes a while to complete", + { input: z.string() }, + async ({ input }) => { + // This tool has up to 2 minutes to complete + await longRunningOperation(input); + return { + content: [{ type: "text", text: "Task completed" }] + }; + } + ); + } +} +``` + +The timeout ensures that if a tool handler fails to respond (e.g., due to an infinite loop or forgotten `send()` call), the request will fail with a clear timeout error rather than hanging indefinitely. + +### Advanced: Custom RPC function names + +By default, the RPC transport calls the `handleMcpMessage` function. You can customize this: + +```typescript +await this.addMcpServer("my-server", "MyMCP", { + transport: { type: "rpc", functionName: "customHandler" } +}); +``` + +Your `McpAgent` would then need to implement: + +```typescript +async customHandler( + message: JSONRPCMessage +): Promise { + // Your custom logic +} +``` + +## Choosing a transport + +| Transport | Use when | Pros | Cons | +| ------------------- | ------------------------------------- | ---------------------------------------- | ------------------------------- | +| **Streamable HTTP** | External MCP servers, production apps | Standard protocol, secure, supports auth | Slight network overhead | +| **RPC** | Internal agents | Fastest, simplest setup | No auth, Service Bindings only | +| **SSE** | Legacy compatibility | Backwards compatible | Deprecated, use Streamable HTTP | + +## Examples + +- **Streamable HTTP**: See [`examples/mcp`](../examples/mcp) +- **RPC Transport**: See [`examples/mcp-rpc-transport`](../examples/mcp-rpc-transport) +- **MCP Client**: See [`examples/mcp-client`](../examples/mcp-client) diff --git a/examples/mcp-rpc-transport/.env_example b/examples/mcp-rpc-transport/.env_example new file mode 100644 index 00000000..121e76c4 --- /dev/null +++ b/examples/mcp-rpc-transport/.env_example @@ -0,0 +1 @@ +OPENAI_API_KEY=your_openai_api_key_here \ No newline at end of file diff --git a/examples/mcp-rpc-transport/README.md b/examples/mcp-rpc-transport/README.md new file mode 100644 index 00000000..04e2cf87 --- /dev/null +++ b/examples/mcp-rpc-transport/README.md @@ -0,0 +1,87 @@ +# RPC Transport for MCP + +Example showing an `Agent` calling an `McpAgent` within the same worker using a custom RPC transport. + +## Why RPC Transport? + +If your MCP server and your agent/client are both deployed to the Cloudflare developer platform, our RPC transport is the fastest way to connect them: + +- **No network overhead** - Direct function calls instead of HTTP +- **Simpler** - No endpoints to configure, no connection management, no authentication. + +This is very useful for internal applications. You can define `tools`, `prompts` and `resources` in your MCP server, expose that publically to your users, and also power your own `Agent` from the same `McpAgent`. + +## How it works + +Both the agent (MCP client) and MCP server can exist in the same Worker. + +The MCP server is just a regular `McpAgent`: + +```typescript +export class MyMCP extends McpAgent { + server = new McpServer({ + name: "Demo", + version: "1.0.0" + }); + + async init() { + this.server.tool( + "add", + "Add to the counter, stored in the MCP", + { a: z.number() }, + async ({ a }) => { + this.setState({ ...this.state, counter: this.state.counter + a }); + return { + content: [ + { + text: `Added ${a}, total is now ${this.state.counter}`, + type: "text" + } + ] + }; + } + ); + } +} +``` + +The agent calls out to the MCP server using Cloudflare's RPC bindings: + +```typescript +export class Chat extends AIChatAgent { + async onStart(): Promise { + // Connect to MyMCP server via RPC + await this.addMcpServer("test-server", this.env.MyMCP, { + transport: { type: "rpc" } + }); + // Or pass the binding name as a string: + // await this.addMcpServer("test-server", "MyMCP", { transport: { type: "rpc" } }); + } + + async onChatMessage(onFinish: StreamTextOnFinishCallback) { + // MCP tools are now available + const allTools = this.mcp.getAITools(); + + const result = streamText({ + model, + tools: allTools + // ... + }); + } +} +``` + +## Instructions + +1. Copy `.dev.vars.example` to `.dev.vars` and add your OpenAI API key +2. Run `npm install` +3. Run `npm start` +4. Open the UI in your browser + +Try asking the AI to add numbers to the counter! + +## More Info + +Sevice bindings over RPC are commonly used in Workers to call out to other Cloudflare services. You can find out more [in the docs](https://developers.cloudflare.com/workers/runtime-apis/bindings/). + +The Model Context Protocol supports [pluggable transports](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports). The code for this custom RPC transport can be found [here](packages/agents/src/mcp/rpc.ts) diff --git a/examples/mcp-rpc-transport/index.html b/examples/mcp-rpc-transport/index.html new file mode 100644 index 00000000..c66282dc --- /dev/null +++ b/examples/mcp-rpc-transport/index.html @@ -0,0 +1,10 @@ + + + + MCP Example + + +
+ + + diff --git a/examples/mcp-rpc-transport/package.json b/examples/mcp-rpc-transport/package.json new file mode 100644 index 00000000..d453ed3d --- /dev/null +++ b/examples/mcp-rpc-transport/package.json @@ -0,0 +1,11 @@ +{ + "author": "Matt Carey ", + "keywords": [], + "name": "@cloudflare/agents-mcp-rpc-transport-demo", + "private": true, + "scripts": { + "start": "vite dev", + "deploy": "vite build && wrangler deploy" + }, + "type": "module" +} diff --git a/examples/mcp-rpc-transport/public/favicon.ico b/examples/mcp-rpc-transport/public/favicon.ico new file mode 100644 index 00000000..6d647f9d Binary files /dev/null and b/examples/mcp-rpc-transport/public/favicon.ico differ diff --git a/examples/mcp-rpc-transport/public/normalize.css b/examples/mcp-rpc-transport/public/normalize.css new file mode 100644 index 00000000..7f612e4b --- /dev/null +++ b/examples/mcp-rpc-transport/public/normalize.css @@ -0,0 +1,351 @@ +/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */ + +/* Document + ========================================================================== */ + +/** + * 1. Correct the line height in all browsers. + * 2. Prevent adjustments of font size after orientation changes in iOS. + */ + +html { + line-height: 1.15; /* 1 */ + -webkit-text-size-adjust: 100%; /* 2 */ +} + +/* Sections + ========================================================================== */ + +/** + * Remove the margin in all browsers. + */ + +body { + margin: 0; +} + +/** + * Render the `main` element consistently in IE. + */ + +main { + display: block; +} + +/** + * Correct the font size and margin on `h1` elements within `section` and + * `article` contexts in Chrome, Firefox, and Safari. + */ + +h1 { + font-size: 2em; + margin: 0.67em 0; +} + +/* Grouping content + ========================================================================== */ + +/** + * 1. Add the correct box sizing in Firefox. + * 2. Show the overflow in Edge and IE. + */ + +hr { + box-sizing: content-box; /* 1 */ + height: 0; /* 1 */ + overflow: visible; /* 2 */ +} + +/** + * 1. Correct the inheritance and scaling of font size in all browsers. + * 2. Correct the odd `em` font sizing in all browsers. + */ + +pre { + font-family: monospace; /* 1 */ + font-size: 1em; /* 2 */ +} + +/* Text-level semantics + ========================================================================== */ + +/** + * Remove the gray background on active links in IE 10. + */ + +a { + background-color: transparent; +} + +/** + * 1. Remove the bottom border in Chrome 57- + * 2. Add the correct text decoration in Chrome, Edge, IE, Opera, and Safari. + */ + +abbr[title] { + border-bottom: none; /* 1 */ + text-decoration: underline; /* 2 */ + text-decoration: underline dotted; /* 2 */ +} + +/** + * Add the correct font weight in Chrome, Edge, and Safari. + */ + +b, +strong { + font-weight: bolder; +} + +/** + * 1. Correct the inheritance and scaling of font size in all browsers. + * 2. Correct the odd `em` font sizing in all browsers. + */ + +code, +kbd, +samp { + font-family: monospace; /* 1 */ + font-size: 1em; /* 2 */ +} + +/** + * Add the correct font size in all browsers. + */ + +small { + font-size: 80%; +} + +/** + * Prevent `sub` and `sup` elements from affecting the line height in + * all browsers. + */ + +sub, +sup { + font-size: 75%; + line-height: 0; + position: relative; + vertical-align: baseline; +} + +sub { + bottom: -0.25em; +} + +sup { + top: -0.5em; +} + +/* Embedded content + ========================================================================== */ + +/** + * Remove the border on images inside links in IE 10. + */ + +img { + border-style: none; +} + +/* Forms + ========================================================================== */ + +/** + * 1. Change the font styles in all browsers. + * 2. Remove the margin in Firefox and Safari. + */ + +button, +input, +optgroup, +select, +textarea { + font-family: inherit; /* 1 */ + font-size: 100%; /* 1 */ + line-height: 1.15; /* 1 */ + margin: 0; /* 2 */ +} + +/** + * Show the overflow in IE. + * 1. Show the overflow in Edge. + */ + +button, +input { + /* 1 */ + overflow: visible; +} + +/** + * Remove the inheritance of text transform in Edge, Firefox, and IE. + * 1. Remove the inheritance of text transform in Firefox. + */ + +button, +select { + /* 1 */ + text-transform: none; +} + +/** + * Correct the inability to style clickable types in iOS and Safari. + */ + +button, +[type="button"], +[type="reset"], +[type="submit"] { + -webkit-appearance: button; +} + +/** + * Remove the inner border and padding in Firefox. + */ + +button::-moz-focus-inner, +[type="button"]::-moz-focus-inner, +[type="reset"]::-moz-focus-inner, +[type="submit"]::-moz-focus-inner { + border-style: none; + padding: 0; +} + +/** + * Restore the focus styles unset by the previous rule. + */ + +button:-moz-focusring, +[type="button"]:-moz-focusring, +[type="reset"]:-moz-focusring, +[type="submit"]:-moz-focusring { + outline: 1px dotted ButtonText; +} + +/** + * Correct the padding in Firefox. + */ + +fieldset { + padding: 0.35em 0.75em 0.625em; +} + +/** + * 1. Correct the text wrapping in Edge and IE. + * 2. Correct the color inheritance from `fieldset` elements in IE. + * 3. Remove the padding so developers are not caught out when they zero out + * `fieldset` elements in all browsers. + */ + +legend { + box-sizing: border-box; /* 1 */ + color: inherit; /* 2 */ + display: table; /* 1 */ + max-width: 100%; /* 1 */ + padding: 0; /* 3 */ + white-space: normal; /* 1 */ +} + +/** + * Add the correct vertical alignment in Chrome, Firefox, and Opera. + */ + +progress { + vertical-align: baseline; +} + +/** + * Remove the default vertical scrollbar in IE 10+. + */ + +textarea { + overflow: auto; +} + +/** + * 1. Add the correct box sizing in IE 10. + * 2. Remove the padding in IE 10. + */ + +[type="checkbox"], +[type="radio"] { + box-sizing: border-box; /* 1 */ + padding: 0; /* 2 */ +} + +/** + * Correct the cursor style of increment and decrement buttons in Chrome. + */ + +[type="number"]::-webkit-inner-spin-button, +[type="number"]::-webkit-outer-spin-button { + height: auto; +} + +/** + * 1. Correct the odd appearance in Chrome and Safari. + * 2. Correct the outline style in Safari. + */ + +[type="search"] { + -webkit-appearance: textfield; /* 1 */ + outline-offset: -2px; /* 2 */ +} + +/** + * Remove the inner padding in Chrome and Safari on macOS. + */ + +[type="search"]::-webkit-search-decoration { + -webkit-appearance: none; +} + +/** + * 1. Correct the inability to style clickable types in iOS and Safari. + * 2. Change font properties to `inherit` in Safari. + */ + +::-webkit-file-upload-button { + -webkit-appearance: button; /* 1 */ + font: inherit; /* 2 */ +} + +/* Interactive + ========================================================================== */ + +/* + * Add the correct display in Edge, IE 10+, and Firefox. + */ + +details { + display: block; +} + +/* + * Add the correct display in all browsers. + */ + +summary { + display: list-item; +} + +/* Misc + ========================================================================== */ + +/** + * Add the correct display in IE 10+. + */ + +template { + display: none; +} + +/** + * Add the correct display in IE 10. + */ + +[hidden] { + display: none; +} diff --git a/examples/mcp-rpc-transport/src/client.tsx b/examples/mcp-rpc-transport/src/client.tsx new file mode 100644 index 00000000..45894511 --- /dev/null +++ b/examples/mcp-rpc-transport/src/client.tsx @@ -0,0 +1,241 @@ +import "./styles.css"; +import { useAgent } from "agents/react"; +import { useAgentChat } from "agents/ai-react"; +import { createRoot } from "react-dom/client"; +import { useCallback, useEffect, useRef, useState } from "react"; +import type { MCPServersState } from "agents"; + +function App() { + const [theme, setTheme] = useState<"dark" | "light">(() => { + const savedTheme = localStorage.getItem("theme"); + return (savedTheme as "dark" | "light") || "dark"; + }); + + const [input, setInput] = useState(""); + const messagesEndRef = useRef(null); + const [mcpState, setMcpState] = useState({ + prompts: [], + resources: [], + servers: {}, + tools: [] + }); + const [showMcpServers, setShowMcpServers] = useState(false); + + const scrollToBottom = useCallback(() => { + messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); + }, []); + + useEffect(() => { + if (theme === "dark") { + document.documentElement.classList.add("dark"); + document.documentElement.classList.remove("light"); + } else { + document.documentElement.classList.remove("dark"); + document.documentElement.classList.add("light"); + } + localStorage.setItem("theme", theme); + }, [theme]); + + const toggleTheme = () => { + setTheme(theme === "dark" ? "light" : "dark"); + }; + + const openPopup = (authUrl: string) => { + window.open( + authUrl, + "popupWindow", + "width=600,height=800,resizable=yes,scrollbars=yes,toolbar=yes,menubar=no,location=no,directories=no,status=yes" + ); + }; + + const agent = useAgent({ + agent: "chat", + onMcpUpdate: (mcpServers: MCPServersState) => { + setMcpState(mcpServers); + } + }); + + const { messages, sendMessage, clearHistory } = useAgentChat({ + agent + }); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + if (!input.trim()) return; + + const message = input; + setInput(""); + + await sendMessage({ + role: "user", + parts: [{ type: "text", text: message }] + }); + }; + + useEffect(() => { + if (messages.length > 0) { + scrollToBottom(); + } + }, [messages, scrollToBottom]); + + return ( +
+ {/* Header */} +
+

Chat Agent

+
+ + + +
+
+ + {/* MCP Servers Panel */} + {showMcpServers && ( +
+

MCP Servers

+
+ {Object.entries(mcpState.servers).map(([id, server]) => ( +
+
+
{server.name}
+
+ {server.server_url} +
+
+
+ {server.state} +
+
+ {server.state === "authenticating" && server.auth_url && ( + + )} +
+ ))} +
+
+ )} + + {/* Messages */} +
+ {messages.length === 0 && ( +
+
+
💬
+

Welcome to Chat

+

+ Start a conversation with your AI assistant +

+
+
+ )} + + {messages.map((message) => ( +
+
+ {message.parts + ?.filter((part) => part.type === "text") + .map((part, i) => ( +
+ {part.text} +
+ ))} +
+
+ ))} +
+
+ + {/* Input */} +
+
+ setInput(e.target.value)} + placeholder="Type your message..." + className={`flex-1 p-3 rounded-lg border ${ + theme === "dark" + ? "bg-gray-700 border-gray-600 text-white placeholder-gray-400" + : "bg-white border-gray-300 text-black placeholder-gray-500" + } focus:outline-none focus:ring-2 focus:ring-blue-500`} + /> + +
+
+
+ ); +} + +createRoot(document.getElementById("root")!).render(); diff --git a/examples/mcp-rpc-transport/src/server.ts b/examples/mcp-rpc-transport/src/server.ts new file mode 100644 index 00000000..ffb5f741 --- /dev/null +++ b/examples/mcp-rpc-transport/src/server.ts @@ -0,0 +1,172 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { McpAgent } from "agents/mcp"; +import { z } from "zod"; +import { AIChatAgent } from "agents/ai-chat-agent"; +import { + streamText, + type StreamTextOnFinishCallback, + stepCountIs, + createUIMessageStream, + convertToModelMessages, + type ToolSet, + createUIMessageStreamResponse +} from "ai"; +import { openai } from "@ai-sdk/openai"; +import { cleanupMessages } from "./utils"; +import { routeAgentRequest } from "agents"; + +type State = { + counter: number; +}; + +type Props = { + userId: string; + role: string; +}; + +type Env = { + MyMCP: DurableObjectNamespace; +}; + +export class MyMCP extends McpAgent { + server = new McpServer({ + name: "Demo", + version: "1.0.0" + }); + + initialState: State = { + counter: 1 + }; + + async init() { + this.server.tool( + "add", + "Add to the counter, stored in the MCP", + { a: z.number() }, + async ({ a }) => { + this.setState({ ...this.state, counter: this.state.counter + a }); + + return { + content: [ + { + text: String(`Added ${a}, total is now ${this.state.counter}`), + type: "text" + } + ] + }; + } + ); + + this.server.tool( + "whoami", + "Get information about the current user from props", + {}, + async () => { + const userId = this.props?.userId || "anonymous"; + const role = this.props?.role || "guest"; + + return { + content: [ + { + text: `User ID: ${userId}, Role: ${role}`, + type: "text" + } + ] + }; + } + ); + } + + onStateUpdate(state: State) { + console.log({ stateUpdate: state }); + } + + onError(_: unknown, error?: unknown): void | Promise { + console.error("MyMCP initialization error:", error); + } +} + +const model = openai("gpt-4o-2024-11-20"); + +/** + * Chat Agent implementation that handles real-time AI chat interactions + */ +export class Chat extends AIChatAgent { + async onStart(): Promise { + // Connect to MCP server via RPC with props + // In a real app, you'd get userId/role from authentication + await this.addMcpServer("test-server", this.env.MyMCP, { + transport: { + type: "rpc", + props: { userId: "demo-user-123", role: "admin" } + } + }); + } + + /** + * Handles incoming chat messages and manages the response stream + */ + async onChatMessage( + onFinish: StreamTextOnFinishCallback, + _onFinish?: { abortSignal?: AbortSignal } + ) { + const allTools = this.mcp.getAITools(); + console.log("Available tools:", Object.keys(allTools)); + + const stream = createUIMessageStream({ + execute: async ({ writer }) => { + // Clean up incomplete tool calls to prevent API errors + const cleanedMessages = cleanupMessages(this.messages); + + const result = streamText({ + system: `You are a helpful assistant. The current date and time is ${new Date().toISOString()}.\n`, + messages: convertToModelMessages(cleanedMessages), + model, + tools: allTools, + onFinish: onFinish as unknown as StreamTextOnFinishCallback< + typeof allTools + >, + stopWhen: stepCountIs(10) + }); + + writer.merge(result.toUIMessageStream()); + } + }); + + return createUIMessageStreamResponse({ stream }); + } +} + +export default { + async fetch(request: Request, env: Env, ctx: ExecutionContext) { + const url = new URL(request.url); + console.log("Incoming request:", url.pathname); + + if (url.pathname === "/check-open-ai-key") { + const hasOpenAIKey = !!process.env.OPENAI_API_KEY; + return Response.json({ + success: hasOpenAIKey + }); + } + + if (!process.env.OPENAI_API_KEY) { + console.error( + "OPENAI_API_KEY is not set, don't forget to set it locally in .dev.vars, and use `wrangler secret bulk .dev.vars` to upload it to production" + ); + } + + // external mcp inspector route + if (url.pathname.startsWith("/mcp")) { + return MyMCP.serve("/mcp", { binding: "MyMCP" }).fetch(request, env, ctx); + } + + const response = await routeAgentRequest(request, env); + if (response) { + console.log("Agent handled request"); + return response; + } + + console.log("No route matched, returning 404"); + return new Response("Not found", { status: 404 }); + } +}; diff --git a/examples/mcp-rpc-transport/src/styles.css b/examples/mcp-rpc-transport/src/styles.css new file mode 100644 index 00000000..88956bf2 --- /dev/null +++ b/examples/mcp-rpc-transport/src/styles.css @@ -0,0 +1,265 @@ +/* Reset and base styles */ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: + -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen, Ubuntu, + Cantarell, sans-serif; + line-height: 1.5; + overflow: hidden; +} + +/* Utility classes */ +.h-screen { + height: 100vh; +} +.flex { + display: flex; +} +.flex-col { + flex-direction: column; +} +.flex-1 { + flex: 1; +} +.items-center { + align-items: center; +} +.justify-between { + justify-content: space-between; +} +.justify-center { + justify-content: center; +} +.justify-end { + justify-content: flex-end; +} +.justify-start { + justify-content: flex-start; +} +.gap-2 { + gap: 0.5rem; +} +.gap-4 { + gap: 1rem; +} +.space-y-4 > * + * { + margin-top: 1rem; +} + +.p-2 { + padding: 0.5rem; +} +.p-3 { + padding: 0.75rem; +} +.p-4 { + padding: 1rem; +} +.p-8 { + padding: 2rem; +} +.px-4 { + padding-left: 1rem; + padding-right: 1rem; +} +.px-6 { + padding-left: 1.5rem; + padding-right: 1.5rem; +} +.py-2 { + padding-top: 0.5rem; + padding-bottom: 0.5rem; +} +.py-3 { + padding-top: 0.75rem; + padding-bottom: 0.75rem; +} + +.mb-2 { + margin-bottom: 0.5rem; +} +.mb-4 { + margin-bottom: 1rem; +} + +.text-xl { + font-size: 1.25rem; +} +.text-4xl { + font-size: 2.25rem; +} +.font-semibold { + font-weight: 600; +} + +.rounded-lg { + border-radius: 0.5rem; +} +.border { + border-width: 1px; +} +.border-b { + border-bottom-width: 1px; +} +.border-t { + border-top-width: 1px; +} + +.overflow-y-auto { + overflow-y: auto; +} +.overflow-hidden { + overflow: hidden; +} + +.max-w-xs { + max-width: 20rem; +} +.max-w-md { + max-width: 28rem; +} + +.text-center { + text-align: center; +} +.whitespace-pre-wrap { + white-space: pre-wrap; +} + +.focus\:outline-none:focus { + outline: none; +} +.focus\:ring-2:focus { + box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.5); +} +.focus\:ring-blue-500:focus { + box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.5); +} + +.disabled\:opacity-50:disabled { + opacity: 0.5; +} +.disabled\:cursor-not-allowed:disabled { + cursor: not-allowed; +} + +/* Color themes */ +.bg-white { + background-color: white; +} +.bg-gray-50 { + background-color: #f9fafb; +} +.bg-gray-100 { + background-color: #f3f4f6; +} +.bg-gray-200 { + background-color: #e5e7eb; +} +.bg-gray-300 { + background-color: #d1d5db; +} +.bg-gray-600 { + background-color: #4b5563; +} +.bg-gray-700 { + background-color: #374151; +} +.bg-gray-800 { + background-color: #1f2937; +} +.bg-gray-900 { + background-color: #111827; +} + +.bg-blue-500 { + background-color: #3b82f6; +} +.bg-blue-600 { + background-color: #2563eb; +} +.bg-blue-700 { + background-color: #1d4ed8; +} + +.text-white { + color: white; +} +.text-black { + color: black; +} +.text-gray-400 { + color: #9ca3af; +} +.text-gray-500 { + color: #6b7280; +} +.text-gray-600 { + color: #4b5563; +} + +.border-gray-200 { + border-color: #e5e7eb; +} +.border-gray-300 { + border-color: #d1d5db; +} +.border-gray-600 { + border-color: #4b5563; +} +.border-gray-700 { + border-color: #374151; +} + +.placeholder-gray-400::placeholder { + color: #9ca3af; +} +.placeholder-gray-500::placeholder { + color: #6b7280; +} + +/* Hover states */ +.hover\:bg-gray-300:hover { + background-color: #d1d5db; +} +.hover\:bg-gray-600:hover { + background-color: #4b5563; +} +.hover\:bg-blue-700:hover { + background-color: #1d4ed8; +} + +/* Dark mode support */ +.dark .bg-gray-900 { + background-color: #111827; +} +.dark .bg-gray-800 { + background-color: #1f2937; +} +.dark .bg-gray-700 { + background-color: #374151; +} +.dark .text-white { + color: white; +} +.dark .text-gray-400 { + color: #9ca3af; +} +.dark .border-gray-700 { + border-color: #374151; +} +.dark .border-gray-600 { + border-color: #4b5563; +} + +/* Responsive */ +@media (min-width: 1024px) { + .lg\:max-w-md { + max-width: 28rem; + } +} diff --git a/examples/mcp-rpc-transport/src/utils.ts b/examples/mcp-rpc-transport/src/utils.ts new file mode 100644 index 00000000..eae2ff3e --- /dev/null +++ b/examples/mcp-rpc-transport/src/utils.ts @@ -0,0 +1,24 @@ +import type { UIMessage } from "ai"; +import { isToolUIPart } from "ai"; + +/** + * Clean up incomplete tool calls from messages before sending to API + * Prevents API errors from interrupted or failed tool executions + */ +export function cleanupMessages(messages: UIMessage[]): UIMessage[] { + return messages.filter((message) => { + if (!message.parts) return true; + + // Filter out messages with incomplete tool calls + const hasIncompleteToolCall = message.parts.some((part) => { + if (!isToolUIPart(part)) return false; + // Remove tool calls that are still streaming or awaiting input without results + return ( + part.state === "input-streaming" || + (part.state === "input-available" && !part.output && !part.errorText) + ); + }); + + return !hasIncompleteToolCall; + }); +} diff --git a/examples/mcp-rpc-transport/tsconfig.json b/examples/mcp-rpc-transport/tsconfig.json new file mode 100644 index 00000000..9536a0f4 --- /dev/null +++ b/examples/mcp-rpc-transport/tsconfig.json @@ -0,0 +1,3 @@ +{ + "extends": "../../tsconfig.base.json" +} diff --git a/examples/mcp-rpc-transport/vite.config.ts b/examples/mcp-rpc-transport/vite.config.ts new file mode 100644 index 00000000..a606eafd --- /dev/null +++ b/examples/mcp-rpc-transport/vite.config.ts @@ -0,0 +1,10 @@ +import { cloudflare } from "@cloudflare/vite-plugin"; +import react from "@vitejs/plugin-react"; +import { defineConfig } from "vite"; + +export default defineConfig({ + plugins: [react(), cloudflare()], + server: { + port: 5174 + } +}); diff --git a/examples/mcp-rpc-transport/wrangler.jsonc b/examples/mcp-rpc-transport/wrangler.jsonc new file mode 100644 index 00000000..2b57fbb7 --- /dev/null +++ b/examples/mcp-rpc-transport/wrangler.jsonc @@ -0,0 +1,31 @@ +{ + "compatibility_date": "2025-09-24", + "compatibility_flags": ["nodejs_compat"], + "assets": { + "not_found_handling": "single-page-application", + "run_worker_first": ["/agents/*"] + }, + "durable_objects": { + "bindings": [ + { + "class_name": "MyMCP", + "name": "MyMCP" + }, + { + "class_name": "Chat", + "name": "Chat" + } + ] + }, + "main": "src/server.ts", + "migrations": [ + { + "new_sqlite_classes": ["MyMCP", "Chat"], + "tag": "v1" + } + ], + "name": "mcp-rpc-transport-demo", + "observability": { + "enabled": true + } +} diff --git a/package-lock.json b/package-lock.json index b516e752..f25d67c0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -492,6 +492,9 @@ "node": "^18 || >=20" } }, + "examples/mcp-rpc-transport": { + "name": "@cloudflare/agents-mcp-rpc-transport-demo" + }, "examples/playground": { "name": "@cloudflare/agents-playground", "version": "0.0.0", @@ -519,6 +522,10 @@ "node": "^18 || >=20" } }, + "examples/rpc-transport": { + "name": "@cloudflare/agents-rpc-transport-demo", + "extraneous": true + }, "examples/tictactoe": { "name": "@cloudflare/agents-tictactoe", "version": "0.0.0", @@ -1851,6 +1858,10 @@ "resolved": "examples/mcp", "link": true }, + "node_modules/@cloudflare/agents-mcp-rpc-transport-demo": { + "resolved": "examples/mcp-rpc-transport", + "link": true + }, "node_modules/@cloudflare/agents-openai-basic": { "resolved": "openai-sdk/basic", "link": true diff --git a/packages/agents/src/index.ts b/packages/agents/src/index.ts index 89bee0c0..112fb45b 100644 --- a/packages/agents/src/index.ts +++ b/packages/agents/src/index.ts @@ -23,12 +23,25 @@ import { } from "partyserver"; import { camelCaseToKebabCase } from "./client"; import { MCPClientManager, type MCPClientOAuthResult } from "./mcp/client"; -import type { MCPConnectionState } from "./mcp/client-connection"; +import type { + MCPConnectionState, + MCPTransportOptions +} from "./mcp/client-connection"; import { DurableObjectOAuthClientProvider } from "./mcp/do-oauth-client-provider"; -import type { TransportType } from "./mcp/types"; +import type { + McpConnectionConfig, + McpHttpConnectionConfig, + McpRpcConnectionConfig, + TransportType +} from "./mcp/types"; import { genericObservability, type Observability } from "./observability"; import { DisposableStore } from "./core/events"; import { MessageType } from "./ai-types"; +import type { + McpAgent, + RpcConnectionOptions, + HttpConnectionOptions +} from "./mcp"; export type { Connection, ConnectionContext, WSMessage } from "partyserver"; @@ -665,18 +678,19 @@ export class Agent< }); servers.forEach((server) => { - this._connectToMcpServerInternal( - server.name, - server.server_url, - server.callback_url, - server.server_options + this._connectToMcpServerInternal({ + type: "http", + serverName: server.name, + url: server.server_url, + callbackUrl: server.callback_url, + options: server.server_options ? JSON.parse(server.server_options) : undefined, - { + reconnect: { id: server.id, oauthClientId: server.client_id ?? undefined } - ) + }) .then(() => { // Broadcast updated MCP servers state after each server connects this.broadcastMcpServers(); @@ -1405,29 +1419,148 @@ export class Agent< } /** - * Connect to a new MCP Server + * Validates and resolves a Durable Object binding from env + * @returns The namespace and binding name for storage + */ + private _resolveRpcBinding( + binding: DurableObjectNamespace | string + ): DurableObjectNamespace { + if (!binding) { + throw new Error(`Expected binding, received: '${binding}`); + } + + let namespace: DurableObjectNamespace; + if (typeof binding === "string") { + namespace = this.env[ + binding as keyof typeof this.env + ] as unknown as DurableObjectNamespace; + } else { + namespace = binding; + } + + if (!namespace || typeof namespace.get !== "function") { + throw new Error( + `Expected DurableObjectNamespace or binding name string, received: ${namespace}` + ); + } + + return namespace; + } + + /** + * Connect to an MCP server via RPC or HTTP/SSE transport + * + * RPC Transport (internal Durable Object communication): + * @param serverName Name of the MCP server + * @param binding Durable Object binding (e.g., env.MyMCP) or binding name (e.g., "MyMCP") + * @param options Options with transport.type = "rpc" + * @returns id (authUrl is undefined for RPC connections) * + * HTTP/SSE Transport (remote servers): * @param serverName Name of the MCP server - * @param url MCP Server SSE URL - * @param callbackHost Base host for the agent, used for the redirect URI. If not provided, will be derived from the current request. - * @param agentsPrefix agents routing prefix if not using `agents` + * @param url MCP Server URL (e.g., "https://example.com/mcp") + * @param callbackHost Base host for the agent, used for OAuth redirect URI + * @param agentsPrefix agents routing prefix (default: "agents") * @param options MCP client and transport options - * @returns authUrl + * @returns id and authUrl for OAuth flow + * + * @example + * ```typescript + * // RPC transport + * await agent.addMcpServer("my-server", env.MyMCP, { transport: { type: "rpc" } }); + * + * // RPC transport with props (pass user context to MCP server) + * await agent.addMcpServer("my-server", env.MyMCP, { + * transport: { type: "rpc" }, + * props: { userId: "user-123", role: "admin" } + * }); + * + * // HTTP/SSE transport + * await agent.addMcpServer("my-server", "https://example.com/mcp", "https://my-app.com"); + * ``` */ + async addMcpServer< + T extends McpAgent> = McpAgent + >( + serverName: string, + binding: DurableObjectNamespace | string, + options: RpcConnectionOptions + ): Promise<{ id: string; authUrl: undefined }>; + + // Overload for HTTP/SSE transport async addMcpServer( serverName: string, url: string, callbackHost?: string, + agentsPrefix?: string, + options?: HttpConnectionOptions + ): Promise<{ id: string; authUrl: string | undefined }>; + + async addMcpServer< + T extends McpAgent> = McpAgent + >( + serverName: string, + urlOrBinding: string | DurableObjectNamespace, + callbackHostOrOptions?: string | RpcConnectionOptions, agentsPrefix = "agents", - options?: { - client?: ConstructorParameters[1]; - transport?: { - headers?: HeadersInit; - type?: TransportType; - }; - } + options?: HttpConnectionOptions ): Promise<{ id: string; authUrl: string | undefined }> { - // If callbackHost is not provided, derive it from the current request + // Determine if this is RPC or HTTP based on parameters + const isRpc = + typeof callbackHostOrOptions === "object" && + callbackHostOrOptions.transport?.type === "rpc"; + + if (isRpc) { + // RPC transport path + const rpcOptions = callbackHostOrOptions as RpcConnectionOptions; + + const namespace = this._resolveRpcBinding( + urlOrBinding as DurableObjectNamespace | string + ); + + const normalizedName = serverName.toLowerCase().replace(/\s+/g, "-"); + const url = `rpc://${normalizedName}`; + + // Check if server already exists in database for reconnection + const existingServer = this.sql` + SELECT id FROM cf_agents_mcp_servers WHERE server_url = ${url} LIMIT 1; + `.at(0); + + const reconnect = existingServer ? { id: existingServer.id } : undefined; + + // Connect to server + const result = await this._connectToMcpServerInternal({ + type: "rpc", + namespace, + url, + normalizedName, + options: rpcOptions, + reconnect + }); + + // Persist to database for reconnection purposes + this.sql` + INSERT OR REPLACE INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + VALUES ( + ${result.id}, + ${normalizedName}, + ${url}, + ${result.clientId ?? null}, + ${result.authUrl ?? null}, + ${"rpc://internal"}, + ${rpcOptions ? JSON.stringify(rpcOptions) : null} + ); + `; + + this.broadcastMcpServers(); + + return result as { id: string; authUrl: undefined }; + } + + // HTTP/SSE transport path (original implementation) + const url = urlOrBinding as string; + const callbackHost = callbackHostOrOptions as string | undefined; + // Resolve callback host if not provided let resolvedCallbackHost = callbackHost; if (!resolvedCallbackHost) { const { request } = getCurrentAgent(); @@ -1436,49 +1569,34 @@ export class Agent< "callbackHost is required when not called within a request context" ); } - - // Extract the origin from the request const requestUrl = new URL(request.url); resolvedCallbackHost = `${requestUrl.protocol}//${requestUrl.host}`; } - const callbackUrl = `${resolvedCallbackHost}/${agentsPrefix}/${camelCaseToKebabCase(this._ParentClass.name)}/${this.name}/callback`; - // Generate a serverId upfront - const serverId = nanoid(8); + // Connect to server + const result = await this._connectToMcpServerInternal({ + type: "http", + serverName, + url, + callbackUrl, + options + }); - // Persist to database BEFORE starting OAuth flow to survive DO hibernation + // Persist to database for reconnection purposes this.sql` - INSERT OR REPLACE INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) + INSERT OR REPLACE INTO cf_agents_mcp_servers (id, name, server_url, client_id, auth_url, callback_url, server_options) VALUES ( - ${serverId}, + ${result.id}, ${serverName}, ${url}, - ${null}, - ${null}, + ${result.clientId ?? null}, + ${result.authUrl ?? null}, ${callbackUrl}, ${options ? JSON.stringify(options) : null} - ); + ); `; - // _connectToMcpServerInternal will call mcp.connect which registers the callback URL - const result = await this._connectToMcpServerInternal( - serverName, - url, - callbackUrl, - options, - { id: serverId } - ); - - // Update database with OAuth client info if auth is required - if (result.clientId || result.authUrl) { - this.sql` - UPDATE cf_agents_mcp_servers - SET client_id = ${result.clientId ?? null}, auth_url = ${result.authUrl ?? null} - WHERE id = ${serverId} - `; - } - this.broadcastMcpServers(); return result; @@ -1572,16 +1690,17 @@ export class Agent< ); } - await this._connectToMcpServerInternal( - server.name, - server.server_url, - server.callback_url, - parsedOptions, - { + await this._connectToMcpServerInternal({ + type: "http", + serverName: server.name, + url: server.server_url, + callbackUrl: server.callback_url, + options: parsedOptions, + reconnect: { id: server.id, oauthClientId: server.client_id ?? undefined } - ); + }); } // Now process the OAuth callback @@ -1620,34 +1739,29 @@ export class Agent< return this.handleOAuthCallbackResponse(result, request); } - private async _connectToMcpServerInternal( - _serverName: string, - url: string, - callbackUrl: string, - // it's important that any options here are serializable because we put them into our sqlite DB for reconnection purposes - options?: { - client?: ConstructorParameters[1]; - /** - * We don't expose the normal set of transport options because: - * 1) we can't serialize things like the auth provider or a fetch function into the DB for reconnection purposes - * 2) We probably want these options to be agnostic to the transport type (SSE vs Streamable) - * - * This has the limitation that you can't override fetch, but I think headers should handle nearly all cases needed (i.e. non-standard bearer auth). - */ - transport?: { - headers?: HeadersInit; - type?: TransportType; - }; - }, - reconnect?: { - id: string; - oauthClientId?: string; - } - ): Promise<{ - id: string; - authUrl: string | undefined; - clientId: string | undefined; - }> { + private async _buildRpcTransportOptions( + config: McpRpcConnectionConfig + ): Promise { + const { normalizedName, namespace, options } = config; + + const doName = `rpc:${normalizedName}`; + const doId = namespace.idFromName(doName); + const stub = namespace.get(doId) as unknown as DurableObjectStub; + + return { + type: "rpc", + stub, + functionName: options?.transport?.functionName, + doName, + props: options?.transport?.props + }; + } + + private _buildHttpTransportOptions( + config: McpHttpConnectionConfig + ): MCPTransportOptions { + const { callbackUrl, options, reconnect } = config; + const authProvider = new DurableObjectOAuthClientProvider( this.ctx.storage, this.name, @@ -1661,44 +1775,51 @@ export class Agent< } } - // Use the transport type specified in options, or default to "auto" const transportType: TransportType = options?.transport?.type ?? "auto"; - - // allows passing through transport headers if necessary - // this handles some non-standard bearer auth setups (i.e. MCP server behind CF access instead of OAuth) - let headerTransportOpts: SSEClientTransportOptions = {}; - if (options?.transport?.headers) { - headerTransportOpts = { - eventSourceInit: { - fetch: (url, init) => - fetch(url, { - ...init, - headers: options?.transport?.headers - }) - }, - requestInit: { - headers: options?.transport?.headers + const headerTransportOpts: SSEClientTransportOptions = options?.transport + ?.headers + ? { + eventSourceInit: { + fetch: (url, init) => + fetch(url, { + ...init, + headers: options.transport!.headers + }) + }, + requestInit: { + headers: options.transport!.headers + } } - }; - } - - const { id, authUrl, clientId } = await this.mcp.connect(url, { - client: options?.client, - reconnect, - transport: { - ...headerTransportOpts, - authProvider, - type: transportType - } - }); + : {}; return { - authUrl, - clientId, - id + ...headerTransportOpts, + authProvider, + type: transportType }; } + private async _connectToMcpServerInternal( + config: McpConnectionConfig + ): Promise<{ + id: string; + authUrl: string | undefined; + clientId: string | undefined; + }> { + const transportOptions = + config.type === "rpc" + ? await this._buildRpcTransportOptions(config) + : this._buildHttpTransportOptions(config); + + const { id, authUrl, clientId } = await this.mcp.connect(config.url, { + client: config.options?.client, + reconnect: config.reconnect, + transport: transportOptions + }); + + return { id, authUrl, clientId }; + } + async removeMcpServer(id: string) { this.mcp.closeConnection(id); this.mcp.unregisterCallbackUrl(id); diff --git a/packages/agents/src/mcp/client-connection.ts b/packages/agents/src/mcp/client-connection.ts index 2536c241..85d263b9 100644 --- a/packages/agents/src/mcp/client-connection.ts +++ b/packages/agents/src/mcp/client-connection.ts @@ -33,7 +33,13 @@ import { } from "./errors"; import { SSEEdgeClientTransport } from "./sse-edge"; import { StreamableHTTPEdgeClientTransport } from "./streamable-http-edge"; -import type { BaseTransportType, TransportType } from "./types"; +import { RPCClientTransport, type RPCClientTransportOptions } from "./rpc"; +import type { + BaseTransportType, + TransportType, + McpClientOptions, + HttpTransportType +} from "./types"; /** * Connection state for MCP client connections @@ -45,11 +51,18 @@ export type MCPConnectionState = | "discovering" | "failed"; +/** + * Transport options for MCP client connections. + * Combines transport-specific options with auth provider and type selection. + */ export type MCPTransportOptions = ( | SSEClientTransportOptions | StreamableHTTPClientTransportOptions + | RPCClientTransportOptions ) & { + /** Optional OAuth provider for authenticating with the MCP server */ authProvider?: AgentsOAuthProvider; + /** The transport type to use. "auto" will try streamable-http, then fall back to SSE */ type?: TransportType; }; @@ -73,7 +86,7 @@ export class MCPClientConnection { info: ConstructorParameters[0], public options: { transport: MCPTransportOptions; - client: ConstructorParameters[1]; + client: McpClientOptions; } = { client: {}, transport: {} } ) { const clientOptions = { @@ -142,11 +155,20 @@ export class MCPClientConnection { throw new Error("Transport type must be specified"); } - const finishAuth = async (base: BaseTransportType) => { + const finishAuth = async (base: HttpTransportType) => { const transport = this.getTransport(base); - await transport.finishAuth(code); + if ( + "finishAuth" in transport && + typeof transport.finishAuth === "function" + ) { + await transport.finishAuth(code); + } }; + if (configuredType === "rpc") { + throw new Error("RPC transport does not support authentication"); + } + if (configuredType === "sse" || configuredType === "streamable-http") { await finishAuth(configuredType); return; @@ -434,6 +456,10 @@ export class MCPClientConnection { this.url, this.options.transport as SSEClientTransportOptions ); + case "rpc": + return new RPCClientTransport( + this.options.transport as RPCClientTransportOptions + ); default: throw new Error(`Unsupported transport type: ${transportType}`); } diff --git a/packages/agents/src/mcp/client.ts b/packages/agents/src/mcp/client.ts index 202383ab..66988d40 100644 --- a/packages/agents/src/mcp/client.ts +++ b/packages/agents/src/mcp/client.ts @@ -476,10 +476,11 @@ export class MCPClientManager { | typeof CompatibilityCallToolResultSchema, options?: RequestOptions ) { - const unqualifiedName = params.name.replace(`${params.serverId}.`, ""); - return this.mcpConnections[params.serverId].client.callTool( + const { serverId, ...mcpParams } = params; + const unqualifiedName = mcpParams.name.replace(`${serverId}.`, ""); + return this.mcpConnections[serverId].client.callTool( { - ...params, + ...mcpParams, name: unqualifiedName }, resultSchema, diff --git a/packages/agents/src/mcp/index.ts b/packages/agents/src/mcp/index.ts index 87e51de4..212d7b81 100644 --- a/packages/agents/src/mcp/index.ts +++ b/packages/agents/src/mcp/index.ts @@ -20,6 +20,7 @@ import { MCP_MESSAGE_HEADER } from "./utils"; import { McpSSETransport, StreamableHTTPServerTransport } from "./transport"; +import { RPCServerTransport, type RPCServerTransportOptions } from "./rpc"; export abstract class McpAgent< Env = unknown, @@ -45,8 +46,8 @@ export abstract class McpAgent< } /** Read the transport type for this agent. - * This relies on the naming scheme being `sse:${sessionId}` - * or `streamable-http:${sessionId}`. + * This relies on the naming scheme being `sse:${sessionId}`, + * `streamable-http:${sessionId}`, or `rpc:${sessionId}`. */ getTransportType(): BaseTransportType { const [t, ..._] = this.name.split(":"); @@ -55,6 +56,8 @@ export abstract class McpAgent< return "sse"; case "streamable-http": return "streamable-http"; + case "rpc": + return "rpc"; default: throw new Error( "Invalid transport type. McpAgent must be addressed with a valid protocol." @@ -85,6 +88,23 @@ export abstract class McpAgent< return websockets[0]; } + /** + * Returns options for configuring the RPC server transport. + * Override this method to customize RPC transport behavior (e.g., timeout). + * + * @example + * ```typescript + * class MyMCP extends McpAgent { + * protected getRpcTransportOptions() { + * return { timeout: 120000 }; // 2 minutes + * } + * } + * ``` + */ + protected getRpcTransportOptions(): RPCServerTransportOptions { + return {}; + } + /** Returns a new transport matching the type of the Agent. */ private initTransport() { switch (this.getTransportType()) { @@ -94,6 +114,9 @@ export abstract class McpAgent< case "streamable-http": { return new StreamableHTTPServerTransport({}); } + case "rpc": { + return new RPCServerTransport(this.getRpcTransportOptions()); + } } } @@ -114,7 +137,7 @@ export abstract class McpAgent< } /* - * Base Agent / Parykit Server overrides + * Base Agent / Partykit Server overrides */ /** Sets up the MCP transport and server every time the Agent is started.*/ @@ -127,7 +150,12 @@ export abstract class McpAgent< const server = await this.server; // Connect to the MCP server this._transport = this.initTransport(); + + if (!this._transport) { + throw new Error("Failed to initialize transport"); + } await server.connect(this._transport); + await this.reinitializeServer(); } @@ -350,6 +378,40 @@ export abstract class McpAgent< return false; } + /** + * Handle an RPC message for MCP + * This method is called by the RPC stub to process MCP messages + * @param message The JSON-RPC message(s) to handle + * @returns The response message(s) or undefined + */ + async handleMcpMessage( + message: JSONRPCMessage | JSONRPCMessage[] + ): Promise { + if (!this._transport) { + this.props = await this.ctx.storage.get("props"); + + // Re-run init() to register tools on the server + await this.init(); + const server = await this.server; + + this._transport = this.initTransport(); + + if (!this._transport) { + throw new Error("Failed to initialize transport"); + } + await server.connect(this._transport); + + // Reinitialize the server with any stored initialize request + await this.reinitializeServer(); + } + + if (!(this._transport instanceof RPCServerTransport)) { + throw new Error("Expected RPC transport"); + } + + return await this._transport.handle(message); + } + /** Return a handler for the given path for this MCP. * Defaults to Streamable HTTP transport. */ @@ -436,6 +498,13 @@ export abstract class McpAgent< // Export client transport classes export { SSEEdgeClientTransport } from "./sse-edge"; export { StreamableHTTPEdgeClientTransport } from "./streamable-http-edge"; +export { + RPCClientTransport, + RPCServerTransport, + type MCPStub, + type RPCClientTransportOptions, + type RPCServerTransportOptions +} from "./rpc"; // Export elicitation types and schemas export { @@ -449,3 +518,12 @@ export type { MCPClientOAuthResult, MCPClientOAuthCallbackConfig } from "./client"; + +// Export connection configuration types +export type { + RpcConnectionOptions, + HttpConnectionOptions, + McpClientOptions, + RpcTransportOptions, + HttpTransportOptions +} from "./types"; diff --git a/packages/agents/src/mcp/rpc.ts b/packages/agents/src/mcp/rpc.ts new file mode 100644 index 00000000..38015b99 --- /dev/null +++ b/packages/agents/src/mcp/rpc.ts @@ -0,0 +1,608 @@ +import type { + Transport, + TransportSendOptions +} from "@modelcontextprotocol/sdk/shared/transport.js"; +import type { + JSONRPCMessage, + MessageExtraInfo +} from "@modelcontextprotocol/sdk/types.js"; +import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; + +/** + * Validates a JSON-RPC 2.0 batch request + * @see JSON-RPC 2.0 spec section 6 + */ +function validateJSONRPCBatch(batch: unknown): batch is JSONRPCMessage[] { + if (!Array.isArray(batch)) { + throw new Error("Invalid JSON-RPC batch: must be an array"); + } + + // Spec: "an Array with at least one value" + if (batch.length === 0) { + throw new Error("Invalid JSON-RPC batch: array must not be empty"); + } + + // Validate each message in the batch + for (let i = 0; i < batch.length; i++) { + try { + validateJSONRPCMessage(batch[i]); + } catch (error) { + throw new Error( + `Invalid JSON-RPC batch: message at index ${i} is invalid: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + + return true; +} + +/** + * Validates that a message conforms to JSON-RPC 2.0 specification + * @see https://www.jsonrpc.org/specification + * @see /packages/agents/src/mcp/json-rpc-spec.md + */ +function validateJSONRPCMessage(message: unknown): message is JSONRPCMessage { + if (!message || typeof message !== "object") { + throw new Error("Invalid JSON-RPC message: must be an object"); + } + + const msg = message as Record; + + // Spec line 26: jsonrpc MUST be exactly "2.0" + if (msg.jsonrpc !== "2.0") { + throw new Error('Invalid JSON-RPC message: jsonrpc field must be "2.0"'); + } + + // Check if it's a request/notification (has method field) + if ("method" in msg) { + // Spec line 27-28: method MUST be a String + if (typeof msg.method !== "string") { + throw new Error("Invalid JSON-RPC request: method must be a string"); + } + + // Spec line 28: Method names starting with "rpc." are reserved + if (msg.method.startsWith("rpc.")) { + throw new Error( + 'Invalid JSON-RPC request: method names starting with "rpc." are reserved for internal methods' + ); + } + + // Spec line 31-32: id MAY be omitted (notification), but if included MUST be String, Number, or NULL + if ( + "id" in msg && + msg.id !== null && + typeof msg.id !== "string" && + typeof msg.id !== "number" + ) { + throw new Error( + "Invalid JSON-RPC request: id must be string, number, or null" + ); + } + + // Spec line 32: Warn about fractional numbers in id (SHOULD NOT have fractional parts) + if (typeof msg.id === "number" && !Number.isInteger(msg.id)) { + console.warn("JSON-RPC warning: id should not contain fractional parts"); + } + + // Spec line 29-30, 45-48: params MAY be omitted, but if present MUST be Array or Object (Structured value) + if ("params" in msg && msg.params !== undefined) { + const params = msg.params; + if (params !== null && typeof params !== "object") { + throw new Error( + "Invalid JSON-RPC request: params must be an array or object" + ); + } + // params can be an object or array, but not other types + if ( + params === null || + (typeof params === "object" && + !Array.isArray(params) && + Object.getPrototypeOf(params) !== Object.prototype) + ) { + throw new Error( + "Invalid JSON-RPC request: params must be an array or object" + ); + } + } + + return true; + } + + // Check if it's a response (has id but no method) + if ("id" in msg) { + // Spec line 63: id is REQUIRED in responses + // Spec line 64-65: id MUST be same as request, or NULL on parse/invalid request error + if ( + msg.id !== null && + typeof msg.id !== "string" && + typeof msg.id !== "number" + ) { + throw new Error( + "Invalid JSON-RPC response: id must be string, number, or null" + ); + } + + // Spec line 66: Either result or error MUST be included, but both MUST NOT be included + const hasResult = "result" in msg; + const hasError = "error" in msg; + + if (!hasResult && !hasError) { + throw new Error( + "Invalid JSON-RPC response: must have either result or error" + ); + } + + if (hasResult && hasError) { + throw new Error( + "Invalid JSON-RPC response: cannot have both result and error" + ); + } + + // Spec line 68-80: Validate error object structure if present + if (hasError) { + const error = msg.error as Record; + if (!error || typeof error !== "object") { + throw new Error("Invalid JSON-RPC error: error must be an object"); + } + // Spec line 71-73: code MUST be a Number (integer) + if (typeof error.code !== "number") { + throw new Error("Invalid JSON-RPC error: code must be a number"); + } + if (!Number.isInteger(error.code)) { + throw new Error("Invalid JSON-RPC error: code must be an integer"); + } + // Spec line 74-76: message MUST be a String + if (typeof error.message !== "string") { + throw new Error("Invalid JSON-RPC error: message must be a string"); + } + // Spec line 77-80: data MAY be omitted, but if present can be any Primitive or Structured value + // (no validation needed - any type is allowed) + } + + return true; + } + + throw new Error( + "Invalid JSON-RPC message: must have either method (request/notification) or id (response)" + ); +} + +/** + * Type for RPC handler functions that can process MCP messages + */ +export type MCPMessageHandler = ( + message: JSONRPCMessage | JSONRPCMessage[] +) => Promise; + +/** + * Base interface for objects that can handle MCP messages via RPC + */ +export interface MCPStub extends Record { + handleMcpMessage: MCPMessageHandler; + setName?(name: string): Promise; + updateProps?(props: Record): Promise; +} + +export interface RPCClientTransportOptions { + stub: MCPStub; + functionName?: string; + doName?: string; + props?: Record; +} + +export class RPCClientTransport implements Transport { + private _stub: MCPStub; + private _functionName: string; + private _doName?: string; + private _props?: Record; + private _propsInitialized = false; + private _started = false; + private _protocolVersion?: string; + + sessionId?: string; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + + constructor(options: RPCClientTransportOptions) { + this._stub = options.stub; + this._functionName = options.functionName ?? "handleMcpMessage"; + this._doName = options.doName; + this._props = options.props; + } + + setProtocolVersion(version: string): void { + this._protocolVersion = version; + } + + getProtocolVersion(): string | undefined { + return this._protocolVersion; + } + + async start(): Promise { + if (this._started) { + throw new Error("Transport already started"); + } + this._started = true; + } + + async close(): Promise { + this._started = false; + this.onclose?.(); + } + + async send( + message: JSONRPCMessage | JSONRPCMessage[], + options?: TransportSendOptions + ): Promise { + if (!this._started) { + throw new Error("Transport not started"); + } + + // Validate batch or single message + if (Array.isArray(message)) { + validateJSONRPCBatch(message); + } else { + validateJSONRPCMessage(message); + } + + // Set the name if the stub is a DO + if (this._doName && this._stub.setName) { + await this._stub.setName(this._doName); + } + + // Initialize props on first send + if (this._props && !this._propsInitialized && this._stub.updateProps) { + await this._stub.updateProps(this._props); + this._propsInitialized = true; + } + + try { + const handler = this._stub[this._functionName] as MCPMessageHandler; + const result = await handler(message); + + if (!result) { + return; + } + + // Prepare MessageExtraInfo if relatedRequestId is provided + const extra: MessageExtraInfo | undefined = options?.relatedRequestId + ? { requestInfo: { headers: {} } } + : undefined; + + if (Array.isArray(result)) { + for (const msg of result) { + validateJSONRPCMessage(msg); + this.onmessage?.(msg, extra); + } + } else { + validateJSONRPCMessage(result); + this.onmessage?.(result, extra); + } + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + } +} + +/** + * Configuration options for RPCServerTransport + * + * Session Management: + * - Stateless mode (default): No sessionIdGenerator provided. All requests are accepted without validation. + * - Stateful mode: When sessionIdGenerator is provided, the server enforces session initialization. + * - Clients must send an initialize request first to establish a session + * - All subsequent requests are validated to ensure session is initialized + * - Session ID is generated during initialization and available via transport.sessionId + */ +export interface RPCServerTransportOptions { + /** + * Function that generates a session ID for the transport. + * The session ID SHOULD be globally unique and cryptographically secure (e.g., a securely generated UUID, a JWT, or a cryptographic hash) + * + * When provided, enables stateful session management: + * - Session is created during MCP initialization request + * - Non-initialization requests will be rejected until session is initialized + * - Session can be terminated via terminateSession() or transport.close() + * + * When omitted, transport operates in stateless mode (no session validation). + */ + sessionIdGenerator?: (() => string) | undefined; + + /** + * A callback for session initialization events. + * Called after a session ID is generated during MCP initialization. + * + * @param sessionId The generated session ID + * + * @example + * ```typescript + * const transport = new RPCServerTransport({ + * sessionIdGenerator: () => crypto.randomUUID(), + * onsessioninitialized: async (sessionId) => { + * console.log(`Session ${sessionId} initialized`); + * await database.createSession(sessionId); + * } + * }); + * ``` + */ + onsessioninitialized?: (sessionId: string) => void | Promise; + + /** + * A callback for session close events. + * Called when the session is terminated via terminateSession() or transport.close(). + * + * @param sessionId The session ID that was closed + * + * @example + * ```typescript + * const transport = new RPCServerTransport({ + * sessionIdGenerator: () => crypto.randomUUID(), + * onsessionclosed: async (sessionId) => { + * console.log(`Session ${sessionId} closed`); + * await database.deleteSession(sessionId); + * } + * }); + * ``` + */ + onsessionclosed?: (sessionId: string) => void | Promise; + + /** + * Timeout in milliseconds for waiting for a response from the onmessage handler. + * If the handler doesn't call send() within this time, the request will fail with a timeout error. + * + * @default 60000 (60 seconds) + * + * @example + * ```typescript + * const transport = new RPCServerTransport({ + * timeout: 30000 // 30 seconds + * }); + * ``` + */ + timeout?: number; +} + +export class RPCServerTransport implements Transport { + private _started = false; + private _pendingResponse: JSONRPCMessage | JSONRPCMessage[] | null = null; + private _responseResolver: (() => void) | null = null; + private _currentRequestId: string | number | null = null; + private _protocolVersion?: string; + private _sessionIdGenerator: (() => string) | undefined; + private _initialized = false; + private _onsessioninitialized?: (sessionId: string) => void | Promise; + private _onsessionclosed?: (sessionId: string) => void | Promise; + private _timeout: number; + + sessionId?: string; + onclose?: () => void; + onerror?: (error: Error) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + + constructor(options?: RPCServerTransportOptions) { + this._sessionIdGenerator = options?.sessionIdGenerator; + this._onsessioninitialized = options?.onsessioninitialized; + this._onsessionclosed = options?.onsessionclosed; + this._timeout = options?.timeout ?? 60000; // Default 60 seconds + } + + setProtocolVersion(version: string): void { + this._protocolVersion = version; + } + + getProtocolVersion(): string | undefined { + return this._protocolVersion; + } + + async start(): Promise { + if (this._started) { + throw new Error("Transport already started"); + } + this._started = true; + } + + async close(): Promise { + this._started = false; + + // Terminate session if it exists + await this.terminateSession(); + + this.onclose?.(); + // Resolve any pending response promises + if (this._responseResolver) { + this._responseResolver(); + this._responseResolver = null; + } + this._currentRequestId = null; + } + + async send( + message: JSONRPCMessage, + _options?: TransportSendOptions + ): Promise { + if (!this._started) { + throw new Error("Transport not started"); + } + + validateJSONRPCMessage(message); + + // Validate response IDs match the request ID (JSON-RPC 2.0 spec section 5) + const isResponse = "id" in message && !("method" in message); + if (isResponse && this._currentRequestId !== null) { + const responseId = (message as { id: string | number | null }).id; + if (responseId !== this._currentRequestId) { + throw new Error( + `Response ID ${responseId} does not match request ID ${this._currentRequestId} (JSON-RPC 2.0 spec section 5)` + ); + } + } + + if (!this._pendingResponse) { + this._pendingResponse = message; + } else if (Array.isArray(this._pendingResponse)) { + this._pendingResponse.push(message); + } else { + this._pendingResponse = [this._pendingResponse, message]; + } + + // Resolve the promise on the next tick to allow multiple send() calls to accumulate + if (this._responseResolver) { + const resolver = this._responseResolver; + // Use queueMicrotask to allow additional send() calls to accumulate + // Resolver is reused for concurrent sends within the same tick + queueMicrotask(() => resolver()); + } + } + + /** + * Validates that the session is initialized for non-initialization requests + */ + private _validateSession(message: JSONRPCMessage): void { + // If we're in stateless mode (no session ID generator), skip validation + if (!this._sessionIdGenerator) { + return; + } + + // If this is an initialization request, don't validate session yet + if (isInitializeRequest(message)) { + return; + } + + // For all other requests, ensure the session is initialized + if (!this._initialized) { + throw new Error( + "Session not initialized. An initialize request must be sent first." + ); + } + } + + /** + * Terminates the current session and calls the session closed callback + */ + async terminateSession(): Promise { + if (this.sessionId && this._onsessionclosed) { + await this._onsessionclosed(this.sessionId); + } + this._initialized = false; + this.sessionId = undefined; + } + + async handle( + message: JSONRPCMessage | JSONRPCMessage[] + ): Promise { + if (!this._started) { + throw new Error("Transport not started"); + } + + // Handle batch requests (JSON-RPC 2.0 spec section 6) + if (Array.isArray(message)) { + validateJSONRPCBatch(message); + + const responses: JSONRPCMessage[] = []; + + // Process each message in the batch + for (const msg of message) { + const response = await this.handle(msg); + // Spec: "A Response object SHOULD exist for each Request object, + // except that there SHOULD NOT be any Response objects for notifications" + if (response !== undefined) { + if (Array.isArray(response)) { + responses.push(...response); + } else { + responses.push(response); + } + } + } + + // Spec: "If there are no Response objects contained within the Response array + // as it is to be sent to the client, the server MUST NOT return an empty Array + // and should return nothing at all" + if (responses.length === 0) { + return undefined; + } + + return responses; + } + + // Handle single message + validateJSONRPCMessage(message); + + // Session management: validate session for non-initialization requests + this._validateSession(message); + + // Session management: handle initialization requests + if (isInitializeRequest(message) && this._sessionIdGenerator) { + if (this._initialized) { + throw new Error("Session already initialized"); + } + + // Generate session ID + this.sessionId = this._sessionIdGenerator(); + this._initialized = true; + + // Call session initialized callback + if (this._onsessioninitialized) { + await this._onsessioninitialized(this.sessionId); + } + } + + this._pendingResponse = null; + + const isNotification = !("id" in message); + if (isNotification) { + // notifications do not get responses + this.onmessage?.(message); + return undefined; + } + + // Store the request ID to validate responses (JSON-RPC 2.0 spec section 5) + this._currentRequestId = (message as { id: string | number | null }).id; + + // Set up the promise before calling onmessage to handle race conditions + let timeoutId: ReturnType | null = null; + const responsePromise = new Promise((resolve, reject) => { + // Set up timeout + timeoutId = setTimeout(() => { + this._responseResolver = null; + reject( + new Error( + `Request timeout: No response received within ${this._timeout}ms for request ID ${this._currentRequestId}` + ) + ); + }, this._timeout); + + // Wrap the resolver to clear timeout when response is received + // Note: Don't null out here - send() needs it to remain available for concurrent calls + this._responseResolver = () => { + if (timeoutId) { + clearTimeout(timeoutId); + timeoutId = null; + } + // Null out after resolution to prevent reuse across different requests + this._responseResolver = null; + resolve(); + }; + }); + + this.onmessage?.(message); + + // Wait for a response using a promise that resolves when send() is called + try { + await responsePromise; + } catch (error) { + // Clean up on timeout + this._pendingResponse = null; + this._currentRequestId = null; + this._responseResolver = null; + throw error; + } + + const response = this._pendingResponse; + this._pendingResponse = null; + this._currentRequestId = null; + + return response ?? undefined; + } +} diff --git a/packages/agents/src/mcp/types.ts b/packages/agents/src/mcp/types.ts index 18d3f74e..c1628024 100644 --- a/packages/agents/src/mcp/types.ts +++ b/packages/agents/src/mcp/types.ts @@ -1,7 +1,11 @@ +import type { Client } from "@modelcontextprotocol/sdk/client"; +import type { McpAgent } from "."; + export type MaybePromise = T | Promise; export type MaybeConnectionTag = { role: string } | undefined; -export type BaseTransportType = "sse" | "streamable-http"; +export type HttpTransportType = "sse" | "streamable-http"; +export type BaseTransportType = HttpTransportType | "rpc"; export type TransportType = BaseTransportType | "auto"; export interface CORSOptions { @@ -15,5 +19,92 @@ export interface CORSOptions { export interface ServeOptions { binding?: string; corsOptions?: CORSOptions; - transport?: BaseTransportType; + transport?: HttpTransportType; +} + +/** + * Client options passed to the MCP SDK Client constructor + */ +export type McpClientOptions = ConstructorParameters[1]; + +/** + * Transport configuration for RPC connections + */ +export interface RpcTransportOptions< + T extends McpAgent> = McpAgent +> { + /** The transport type (must be "rpc") */ + type?: "rpc"; + /** Optional custom function name on the Durable Object stub (defaults to "handleMcpMessage") */ + functionName?: string; + /** Props to pass to the McpAgent instance */ + props?: T extends McpAgent ? Props : never; +} + +/** + * Transport configuration for HTTP-based connections (SSE or Streamable HTTP) + */ +export interface HttpTransportOptions { + /** The transport type to use. "auto" will try streamable-http, then fall back to SSE */ + type?: TransportType; + /** Additional headers to include in HTTP requests */ + headers?: HeadersInit; +} + +/** + * Options for RPC connection configuration + */ +export interface RpcConnectionOptions< + T extends McpAgent> = McpAgent +> { + /** Transport-specific options for RPC connections */ + transport?: RpcTransportOptions; + /** Client options passed to the MCP SDK Client */ + client?: McpClientOptions; } + +/** + * Options for HTTP/SSE connection configuration + */ +export interface HttpConnectionOptions { + /** Transport-specific options for HTTP connections */ + transport?: HttpTransportOptions; + /** Client options passed to the MCP SDK Client */ + client?: McpClientOptions; +} + +/** + * Configuration for connecting to an MCP server via RPC transport + */ +export interface McpRpcConnectionConfig< + T extends McpAgent> = McpAgent +> { + type: "rpc"; + url: string; + normalizedName: string; + namespace: DurableObjectNamespace; + options?: RpcConnectionOptions; + reconnect?: { id: string }; +} + +/** + * Configuration for connecting to an MCP server via HTTP/SSE transport + */ +export interface McpHttpConnectionConfig { + type: "http"; + serverName: string; + url: string; + callbackUrl: string; + options?: HttpConnectionOptions; + reconnect?: { + id: string; + oauthClientId?: string; + }; +} + +/** + * Union type for MCP connection configuration + */ +export type McpConnectionConfig = + | McpRpcConnectionConfig + | McpHttpConnectionConfig; diff --git a/packages/agents/src/tests/mcp/mcp-protocol.test.ts b/packages/agents/src/tests/mcp/mcp-protocol.test.ts index 694d1b61..9ace561a 100644 --- a/packages/agents/src/tests/mcp/mcp-protocol.test.ts +++ b/packages/agents/src/tests/mcp/mcp-protocol.test.ts @@ -10,7 +10,8 @@ import { parseSSEData, expectValidToolsList, expectValidGreetResult, - expectValidPropsResult + expectValidPropsResult, + establishRPCConnection } from "../shared/test-utils"; declare module "cloudflare:test" { @@ -65,6 +66,18 @@ describe("MCP Protocol Core Functionality", () => { expectValidToolsList(result); }); + it("should list available tools via RPC", async () => { + const { connection } = await establishRPCConnection(); + + const result = await connection.client.listTools(); + + expectValidToolsList({ + jsonrpc: "2.0", + id: "tools-1", + result + }); + }); + it("should invoke greet tool via streamable HTTP", async () => { const ctx = createExecutionContext(); const sessionId = await initializeStreamableHTTPServer(ctx); @@ -107,6 +120,24 @@ describe("MCP Protocol Core Functionality", () => { expectValidGreetResult(result, "Test User"); }); + + it("should invoke greet tool via RPC", async () => { + const { connection } = await establishRPCConnection(); + + const result = await connection.client.callTool({ + name: "greet", + arguments: { name: "Test User" } + }); + + expectValidGreetResult( + { + jsonrpc: "2.0", + id: "greet-1", + result + }, + "Test User" + ); + }); }); describe("Props Passing", () => { diff --git a/packages/agents/src/tests/mcp/transports/rpc.test.ts b/packages/agents/src/tests/mcp/transports/rpc.test.ts new file mode 100644 index 00000000..ad83f25e --- /dev/null +++ b/packages/agents/src/tests/mcp/transports/rpc.test.ts @@ -0,0 +1,1742 @@ +import { describe, expect, it, beforeEach } from "vitest"; +import { + RPCClientTransport, + RPCServerTransport, + type MCPStub +} from "../../../mcp/rpc"; +import type { + JSONRPCMessage, + JSONRPCRequest, + JSONRPCNotification +} from "@modelcontextprotocol/sdk/types.js"; +import { TEST_MESSAGES } from "../../shared/test-utils"; + +describe("RPC Transport", () => { + describe("RPCClientTransport", () => { + it("should start and close transport", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + + await transport.start(); + + let closeCalled = false; + transport.onclose = () => { + closeCalled = true; + }; + + await transport.close(); + expect(closeCalled).toBe(true); + }); + + it("should throw error when sending before start", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }; + + await expect(transport.send(message)).rejects.toThrow( + "Transport not started" + ); + }); + + it("should send message and receive single response", async () => { + const mockResponse: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + result: { success: true } + }; + + const mockStub: MCPStub = { + handleMcpMessage: async () => mockResponse + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + let receivedMessage: JSONRPCMessage | undefined; + transport.onmessage = (msg) => { + receivedMessage = msg; + }; + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }; + + await transport.send(message); + + expect(receivedMessage).toEqual(mockResponse); + }); + + it("should send message and receive multiple responses", async () => { + const mockResponses: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + result: { success: true } + }, + { + jsonrpc: "2.0", + method: "notification", + params: {} + } + ]; + + const mockStub: MCPStub = { + handleMcpMessage: async () => mockResponses + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const receivedMessages: JSONRPCMessage[] = []; + transport.onmessage = (msg) => { + receivedMessages.push(msg); + }; + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }; + + await transport.send(message); + + expect(receivedMessages).toEqual(mockResponses); + }); + + it("should handle stub returning void", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + let receivedMessage: JSONRPCMessage | undefined; + transport.onmessage = (msg) => { + receivedMessage = msg; + }; + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "notification", + params: {} + }; + + await transport.send(message); + + expect(receivedMessage).toBeUndefined(); + }); + + it("should call onerror on stub error", async () => { + const mockError = new Error("Stub error"); + const mockStub: MCPStub = { + handleMcpMessage: async () => { + throw mockError; + } + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + let errorReceived: Error | undefined; + transport.onerror = (err) => { + errorReceived = err; + }; + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }; + + await expect(transport.send(message)).rejects.toThrow("Stub error"); + expect(errorReceived).toEqual(mockError); + }); + + it("should use custom function name", async () => { + let calledFunction: string | undefined; + + const mockStub = { + handleMcpMessage: async () => undefined, + customHandle: async () => { + calledFunction = "customHandle"; + return undefined; + } + } as MCPStub & { + customHandle: (msg: JSONRPCMessage) => Promise; + }; + + const transport = new RPCClientTransport({ + stub: mockStub, + functionName: "customHandle" + }); + await transport.start(); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + method: "test", + params: {} + }; + + await transport.send(message); + + expect(calledFunction).toBe("customHandle"); + }); + + it("should call onclose when closing", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + let closeCalled = false; + transport.onclose = () => { + closeCalled = true; + }; + + await transport.close(); + + expect(closeCalled).toBe(true); + }); + }); + + describe("RPCServerTransport", () => { + it("should start and close transport", async () => { + const transport = new RPCServerTransport(); + + await transport.start(); + + let closeCalled = false; + transport.onclose = () => { + closeCalled = true; + }; + + await transport.close(); + expect(closeCalled).toBe(true); + }); + + it("should handle request and return response", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const expectedResponse: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + result: { success: true } + }; + + transport.onmessage = (msg) => { + expect(msg).toEqual({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }); + }; + + const handlePromise = transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }); + + await transport.send(expectedResponse); + + const result = await handlePromise; + expect(result).toEqual(expectedResponse); + }); + + it("should handle request and return multiple responses", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const expectedResponses: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + result: { success: true } + }, + { + jsonrpc: "2.0", + method: "notification", + params: {} + } + ]; + + transport.onmessage = (msg) => { + const req = msg as JSONRPCRequest; + expect(req.id).toBe(1); + }; + + const handlePromise = transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }); + + for (const response of expectedResponses) { + await transport.send(response); + } + + const result = await handlePromise; + expect(result).toEqual(expectedResponses); + }); + + it("should handle concurrent sends within the same request", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const expectedResponses: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + result: { first: true } + }, + { + jsonrpc: "2.0", + id: 1, + result: { second: true } + } + ]; + + transport.onmessage = () => { + // Simulate concurrent sends (don't await) + transport.send(expectedResponses[0]); + transport.send(expectedResponses[1]); + }; + + const handlePromise = transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }); + + const result = await handlePromise; + expect(result).toEqual(expectedResponses); + }); + + it("should handle notification without waiting for response", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + let messageReceived = false; + transport.onmessage = () => { + messageReceived = true; + }; + + const result = await transport.handle({ + jsonrpc: "2.0", + method: "notification", + params: {} + }); + + expect(messageReceived).toBe(true); + expect(result).toBeUndefined(); + }); + + it("should throw error when handling before start", async () => { + const transport = new RPCServerTransport(); + + await expect( + transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }) + ).rejects.toThrow("Transport not started"); + }); + + it("should throw error when sending before start", async () => { + const transport = new RPCServerTransport(); + + await expect( + transport.send({ + jsonrpc: "2.0", + id: 1, + result: {} + }) + ).rejects.toThrow("Transport not started"); + }); + + it("should support session ID generation after initialization", async () => { + const transport = new RPCServerTransport({ + sessionIdGenerator: () => "test-session" + }); + await transport.start(); + + // Session ID is undefined until initialization + expect(transport.sessionId).toBeUndefined(); + + transport.onmessage = (msg) => { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test", version: "1.0.0" } + } + }); + }; + + // After initialization, session ID should be set + await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }); + + expect(transport.sessionId).toBe("test-session"); + }); + + it("should call onclose when closing", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + let closeCalled = false; + transport.onclose = () => { + closeCalled = true; + }; + + await transport.close(); + + expect(closeCalled).toBe(true); + }); + }); + + describe("Client-Server Integration", () => { + let clientTransport: RPCClientTransport; + let serverTransport: RPCServerTransport; + + beforeEach(async () => { + serverTransport = new RPCServerTransport(); + await serverTransport.start(); + + const stub: MCPStub = { + handleMcpMessage: async (msg: JSONRPCMessage | JSONRPCMessage[]) => { + return await serverTransport.handle(msg); + } + }; + + clientTransport = new RPCClientTransport({ stub }); + await clientTransport.start(); + }); + + it("should handle complete request-response cycle", async () => { + serverTransport.onmessage = async () => { + await serverTransport.send({ + jsonrpc: "2.0", + id: 1, + result: { data: "response" } + }); + }; + + const receivedMessages: JSONRPCMessage[] = []; + clientTransport.onmessage = (msg) => { + receivedMessages.push(msg); + }; + + await clientTransport.send({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: { data: "request" } + }); + + expect(receivedMessages).toHaveLength(1); + expect(receivedMessages[0]).toMatchObject({ + jsonrpc: "2.0", + id: 1, + result: { data: "response" } + }); + }); + + it("should handle notification (no response expected)", async () => { + let serverReceivedNotification = false; + serverTransport.onmessage = async (msg) => { + const notification = msg as JSONRPCNotification; + if (notification.method === "notification") { + serverReceivedNotification = true; + } + await serverTransport.send({ + jsonrpc: "2.0", + method: "ack", + params: {} + }); + }; + + await clientTransport.send({ + jsonrpc: "2.0", + method: "notification", + params: { data: "notify" } + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + expect(serverReceivedNotification).toBe(true); + }); + + it("should handle multiple messages", async () => { + const responses = [ + { jsonrpc: "2.0" as const, id: 1, result: { data: "first" } }, + { jsonrpc: "2.0" as const, id: 2, result: { data: "second" } } + ]; + + serverTransport.onmessage = async (msg) => { + const req = msg as JSONRPCRequest; + if (req.id === 1) { + await serverTransport.send(responses[0]); + } else if (req.id === 2) { + await serverTransport.send(responses[1]); + } + }; + + const receivedMessages: JSONRPCMessage[] = []; + clientTransport.onmessage = (msg) => { + receivedMessages.push(msg); + }; + + await clientTransport.send({ + jsonrpc: "2.0", + id: 1, + method: "test1", + params: {} + }); + + await clientTransport.send({ + jsonrpc: "2.0", + id: 2, + method: "test2", + params: {} + }); + + expect(receivedMessages).toHaveLength(2); + expect(receivedMessages[0]).toMatchObject(responses[0]); + expect(receivedMessages[1]).toMatchObject(responses[1]); + }); + }); + + describe("JSON-RPC 2.0 Validation", () => { + describe("Request/Notification Validation", () => { + it("should reject request without jsonrpc field", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const invalidMessage = { + id: 1, + method: "test", + params: {} + } as unknown as JSONRPCMessage; + + await expect(transport.send(invalidMessage)).rejects.toThrow( + 'jsonrpc field must be "2.0"' + ); + }); + + it("should reject request with wrong jsonrpc version", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const invalidMessage = { + jsonrpc: "1.0", + id: 1, + method: "test", + params: {} + } as unknown as JSONRPCMessage; + + await expect(transport.send(invalidMessage)).rejects.toThrow( + 'jsonrpc field must be "2.0"' + ); + }); + + it("should reject request with non-string method", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const invalidMessage = { + jsonrpc: "2.0", + id: 1, + method: 123, + params: {} + } as unknown as JSONRPCMessage; + + await expect(transport.send(invalidMessage)).rejects.toThrow( + "method must be a string" + ); + }); + + it("should reject request with reserved rpc.* method name", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const invalidMessage: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "rpc.reserved", + params: {} + }; + + await expect(transport.send(invalidMessage)).rejects.toThrow( + 'method names starting with "rpc." are reserved' + ); + }); + + it("should reject request with invalid id type", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const invalidMessage = { + jsonrpc: "2.0", + id: { invalid: true }, + method: "test", + params: {} + } as unknown as JSONRPCMessage; + + await expect(transport.send(invalidMessage)).rejects.toThrow( + "id must be string, number, or null" + ); + }); + + it("should accept request with null id", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const validMessage = { + jsonrpc: "2.0", + id: null, + method: "test", + params: {} + } as unknown as JSONRPCMessage; + + await expect(transport.send(validMessage)).resolves.not.toThrow(); + }); + + it("should reject request with non-structured params", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const invalidMessage = { + jsonrpc: "2.0", + id: 1, + method: "test", + params: "string params not allowed" + } as unknown as JSONRPCMessage; + + await expect(transport.send(invalidMessage)).rejects.toThrow( + "params must be an array or object" + ); + }); + + it("should accept request with array params", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const validMessage = { + jsonrpc: "2.0", + id: 1, + method: "test", + params: [1, 2, 3] + } as unknown as JSONRPCMessage; + + await expect(transport.send(validMessage)).resolves.not.toThrow(); + }); + + it("should accept request with object params", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + // Use TEST_MESSAGES.toolsList as a valid example + await expect( + transport.send(TEST_MESSAGES.toolsList) + ).resolves.not.toThrow(); + }); + + it("should accept request without params", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const validMessage: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + method: "test" + }; + + await expect(transport.send(validMessage)).resolves.not.toThrow(); + }); + + it("should accept notification without id", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const validMessage: JSONRPCMessage = { + jsonrpc: "2.0", + method: "notification", + params: {} + }; + + await expect(transport.send(validMessage)).resolves.not.toThrow(); + }); + }); + + describe("Response Validation", () => { + it("should reject response without result or error", async () => { + const invalidResponse = { + jsonrpc: "2.0", + id: 1 + } as unknown as JSONRPCMessage; + + const mockStub: MCPStub = { + handleMcpMessage: async () => invalidResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + await expect(transport.send(TEST_MESSAGES.toolsList)).rejects.toThrow( + "must have either result or error" + ); + }); + + it("should reject response with both result and error", async () => { + const invalidResponse = { + jsonrpc: "2.0", + id: 1, + result: { data: "test" }, + error: { code: -32600, message: "Invalid" } + } as unknown as JSONRPCMessage; + + const mockStub: MCPStub = { + handleMcpMessage: async () => invalidResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + await expect(transport.send(TEST_MESSAGES.toolsList)).rejects.toThrow( + "cannot have both result and error" + ); + }); + + it("should accept response with null id (parse error case)", async () => { + const validResponse = { + jsonrpc: "2.0", + id: null, + error: { code: -32700, message: "Parse error" } + } as unknown as JSONRPCMessage; + + const mockStub: MCPStub = { + handleMcpMessage: async () => validResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + let received: JSONRPCMessage | undefined; + transport.onmessage = (msg: JSONRPCMessage) => { + received = msg; + }; + + await transport.send(TEST_MESSAGES.toolsList); + expect(received).toEqual(validResponse); + }); + }); + + describe("Error Object Validation", () => { + it("should reject error with non-number code", async () => { + const invalidResponse = { + jsonrpc: "2.0", + id: 1, + error: { code: "not a number", message: "Error" } + } as unknown as JSONRPCMessage; + + const mockStub: MCPStub = { + handleMcpMessage: async () => invalidResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + await expect(transport.send(TEST_MESSAGES.toolsList)).rejects.toThrow( + "code must be a number" + ); + }); + + it("should reject error with non-integer code", async () => { + const invalidResponse = { + jsonrpc: "2.0", + id: 1, + error: { code: 123.45, message: "Error" } + } as unknown as JSONRPCMessage; + + const mockStub: MCPStub = { + handleMcpMessage: async () => invalidResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + await expect(transport.send(TEST_MESSAGES.toolsList)).rejects.toThrow( + "code must be an integer" + ); + }); + + it("should reject error with non-string message", async () => { + const invalidResponse = { + jsonrpc: "2.0", + id: 1, + error: { code: -32600, message: 123 } + } as unknown as JSONRPCMessage; + + const mockStub: MCPStub = { + handleMcpMessage: async () => invalidResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + await expect(transport.send(TEST_MESSAGES.toolsList)).rejects.toThrow( + "message must be a string" + ); + }); + + it("should accept error with valid structure and optional data", async () => { + const validResponse: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + error: { + code: -32600, + message: "Invalid Request", + data: { details: "Additional error info" } + } + }; + + const mockStub: MCPStub = { + handleMcpMessage: async () => validResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + let received: JSONRPCMessage | undefined; + transport.onmessage = (msg: JSONRPCMessage) => { + received = msg; + }; + + await transport.send(TEST_MESSAGES.toolsList); + expect(received).toEqual(validResponse); + }); + + it("should accept standard JSON-RPC error codes", async () => { + const errorCodes = [ + { code: -32700, message: "Parse error" }, + { code: -32600, message: "Invalid Request" }, + { code: -32601, message: "Method not found" }, + { code: -32602, message: "Invalid params" }, + { code: -32603, message: "Internal error" }, + { code: -32000, message: "Server error" } + ]; + + for (const error of errorCodes) { + const validResponse: JSONRPCMessage = { + jsonrpc: "2.0", + id: 1, + error + }; + + const mockStub: MCPStub = { + handleMcpMessage: async () => validResponse + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + let received: JSONRPCMessage | undefined; + transport.onmessage = (msg: JSONRPCMessage) => { + received = msg; + }; + + await transport.send(TEST_MESSAGES.toolsList); + expect(received).toEqual(validResponse); + } + }); + }); + + describe("Server Transport Validation", () => { + it("should validate incoming requests", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const invalidMessage = { + jsonrpc: "2.0", + id: 1, + method: "rpc.internal" + } as JSONRPCMessage; + + await expect(transport.handle(invalidMessage)).rejects.toThrow( + 'method names starting with "rpc." are reserved' + ); + }); + + it("should validate outgoing responses", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const invalidMessage = { + jsonrpc: "2.0", + id: 1 + } as unknown as JSONRPCMessage; + + await expect(transport.send(invalidMessage)).rejects.toThrow( + "must have either result or error" + ); + }); + + it("should validate response ID matches request ID (JSON-RPC 2.0 spec section 5)", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + // Set up message handler first + transport.onmessage = async () => { + // Try to send response with mismatched id + await expect( + transport.send({ + jsonrpc: "2.0", + id: 2, // Wrong ID! Should be 1 + result: { data: "response" } + }) + ).rejects.toThrow( + "Response ID 2 does not match request ID 1 (JSON-RPC 2.0 spec section 5)" + ); + + // Send correct response to complete the test + await transport.send({ + jsonrpc: "2.0", + id: 1, + result: { data: "response" } + }); + }; + + // Start handling a request with id: 1 + await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }); + }); + + it("should allow notifications alongside responses", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + // Set up message handler first + transport.onmessage = async () => { + // Send a notification (no id) - should be allowed + await transport.send({ + jsonrpc: "2.0", + method: "progress", + params: { percent: 50 } + }); + + // Send the response with matching id + await transport.send({ + jsonrpc: "2.0", + id: 1, + result: { data: "response" } + }); + }; + + // Start handling a request + const result = await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }); + + // Should receive both messages + expect(Array.isArray(result)).toBe(true); + expect(result).toHaveLength(2); + expect(result).toEqual([ + { jsonrpc: "2.0", method: "progress", params: { percent: 50 } }, + { jsonrpc: "2.0", id: 1, result: { data: "response" } } + ]); + }); + }); + }); + + describe("Batch Requests (JSON-RPC 2.0 spec section 6)", () => { + describe("Client Transport Batch Support", () => { + it("should send batch requests", async () => { + const batchMessages: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + method: "sum", + params: [1, 2, 4] as unknown as Record + }, + { + jsonrpc: "2.0", + method: "notify_hello", + params: [7] as unknown as Record + }, + { + jsonrpc: "2.0", + id: 2, + method: "subtract", + params: [42, 23] as unknown as Record + } + ]; + + const mockResponses: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + result: 7 as unknown as Record + }, + { + jsonrpc: "2.0", + id: 2, + result: 19 as unknown as Record + } + ]; + + const mockStub: MCPStub = { + handleMcpMessage: async () => mockResponses + }; + + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const receivedMessages: JSONRPCMessage[] = []; + transport.onmessage = (msg) => { + receivedMessages.push(msg); + }; + + await transport.send(batchMessages); + + expect(receivedMessages).toEqual(mockResponses); + }); + + it("should reject empty batch", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + await expect(transport.send([])).rejects.toThrow( + "array must not be empty" + ); + }); + + it("should reject batch with invalid message", async () => { + const mockStub: MCPStub = { + handleMcpMessage: async () => undefined + }; + const transport = new RPCClientTransport({ stub: mockStub }); + await transport.start(); + + const invalidBatch = [ + { jsonrpc: "2.0", id: 1, method: "test", params: {} }, + { invalid: "message" } + ] as unknown as JSONRPCMessage[]; + + await expect(transport.send(invalidBatch)).rejects.toThrow( + "message at index 1 is invalid" + ); + }); + }); + + describe("Server Transport Batch Support", () => { + it("should handle batch with multiple requests", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + transport.onmessage = async (msg) => { + const req = msg as { method: string; id?: number }; + if (req.method === "sum") { + await transport.send({ + jsonrpc: "2.0", + id: req.id!, + result: 7 as unknown as Record + }); + } else if (req.method === "subtract") { + await transport.send({ + jsonrpc: "2.0", + id: req.id!, + result: 19 as unknown as Record + }); + } + }; + + const batch: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + method: "sum", + params: [1, 2, 4] as unknown as Record + }, + { + jsonrpc: "2.0", + id: 2, + method: "subtract", + params: [42, 23] as unknown as Record + } + ]; + + const result = await transport.handle(batch); + + expect(Array.isArray(result)).toBe(true); + expect(result).toHaveLength(2); + expect(result).toEqual([ + { + jsonrpc: "2.0", + id: 1, + result: 7 as unknown as Record + }, + { + jsonrpc: "2.0", + id: 2, + result: 19 as unknown as Record + } + ]); + }); + + it("should handle batch with notifications only (returns nothing)", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const receivedNotifications: string[] = []; + transport.onmessage = async (msg) => { + const notification = msg as { method: string }; + receivedNotifications.push(notification.method); + }; + + const batch: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + method: "notify_sum", + params: [1, 2, 4] as unknown as Record + }, + { + jsonrpc: "2.0", + method: "notify_hello", + params: [7] as unknown as Record + } + ]; + + const result = await transport.handle(batch); + + // Spec: "should return nothing at all" when all notifications + expect(result).toBeUndefined(); + expect(receivedNotifications).toEqual(["notify_sum", "notify_hello"]); + }); + + it("should handle batch with mixed requests and notifications", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const receivedNotifications: string[] = []; + transport.onmessage = async (msg) => { + const req = msg as { method: string; id?: number }; + if (!("id" in msg)) { + receivedNotifications.push(req.method); + } else if (req.method === "sum") { + await transport.send({ + jsonrpc: "2.0", + id: req.id!, + result: 7 as unknown as Record + }); + } + }; + + const batch: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + method: "sum", + params: [1, 2, 4] as unknown as Record + }, + { + jsonrpc: "2.0", + method: "notify_hello", + params: [7] as unknown as Record + } + ]; + + const result = await transport.handle(batch); + + // Should have response for request, but not for notification + expect(Array.isArray(result)).toBe(true); + expect(result).toHaveLength(1); + expect(result).toEqual([ + { + jsonrpc: "2.0", + id: 1, + result: 7 as unknown as Record + } + ]); + expect(receivedNotifications).toEqual(["notify_hello"]); + }); + + it("should reject empty batch", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + await expect(transport.handle([])).rejects.toThrow( + "array must not be empty" + ); + }); + + it("should reject batch with invalid message", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + const invalidBatch = [ + { jsonrpc: "2.0", id: 1, method: "test", params: {} }, + { invalid: "message" } + ] as unknown as JSONRPCMessage[]; + + await expect(transport.handle(invalidBatch)).rejects.toThrow( + "message at index 1 is invalid" + ); + }); + }); + + describe("End-to-End Batch Processing", () => { + it("should handle complete batch request-response cycle", async () => { + const serverTransport = new RPCServerTransport(); + await serverTransport.start(); + + serverTransport.onmessage = async (msg) => { + const req = msg as { method: string; id?: number; params?: number[] }; + if (req.method === "sum" && req.id) { + const sum = (req.params || []).reduce((a, b) => a + b, 0); + await serverTransport.send({ + jsonrpc: "2.0", + id: req.id, + result: sum as unknown as Record + }); + } else if (req.method === "subtract" && req.id) { + const [a, b] = req.params || [0, 0]; + await serverTransport.send({ + jsonrpc: "2.0", + id: req.id, + result: (a - b) as unknown as Record + }); + } + }; + + const stub: MCPStub = { + handleMcpMessage: async (msg: JSONRPCMessage | JSONRPCMessage[]) => { + return await serverTransport.handle(msg); + } + }; + + const clientTransport = new RPCClientTransport({ stub }); + await clientTransport.start(); + + const receivedMessages: JSONRPCMessage[] = []; + clientTransport.onmessage = (msg) => { + receivedMessages.push(msg); + }; + + const batch: JSONRPCMessage[] = [ + { + jsonrpc: "2.0", + id: 1, + method: "sum", + params: [1, 2, 4] as unknown as Record + }, + { + jsonrpc: "2.0", + method: "notify_hello", + params: [7] as unknown as Record + }, + { + jsonrpc: "2.0", + id: 2, + method: "subtract", + params: [42, 23] as unknown as Record + } + ]; + + await clientTransport.send(batch); + + expect(receivedMessages).toHaveLength(2); + expect(receivedMessages).toEqual([ + { + jsonrpc: "2.0", + id: 1, + result: 7 as unknown as Record + }, + { + jsonrpc: "2.0", + id: 2, + result: 19 as unknown as Record + } + ]); + }); + }); + }); + + describe("Session Management", () => { + describe("Stateless Mode (no session ID generator)", () => { + it("should work without session management when no sessionIdGenerator is provided", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + let receivedMessage: JSONRPCMessage | undefined; + transport.onmessage = (msg) => { + receivedMessage = msg; + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { success: true } + }); + }; + + const response = await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "test", + params: {} + }); + + expect(receivedMessage).toBeDefined(); + expect(response).toEqual({ + jsonrpc: "2.0", + id: 1, + result: { success: true } + }); + expect(transport.sessionId).toBeUndefined(); + }); + + it("should allow any request without initialization in stateless mode", async () => { + const transport = new RPCServerTransport(); + await transport.start(); + + let receivedCount = 0; + transport.onmessage = (msg) => { + receivedCount++; + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { count: receivedCount } + }); + }; + + // First request should work without initialize + const response1 = await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: {} + }); + + expect(response1).toEqual({ + jsonrpc: "2.0", + id: 1, + result: { count: 1 } + }); + + // Second request should also work + const response2 = await transport.handle({ + jsonrpc: "2.0", + id: 2, + method: "tools/call", + params: {} + }); + + expect(response2).toEqual({ + jsonrpc: "2.0", + id: 2, + result: { count: 2 } + }); + }); + }); + + describe("Stateful Mode (with session ID generator)", () => { + it("should generate session ID during initialization", async () => { + let initializedSessionId: string | undefined; + const sessionIdGenerator = () => "test-session-123"; + + const transport = new RPCServerTransport({ + sessionIdGenerator, + onsessioninitialized: (sessionId) => { + initializedSessionId = sessionId; + } + }); + await transport.start(); + + transport.onmessage = (msg) => { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test", version: "1.0.0" } + } + }); + }; + + const response = await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }); + + expect(response).toBeDefined(); + expect(transport.sessionId).toBe("test-session-123"); + expect(initializedSessionId).toBe("test-session-123"); + }); + + it("should reject non-initialization requests before session is initialized", async () => { + const sessionIdGenerator = () => "test-session-456"; + const transport = new RPCServerTransport({ + sessionIdGenerator + }); + await transport.start(); + + transport.onmessage = (msg) => { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { success: true } + }); + }; + + // Try to send a non-initialization request before initialize + await expect( + transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "tools/list", + params: {} + }) + ).rejects.toThrow( + "Session not initialized. An initialize request must be sent first." + ); + + expect(transport.sessionId).toBeUndefined(); + }); + + it("should reject duplicate initialization requests", async () => { + const sessionIdGenerator = () => "test-session-789"; + const transport = new RPCServerTransport({ + sessionIdGenerator + }); + await transport.start(); + + transport.onmessage = (msg) => { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test", version: "1.0.0" } + } + }); + }; + + // First initialization should succeed + await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }); + + expect(transport.sessionId).toBe("test-session-789"); + + // Second initialization should fail + await expect( + transport.handle({ + jsonrpc: "2.0", + id: 2, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }) + ).rejects.toThrow("Session already initialized"); + }); + + it("should allow requests after successful initialization", async () => { + const sessionIdGenerator = () => "test-session-abc"; + const transport = new RPCServerTransport({ + sessionIdGenerator + }); + await transport.start(); + + let requestCount = 0; + transport.onmessage = (msg) => { + requestCount++; + if ((msg as { method: string }).method === "initialize") { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test", version: "1.0.0" } + } + }); + } else { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { requestCount } + }); + } + }; + + // Initialize first + await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }); + + // Now other requests should work + const response1 = await transport.handle({ + jsonrpc: "2.0", + id: 2, + method: "tools/list", + params: {} + }); + + expect(response1).toEqual({ + jsonrpc: "2.0", + id: 2, + result: { requestCount: 2 } + }); + + const response2 = await transport.handle({ + jsonrpc: "2.0", + id: 3, + method: "tools/call", + params: {} + }); + + expect(response2).toEqual({ + jsonrpc: "2.0", + id: 3, + result: { requestCount: 3 } + }); + }); + + it("should call onsessionclosed when terminateSession is called", async () => { + let closedSessionId: string | undefined; + const sessionIdGenerator = () => "test-session-xyz"; + + const transport = new RPCServerTransport({ + sessionIdGenerator, + onsessionclosed: (sessionId) => { + closedSessionId = sessionId; + } + }); + await transport.start(); + + transport.onmessage = (msg) => { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test", version: "1.0.0" } + } + }); + }; + + // Initialize session + await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }); + + expect(transport.sessionId).toBe("test-session-xyz"); + + // Terminate session + await transport.terminateSession(); + + expect(closedSessionId).toBe("test-session-xyz"); + expect(transport.sessionId).toBeUndefined(); + }); + + it("should call onsessionclosed when transport is closed", async () => { + let closedSessionId: string | undefined; + const sessionIdGenerator = () => "test-session-close"; + + const transport = new RPCServerTransport({ + sessionIdGenerator, + onsessionclosed: (sessionId) => { + closedSessionId = sessionId; + } + }); + await transport.start(); + + transport.onmessage = (msg) => { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test", version: "1.0.0" } + } + }); + }; + + // Initialize session + await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }); + + expect(transport.sessionId).toBe("test-session-close"); + + // Close transport + await transport.close(); + + expect(closedSessionId).toBe("test-session-close"); + expect(transport.sessionId).toBeUndefined(); + }); + + it("should support async session lifecycle hooks", async () => { + const hookCalls: string[] = []; + const sessionIdGenerator = () => "test-session-async"; + + const transport = new RPCServerTransport({ + sessionIdGenerator, + onsessioninitialized: async (sessionId) => { + await new Promise((resolve) => setTimeout(resolve, 10)); + hookCalls.push(`initialized:${sessionId}`); + }, + onsessionclosed: async (sessionId) => { + await new Promise((resolve) => setTimeout(resolve, 10)); + hookCalls.push(`closed:${sessionId}`); + } + }); + await transport.start(); + + transport.onmessage = (msg) => { + transport.send({ + jsonrpc: "2.0", + id: (msg as { id: number }).id, + result: { + protocolVersion: "2025-06-18", + capabilities: {}, + serverInfo: { name: "test", version: "1.0.0" } + } + }); + }; + + // Initialize session + await transport.handle({ + jsonrpc: "2.0", + id: 1, + method: "initialize", + params: { + protocolVersion: "2025-06-18", + capabilities: {}, + clientInfo: { name: "test-client", version: "1.0.0" } + } + }); + + expect(hookCalls).toContain("initialized:test-session-async"); + + // Close transport + await transport.close(); + + expect(hookCalls).toContain("closed:test-session-async"); + expect(hookCalls).toHaveLength(2); + }); + }); + }); +}); diff --git a/packages/agents/src/tests/shared/test-utils.ts b/packages/agents/src/tests/shared/test-utils.ts index c7147594..46005bde 100644 --- a/packages/agents/src/tests/shared/test-utils.ts +++ b/packages/agents/src/tests/shared/test-utils.ts @@ -139,15 +139,43 @@ export async function initializeStreamableHTTPServer( export async function initializeMCPClientConnection( baseUrl = "http://example.com/mcp", - transportType: "auto" | "streamable-http" | "sse" = "auto" + transportType: "auto" | "streamable-http" | "sse" | "rpc" = "auto", + transportOptions?: Record ) { return new MCPClientConnection( new URL(baseUrl), { name: "test-client", version: "1.0.0" }, - { transport: { type: transportType }, client: {} } + { transport: { type: transportType, ...transportOptions }, client: {} } ); } +/** + * Helper to create RPC connection to TestMcpAgent + */ +export async function establishRPCConnection(): Promise<{ + connection: MCPClientConnection; + sessionId: string; +}> { + const sessionId = `rpc:${crypto.randomUUID()}`; + const id = env.MCP_OBJECT.idFromName(sessionId); + const agentStub = env.MCP_OBJECT.get(id); + + // Set the name on the stub to avoid "name not set" error + agentStub.setName(sessionId); + + // Create MCPClientConnection with RPC transport + const connection = await initializeMCPClientConnection( + "http://example.com/mcp", + "rpc", + { stub: agentStub } + ); + + // Initialize the connection + await connection.init(); + + return { connection, sessionId }; +} + /** * Helper to establish SSE connection and get session ID */