diff --git a/.changeset/spotty-crabs-allow.md b/.changeset/spotty-crabs-allow.md new file mode 100644 index 00000000..e05a462d --- /dev/null +++ b/.changeset/spotty-crabs-allow.md @@ -0,0 +1,5 @@ +--- +"agents": patch +--- + +When handling MCP server requests use relatedRequestId in TransportOptions to send the response down a POST stream if supported (streamable-http) diff --git a/examples/mcp-elicitation/src/index.ts b/examples/mcp-elicitation/src/index.ts index 689e5efa..bc895420 100644 --- a/examples/mcp-elicitation/src/index.ts +++ b/examples/mcp-elicitation/src/index.ts @@ -54,28 +54,31 @@ export class MyAgent extends Agent { confirm: z.boolean().describe("Do you want to increase the counter?") } }, - async ({ confirm }) => { + async ({ confirm }, extra) => { if (!confirm) { return { content: [{ type: "text", text: "Counter increase cancelled." }] }; } try { - const basicInfo = await this.server.server.elicitInput({ - message: "By how much do you want to increase the counter?", - requestedSchema: { - type: "object", - properties: { - amount: { - type: "number", - title: "Amount", - description: "The amount to increase the counter by", - minLength: 1 - } - }, - required: ["amount"] - } - }); + const basicInfo = await this.server.server.elicitInput( + { + message: "By how much do you want to increase the counter?", + requestedSchema: { + type: "object", + properties: { + amount: { + type: "number", + title: "Amount", + description: "The amount to increase the counter by", + minLength: 1 + } + }, + required: ["amount"] + } + }, + { relatedRequestId: extra.requestId } + ); if (basicInfo.action !== "accept" || !basicInfo.content) { return { diff --git a/packages/agents/src/mcp/worker-transport.ts b/packages/agents/src/mcp/worker-transport.ts index 8bf60fde..ab75b5c1 100644 --- a/packages/agents/src/mcp/worker-transport.ts +++ b/packages/agents/src/mcp/worker-transport.ts @@ -2,7 +2,10 @@ * Based on @hono/mcp transport implementation (https://github.com/honojs/middleware/tree/main/packages/mcp) */ -import type { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; +import type { + Transport, + TransportSendOptions +} from "@modelcontextprotocol/sdk/shared/transport.js"; import type { JSONRPCMessage, RequestId, @@ -357,7 +360,7 @@ export class WorkerTransport implements Transport { const acceptHeader = request.headers.get("Accept"); if ( !acceptHeader?.includes("application/json") || - !acceptHeader.includes("text/event-stream") + !acceptHeader?.includes("text/event-stream") ) { return new Response( JSON.stringify({ @@ -738,9 +741,14 @@ export class WorkerTransport implements Transport { this.onclose?.(); } - async send(message: JSONRPCMessage): Promise { - let requestId: RequestId | undefined; + async send( + message: JSONRPCMessage, + options?: TransportSendOptions + ): Promise { + // Check relatedRequestId FIRST to route server-to-client requests through the same stream as the originating client request + let requestId: RequestId | undefined = options?.relatedRequestId; + // Then override with message.id for responses/errors if (isJSONRPCResponse(message) || isJSONRPCError(message)) { requestId = message.id; } diff --git a/packages/agents/src/tests/mcp/worker-transport.test.ts b/packages/agents/src/tests/mcp/transports/worker-transport.test.ts similarity index 72% rename from packages/agents/src/tests/mcp/worker-transport.test.ts rename to packages/agents/src/tests/mcp/transports/worker-transport.test.ts index 0a580959..5752bc9e 100644 --- a/packages/agents/src/tests/mcp/worker-transport.test.ts +++ b/packages/agents/src/tests/mcp/transports/worker-transport.test.ts @@ -1,14 +1,14 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { CallToolResult, - JSONRPCMessage + JSONRPCRequest } from "@modelcontextprotocol/sdk/types.js"; -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi, beforeEach } from "vitest"; import { WorkerTransport, type TransportState, type WorkerTransportOptions -} from "../../mcp/worker-transport"; +} from "../../../mcp/worker-transport"; import { z } from "zod"; /** @@ -973,7 +973,7 @@ describe("WorkerTransport", () => { const body = await response.json(); expect(body).toBeDefined(); - expect((body as JSONRPCMessage).jsonrpc).toBe("2.0"); + expect((body as JSONRPCRequest).jsonrpc).toBe("2.0"); }); it("should return SSE stream when enableJsonResponse is false", async () => { @@ -1036,4 +1036,340 @@ describe("WorkerTransport", () => { expect(response.headers.get("Content-Type")).toBe("application/json"); }); }); + + describe("relatedRequestId Routing", () => { + let transport: WorkerTransport; + let postStreamWriter: WritableStreamDefaultWriter; + let getStreamWriter: WritableStreamDefaultWriter; + let postStreamData: string[] = []; + let getStreamData: string[] = []; + + // Type for accessing private properties during testing (whitebox testing) + type TransportInternal = { + streamMapping: Map; + requestToStreamMapping: Map; + }; + + /** + * Helper to set up mock streams on the transport for testing. + * This is whitebox testing that accesses private fields via type assertion. + */ + const setupMockStream = ( + transport: WorkerTransport, + streamId: string, + writer: WritableStreamDefaultWriter, + encoder: TextEncoder + ) => { + const transportInternal = transport as unknown as TransportInternal; + transportInternal.streamMapping.set(streamId, { + writer, + encoder, + cleanup: vi.fn() + }); + }; + + /** + * Helper to map a request ID to a stream ID for testing. + */ + const mapRequestToStream = ( + transport: WorkerTransport, + requestId: string | number, + streamId: string + ) => { + const transportInternal = transport as unknown as TransportInternal; + transportInternal.requestToStreamMapping.set(requestId, streamId); + }; + + /** + * Helper to delete a stream from the transport for testing. + */ + const deleteStream = (transport: WorkerTransport, streamId: string) => { + const transportInternal = transport as unknown as TransportInternal; + transportInternal.streamMapping.delete(streamId); + }; + + beforeEach(() => { + // Reset data arrays + postStreamData = []; + getStreamData = []; + + // Create transport + transport = new WorkerTransport({ + sessionIdGenerator: () => "test-session" + }); + + // Mock the stream mappings manually + const postEncoder = new TextEncoder(); + const getEncoder = new TextEncoder(); + + // Create mock writers that capture data + postStreamWriter = { + write: vi.fn(async (chunk: Uint8Array) => { + postStreamData.push(new TextDecoder().decode(chunk)); + }), + close: vi.fn(), + abort: vi.fn(), + releaseLock: vi.fn() + } as unknown as WritableStreamDefaultWriter; + + getStreamWriter = { + write: vi.fn(async (chunk: Uint8Array) => { + getStreamData.push(new TextDecoder().decode(chunk)); + }), + close: vi.fn(), + abort: vi.fn(), + releaseLock: vi.fn() + } as unknown as WritableStreamDefaultWriter; + + // Set up the stream mappings using helpers + setupMockStream( + transport, + "post-stream-1", + postStreamWriter, + postEncoder + ); + setupMockStream(transport, "_GET_stream", getStreamWriter, getEncoder); + mapRequestToStream(transport, "req-1", "post-stream-1"); + }); + + describe("Server-to-client requests with relatedRequestId", () => { + it("should route messages with relatedRequestId through the POST stream", async () => { + const elicitationRequest: JSONRPCRequest = { + jsonrpc: "2.0", + id: "elicit-1", + method: "elicitation/create", + params: { + message: "What is your name?", + mode: "form", + requestedSchema: { + type: "object", + properties: { + name: { type: "string" } + } + } + } + }; + + // Send with relatedRequestId pointing to req-1 (which maps to post-stream-1) + await transport.send(elicitationRequest, { relatedRequestId: "req-1" }); + + // Should go through POST stream + expect(postStreamWriter.write).toHaveBeenCalled(); + expect(postStreamData.length).toBe(1); + expect(postStreamData[0]).toContain("elicitation/create"); + expect(postStreamData[0]).toContain("What is your name?"); + + // Should NOT go through GET stream + expect(getStreamWriter.write).not.toHaveBeenCalled(); + expect(getStreamData.length).toBe(0); + }); + + it("should route multiple messages to their respective streams based on relatedRequestId", async () => { + // Add another POST stream + const postStreamWriter2: WritableStreamDefaultWriter = { + write: vi.fn(async (_chunk: Uint8Array) => {}), + close: vi.fn(), + abort: vi.fn(), + releaseLock: vi.fn() + } as unknown as WritableStreamDefaultWriter; + + const postEncoder2 = new TextEncoder(); + + setupMockStream( + transport, + "post-stream-2", + postStreamWriter2, + postEncoder2 + ); + mapRequestToStream(transport, "req-2", "post-stream-2"); + + const message1: JSONRPCRequest = { + jsonrpc: "2.0", + id: "msg-1", + method: "elicitation/create", + params: { message: "Message for stream 1" } + }; + + const message2: JSONRPCRequest = { + jsonrpc: "2.0", + id: "msg-2", + method: "elicitation/create", + params: { message: "Message for stream 2" } + }; + + // Send to different streams + await transport.send(message1, { relatedRequestId: "req-1" }); + await transport.send(message2, { relatedRequestId: "req-2" }); + + // Each stream should receive its own message + expect(postStreamWriter.write).toHaveBeenCalledTimes(1); + expect(postStreamWriter2.write).toHaveBeenCalledTimes(1); + expect(getStreamWriter.write).not.toHaveBeenCalled(); + }); + }); + + describe("Server-to-client requests without relatedRequestId", () => { + it("should route messages without relatedRequestId through the standalone GET stream", async () => { + const notification: JSONRPCRequest = { + jsonrpc: "2.0", + id: "notif-1", + method: "notifications/message", + params: { + level: "info", + data: "Server notification" + } + }; + + // Send without relatedRequestId + await transport.send(notification); + + // Should go through GET stream + expect(getStreamWriter.write).toHaveBeenCalled(); + expect(getStreamData.length).toBe(1); + expect(getStreamData[0]).toContain("notifications/message"); + + // Should NOT go through POST stream + expect(postStreamWriter.write).not.toHaveBeenCalled(); + expect(postStreamData.length).toBe(0); + }); + + it("should not fail when standalone GET stream is not available", async () => { + // Remove the GET stream + deleteStream(transport, "_GET_stream"); + + const notification: JSONRPCRequest = { + jsonrpc: "2.0", + id: "notif-2", + method: "notifications/message", + params: { level: "info", data: "Test" } + }; + + // Should not throw + await expect(transport.send(notification)).resolves.toBeUndefined(); + }); + }); + + describe("Response routing", () => { + it("should route responses based on their message.id (overriding relatedRequestId)", async () => { + const response = { + jsonrpc: "2.0" as const, + id: "req-1", + result: { content: [{ type: "text" as const, text: "Response" }] } + }; + + // Even if we provide a different relatedRequestId, response should use message.id + await transport.send(response, { relatedRequestId: "something-else" }); + + // Should go through POST stream (because message.id="req-1" maps to post-stream-1) + expect(postStreamWriter.write).toHaveBeenCalled(); + expect(postStreamData.length).toBeGreaterThan(0); + }); + }); + + describe("Error handling", () => { + it("should throw error when relatedRequestId has no mapped stream", async () => { + const message: JSONRPCRequest = { + jsonrpc: "2.0", + id: "msg-1", + method: "elicitation/create", + params: {} + }; + + await expect( + transport.send(message, { relatedRequestId: "non-existent-id" }) + ).rejects.toThrow(/No connection established/); + }); + + it("should not send responses to standalone stream when requestId is not mapped", async () => { + const response = { + jsonrpc: "2.0" as const, + id: "unknown-request", + result: { content: [] } + }; + + // Should throw because the requestId is not mapped to any stream + await expect(transport.send(response)).rejects.toThrow( + /No connection established for request ID/ + ); + }); + }); + + describe("Edge cases", () => { + it("should use message.id for responses even when relatedRequestId matches a different mapped request", async () => { + // Set up: req-1 -> post-stream-1, req-2 -> post-stream-2 + const postStreamWriter2: WritableStreamDefaultWriter = { + write: vi.fn(async (_chunk: Uint8Array) => {}), + close: vi.fn(), + abort: vi.fn(), + releaseLock: vi.fn() + } as unknown as WritableStreamDefaultWriter; + + setupMockStream( + transport, + "post-stream-2", + postStreamWriter2, + new TextEncoder() + ); + mapRequestToStream(transport, "req-2", "post-stream-2"); + + // Send a response with id="req-2" but relatedRequestId="req-1" + const response = { + jsonrpc: "2.0" as const, + id: "req-2", + result: { content: [{ type: "text" as const, text: "Response" }] } + }; + + await transport.send(response, { relatedRequestId: "req-1" }); + + // Should go through post-stream-2 (based on message.id="req-2") + // NOT post-stream-1 (based on relatedRequestId="req-1") + expect(postStreamWriter2.write).toHaveBeenCalled(); + expect(postStreamWriter.write).not.toHaveBeenCalled(); + }); + + it("should handle multiple concurrent server-to-client requests with the same relatedRequestId", async () => { + // Both elicitations reference the same originating request + const elicitation1: JSONRPCRequest = { + jsonrpc: "2.0", + id: "elicit-1", + method: "elicitation/create", + params: { message: "First elicitation" } + }; + + const elicitation2: JSONRPCRequest = { + jsonrpc: "2.0", + id: "elicit-2", + method: "elicitation/create", + params: { message: "Second elicitation" } + }; + + // Both use the same relatedRequestId + await transport.send(elicitation1, { relatedRequestId: "req-1" }); + await transport.send(elicitation2, { relatedRequestId: "req-1" }); + + // Both should go through the same POST stream + expect(postStreamWriter.write).toHaveBeenCalledTimes(2); + expect(postStreamData.length).toBe(2); + expect(postStreamData[0]).toContain("First elicitation"); + expect(postStreamData[1]).toContain("Second elicitation"); + }); + + it("should handle relatedRequestId that points to a closed stream differently than missing stream", async () => { + // Map req-2 to a stream, then delete the stream (simulating closure) + mapRequestToStream(transport, "req-2", "closed-stream"); + + const message: JSONRPCRequest = { + jsonrpc: "2.0", + id: "msg-1", + method: "elicitation/create", + params: {} + }; + + // Should throw because stream is mapped but doesn't exist + await expect( + transport.send(message, { relatedRequestId: "req-2" }) + ).rejects.toThrow(/No connection established/); + }); + }); + }); }); diff --git a/packages/agents/src/tests/shared/test-utils.ts b/packages/agents/src/tests/shared/test-utils.ts index c7147594..49634820 100644 --- a/packages/agents/src/tests/shared/test-utils.ts +++ b/packages/agents/src/tests/shared/test-utils.ts @@ -17,7 +17,9 @@ export const TEST_MESSAGES = { jsonrpc: "2.0", method: "initialize", params: { - capabilities: {}, + capabilities: { + elicitation: { form: {} } + }, clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26" } diff --git a/packages/agents/src/tests/worker.ts b/packages/agents/src/tests/worker.ts index 34bd60fa..c6eda68e 100644 --- a/packages/agents/src/tests/worker.ts +++ b/packages/agents/src/tests/worker.ts @@ -45,7 +45,13 @@ export class TestMcpAgent extends McpAgent { server = new McpServer( { name: "test-server", version: "1.0.0" }, - { capabilities: { logging: {}, tools: { listChanged: true } } } + { + capabilities: { + logging: {}, + tools: { listChanged: true }, + elicitation: { form: {}, url: {} } + } + } ); async init() { @@ -87,6 +93,42 @@ export class TestMcpAgent extends McpAgent { } ); + this.server.tool( + "elicitName", + "Test tool that elicits user input for a name", + {}, + async (): Promise => { + const result = await this.server.server.elicitInput({ + message: "What is your name?", + requestedSchema: { + type: "object", + properties: { + name: { + type: "string", + description: "Your name" + } + }, + required: ["name"] + } + }); + + if (result.action === "accept" && result.content?.name) { + return { + content: [ + { + type: "text", + text: `You said your name is: ${result.content.name}` + } + ] + }; + } + + return { + content: [{ type: "text", text: "Elicitation cancelled" }] + }; + } + ); + // Use `registerTool` so we can later remove it. // Triggers notifications/tools/list_changed this.server.tool(