diff --git a/package-lock.json b/package-lock.json index a4e2e7eca..687a7c0c4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,15 +1,16 @@ { "name": "@modelcontextprotocol/sdk", - "version": "1.3.1", + "version": "1.3.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@modelcontextprotocol/sdk", - "version": "1.3.1", + "version": "1.3.2", "license": "MIT", "dependencies": { "content-type": "^1.0.5", + "eventsource": "^3.0.2", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" @@ -24,7 +25,6 @@ "@types/node": "^22.0.2", "@types/ws": "^8.5.12", "eslint": "^9.8.0", - "eventsource": "^2.0.2", "express": "^4.19.2", "jest": "^29.7.0", "ts-jest": "^29.2.4", @@ -3066,12 +3066,24 @@ } }, "node_modules/eventsource": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/eventsource/-/eventsource-2.0.2.tgz", - "integrity": "sha512-IzUmBGPR3+oUG9dUeXynyNmf91/3zUSJg1lCktzKw47OXuhco54U3r9B7O4XX+Rb1Itm9OZ2b0RkTs10bICOxA==", - "dev": true, + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/eventsource/-/eventsource-3.0.2.tgz", + "integrity": "sha512-YolzkJNxsTL3tCJMWFxpxtG2sCjbZ4LQUBUrkdaJK0ub0p6lmJt+2+1SwhKjLc652lpH9L/79Ptez972H9tphw==", + "license": "MIT", + "dependencies": { + "eventsource-parser": "^3.0.0" + }, "engines": { - "node": ">=12.0.0" + "node": ">=18.0.0" + } + }, + "node_modules/eventsource-parser": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.0.tgz", + "integrity": "sha512-T1C0XCUimhxVQzW4zFipdx0SficT651NnkR0ZSH3yQwh+mFMdLfgjABVi4YtMTtaL4s168593DaoaRLMqryavA==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" } }, "node_modules/execa": { diff --git a/package.json b/package.json index 6f06fac0a..7e21c07fe 100644 --- a/package.json +++ b/package.json @@ -47,6 +47,7 @@ }, "dependencies": { "content-type": "^1.0.5", + "eventsource": "^3.0.2", "raw-body": "^3.0.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.24.1" @@ -61,7 +62,6 @@ "@types/node": "^22.0.2", "@types/ws": "^8.5.12", "eslint": "^9.8.0", - "eventsource": "^2.0.2", "express": "^4.19.2", "jest": "^29.7.0", "ts-jest": "^29.2.4", diff --git a/src/cli.ts b/src/cli.ts index d5444972d..b5000896d 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -1,8 +1,5 @@ -import EventSource from "eventsource"; import WebSocket from "ws"; -// eslint-disable-next-line @typescript-eslint/no-explicit-any -(global as any).EventSource = EventSource; // eslint-disable-next-line @typescript-eslint/no-explicit-any (global as any).WebSocket = WebSocket; diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts new file mode 100644 index 000000000..f59c45fef --- /dev/null +++ b/src/client/sse.test.ts @@ -0,0 +1,287 @@ +import { createServer, type IncomingMessage, type Server } from "http"; +import { AddressInfo } from "net"; +import { JSONRPCMessage } from "../types.js"; +import { SSEClientTransport } from "./sse.js"; + +describe("SSEClientTransport", () => { + let server: Server; + let transport: SSEClientTransport; + let baseUrl: URL; + let lastServerRequest: IncomingMessage; + let sendServerMessage: ((message: string) => void) | null = null; + + beforeEach((done) => { + // Reset state + lastServerRequest = null as unknown as IncomingMessage; + sendServerMessage = null; + + // Create a test server that will receive the EventSource connection + server = createServer((req, res) => { + lastServerRequest = req; + + // Send SSE headers + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + + // Send the endpoint event + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + + // Store reference to send function for tests + sendServerMessage = (message: string) => { + res.write(`data: ${message}\n\n`); + }; + + // Handle request body for POST endpoints + if (req.method === "POST") { + let body = ""; + req.on("data", (chunk) => { + body += chunk; + }); + req.on("end", () => { + (req as IncomingMessage & { body: string }).body = body; + res.end(); + }); + } + }); + + // Start server on random port + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + done(); + }); + }); + + afterEach(async () => { + await transport.close(); + await server.close(); + }); + + describe("connection handling", () => { + it("establishes SSE connection and receives endpoint", async () => { + transport = new SSEClientTransport(baseUrl); + await transport.start(); + + expect(lastServerRequest.headers.accept).toBe("text/event-stream"); + expect(lastServerRequest.method).toBe("GET"); + }); + + it("rejects if server returns non-200 status", async () => { + // Create a server that returns 403 + server.close(); + await new Promise((resolve) => server.on("close", resolve)); + + server = createServer((req, res) => { + res.writeHead(403); + res.end(); + }); + + await new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl); + await expect(transport.start()).rejects.toThrow(); + }); + + it("closes EventSource connection on close()", async () => { + transport = new SSEClientTransport(baseUrl); + await transport.start(); + + const closePromise = new Promise((resolve) => { + lastServerRequest.on("close", resolve); + }); + + await transport.close(); + await closePromise; + }); + }); + + describe("message handling", () => { + it("receives and parses JSON-RPC messages", async () => { + const receivedMessages: JSONRPCMessage[] = []; + transport = new SSEClientTransport(baseUrl); + transport.onmessage = (msg) => receivedMessages.push(msg); + + await transport.start(); + + const testMessage: JSONRPCMessage = { + jsonrpc: "2.0", + id: "test-1", + method: "test", + params: { foo: "bar" }, + }; + + sendServerMessage!(JSON.stringify(testMessage)); + + // Wait for message processing + await new Promise((resolve) => setTimeout(resolve, 50)); + + expect(receivedMessages).toHaveLength(1); + expect(receivedMessages[0]).toEqual(testMessage); + }); + + it("handles malformed JSON messages", async () => { + const errors: Error[] = []; + transport = new SSEClientTransport(baseUrl); + transport.onerror = (err) => errors.push(err); + + await transport.start(); + + sendServerMessage!("invalid json"); + + // Wait for message processing + await new Promise((resolve) => setTimeout(resolve, 50)); + + expect(errors).toHaveLength(1); + expect(errors[0].message).toMatch(/JSON/); + }); + + it("handles messages via POST requests", async () => { + transport = new SSEClientTransport(baseUrl); + await transport.start(); + + const testMessage: JSONRPCMessage = { + jsonrpc: "2.0", + id: "test-1", + method: "test", + params: { foo: "bar" }, + }; + + await transport.send(testMessage); + + // Wait for request processing + await new Promise((resolve) => setTimeout(resolve, 50)); + + expect(lastServerRequest.method).toBe("POST"); + expect(lastServerRequest.headers["content-type"]).toBe( + "application/json", + ); + expect( + JSON.parse( + (lastServerRequest as IncomingMessage & { body: string }).body, + ), + ).toEqual(testMessage); + }); + + it("handles POST request failures", async () => { + // Create a server that returns 500 for POST + server.close(); + await new Promise((resolve) => server.on("close", resolve)); + + server = createServer((req, res) => { + if (req.method === "GET") { + res.writeHead(200, { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }); + res.write("event: endpoint\n"); + res.write(`data: ${baseUrl.href}\n\n`); + } else { + res.writeHead(500); + res.end("Internal error"); + } + }); + + await new Promise((resolve) => { + server.listen(0, "127.0.0.1", () => { + const addr = server.address() as AddressInfo; + baseUrl = new URL(`http://127.0.0.1:${addr.port}`); + resolve(); + }); + }); + + transport = new SSEClientTransport(baseUrl); + await transport.start(); + + const testMessage: JSONRPCMessage = { + jsonrpc: "2.0", + id: "test-1", + method: "test", + params: {}, + }; + + await expect(transport.send(testMessage)).rejects.toThrow(/500/); + }); + }); + + describe("header handling", () => { + it("uses custom fetch implementation from EventSourceInit to add auth headers", async () => { + const authToken = "Bearer test-token"; + + // Create a fetch wrapper that adds auth header + const fetchWithAuth = (url: string | URL, init?: RequestInit) => { + const headers = new Headers(init?.headers); + headers.set("Authorization", authToken); + return fetch(url.toString(), { ...init, headers }); + }; + + transport = new SSEClientTransport(baseUrl, { + eventSourceInit: { + fetch: fetchWithAuth, + }, + }); + + await transport.start(); + + // Verify the auth header was received by the server + expect(lastServerRequest.headers.authorization).toBe(authToken); + }); + + it("passes custom headers to fetch requests", async () => { + const customHeaders = { + Authorization: "Bearer test-token", + "X-Custom-Header": "custom-value", + }; + + transport = new SSEClientTransport(baseUrl, { + requestInit: { + headers: customHeaders, + }, + }); + + await transport.start(); + + // Mock fetch for the message sending test + global.fetch = jest.fn().mockResolvedValue({ + ok: true, + }); + + const message: JSONRPCMessage = { + jsonrpc: "2.0", + id: "1", + method: "test", + params: {}, + }; + + await transport.send(message); + + // Verify fetch was called with correct headers + expect(global.fetch).toHaveBeenCalledWith( + expect.any(URL), + expect.objectContaining({ + headers: expect.any(Headers), + }), + ); + + const calledHeaders = (global.fetch as jest.Mock).mock.calls[0][1] + .headers; + expect(calledHeaders.get("Authorization")).toBe( + customHeaders.Authorization, + ); + expect(calledHeaders.get("X-Custom-Header")).toBe( + customHeaders["X-Custom-Header"], + ); + expect(calledHeaders.get("content-type")).toBe("application/json"); + }); + }); +}); diff --git a/src/client/sse.ts b/src/client/sse.ts index 0e6e7eb98..932e7e206 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -1,11 +1,10 @@ import { Transport } from "../shared/transport.js"; import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js"; +import { EventSource, type EventSourceInit } from "eventsource"; /** * Client transport for SSE: this will connect to a server using Server-Sent Events for receiving * messages and make separate POST requests for sending messages. - * - * This uses the EventSource API in browsers. You can install the `eventsource` package for Node.js. */ export class SSEClientTransport implements Transport { private _eventSource?: EventSource; @@ -19,7 +18,10 @@ export class SSEClientTransport implements Transport { onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; - constructor(url: URL, opts?: { eventSourceInit?: EventSourceInit, requestInit?: RequestInit }) { + constructor( + url: URL, + opts?: { eventSourceInit?: EventSourceInit; requestInit?: RequestInit }, + ) { this._url = url; this._eventSourceInit = opts?.eventSourceInit; this._requestInit = opts?.requestInit; @@ -33,7 +35,10 @@ export class SSEClientTransport implements Transport { } return new Promise((resolve, reject) => { - this._eventSource = new EventSource(this._url.href, this._eventSourceInit); + this._eventSource = new EventSource( + this._url.href, + this._eventSourceInit, + ); this._abortController = new AbortController(); this._eventSource.onerror = (event) => { @@ -101,7 +106,7 @@ export class SSEClientTransport implements Transport { method: "POST", headers, body: JSON.stringify(message), - signal: this._abortController?.signal + signal: this._abortController?.signal, }; const response = await fetch(this._endpoint, init);