diff --git a/src/adapter/bun/handler.ts b/src/adapter/bun/handler.ts index aa26d2f7..b6a32a05 100644 --- a/src/adapter/bun/handler.ts +++ b/src/adapter/bun/handler.ts @@ -567,5 +567,6 @@ const handleResponse = createResponseHandler({ const handleStream = createStreamHandler({ mapResponse, - mapCompactResponse + mapCompactResponse, + streamOptions: {} }) diff --git a/src/adapter/utils.ts b/src/adapter/utils.ts index e9eb1e73..e1526b75 100644 --- a/src/adapter/utils.ts +++ b/src/adapter/utils.ts @@ -143,17 +143,22 @@ type CreateHandlerParameter = { request?: Request ): Response mapCompactResponse(response: unknown, request?: Request): Response + streamOptions?: { + autoCancellation?: boolean + } } const allowRapidStream = env.ELYSIA_RAPID_STREAM === 'true' +const defaultAutoCancellation = env.ELYSIA_STREAM_AUTO_CANCELLATION !== 'false' export const createStreamHandler = - ({ mapResponse, mapCompactResponse }: CreateHandlerParameter) => + ({ mapResponse, mapCompactResponse, streamOptions }: CreateHandlerParameter) => async ( generator: Generator | AsyncGenerator | ReadableStream, set?: Context['set'], request?: Request ) => { + const autoCancellation = streamOptions?.autoCancellation ?? defaultAutoCancellation // Since ReadableStream doesn't have next, init might be undefined let init = (generator as Generator).next?.() as | IteratorResult @@ -214,13 +219,15 @@ export const createStreamHandler = async start(controller) { let end = false - request?.signal?.addEventListener('abort', () => { - end = true + if (autoCancellation) { + request?.signal?.addEventListener('abort', () => { + end = true - try { - controller.close() - } catch {} - }) + try { + controller.close() + } catch {} + }) + } if (!init || init.value instanceof ReadableStream) { } else if ( diff --git a/src/adapter/web-standard/handler.ts b/src/adapter/web-standard/handler.ts index 18121d3b..57fcfbcd 100644 --- a/src/adapter/web-standard/handler.ts +++ b/src/adapter/web-standard/handler.ts @@ -600,5 +600,6 @@ const handleResponse = createResponseHandler({ const handleStream = createStreamHandler({ mapResponse, - mapCompactResponse + mapCompactResponse, + streamOptions: {} }) diff --git a/src/types.ts b/src/types.ts index 3ccd25d5..a1705ce8 100644 --- a/src/types.ts +++ b/src/types.ts @@ -168,6 +168,20 @@ export interface ElysiaConfig { WebSocketHandler, 'open' | 'close' | 'message' | 'drain' > + /** + * Stream response configuration + */ + stream?: { + /** + * Enable automatic cancellation of streams when the client disconnects + * + * When enabled, Elysia will automatically stop generator functions + * if the client cancels the request before streaming is completed + * + * @default true + */ + autoCancellation?: boolean + } cookie?: CookieOptions & { /** * Specified cookie name to be signed globally diff --git a/test/response/stream.test.ts b/test/response/stream.test.ts index 5c86f883..d69fb381 100644 --- a/test/response/stream.test.ts +++ b/test/response/stream.test.ts @@ -560,4 +560,122 @@ describe('Stream', () => { expect(result).toEqual(['Elysia', 'Eden'].map((x) => `data: ${x}\n\n`)) expect(response.headers.get('content-type')).toBe('text/event-stream') }) + + + it('continue stream when autoCancellation is disabled', async () => { + const { createStreamHandler } = await import('../../src/adapter/utils') + const sideEffects: string[] = [] + + const mockMapResponse = (value: any) => + new Response(value) + const mockMapCompactResponse = (value: any) => + new Response(value) + + const streamHandler = createStreamHandler({ + mapResponse: mockMapResponse, + mapCompactResponse: mockMapCompactResponse, + streamOptions: { autoCancellation: false } + }) + + async function* testGenerator() { + sideEffects.push('a') + yield 'a' + await Bun.sleep(20) + + sideEffects.push('b') + yield 'b' + await Bun.sleep(20) + + sideEffects.push('c') + yield 'c' + } + + const controller = new AbortController() + const request = new Request('http://e.ly', { + signal: controller.signal + }) + + setTimeout(() => { + controller.abort() + }, 15) + + const response = await streamHandler(testGenerator(), undefined, request) + const reader = response.body?.getReader() + + if (!reader) throw new Error('No reader') + + const results: string[] = [] + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + results.push(new TextDecoder().decode(value)) + } + } catch (e) { + // Expected to error when abort happens + } + + await Bun.sleep(100) + + expect(sideEffects).toEqual(['a', 'b', 'c']) + }) + + it('stop stream when autoCancellation is enabled (default)', async () => { + const { createStreamHandler } = await import('../../src/adapter/utils') + const sideEffects: string[] = [] + + const mockMapResponse = (value: any) => + new Response(value) + const mockMapCompactResponse = (value: any) => + new Response(value) + + const streamHandler = createStreamHandler({ + mapResponse: mockMapResponse, + mapCompactResponse: mockMapCompactResponse, + streamOptions: { autoCancellation: true } + }) + + async function* testGenerator() { + sideEffects.push('a') + yield 'a' + await Bun.sleep(20) + + sideEffects.push('b') + yield 'b' + await Bun.sleep(20) + + sideEffects.push('c') + yield 'c' + } + + const controller = new AbortController() + const request = new Request('http://e.ly', { + signal: controller.signal + }) + + setTimeout(() => { + controller.abort() + }, 15) + + const response = await streamHandler(testGenerator(), undefined, request) + const reader = response.body?.getReader() + + if (!reader) throw new Error('No reader') + + const results: string[] = [] + try { + while (true) { + const { done, value } = await reader.read() + if (done) break + results.push(new TextDecoder().decode(value)) + } + } catch (e) { + // Expected to error when abort happens + } + + await Bun.sleep(100) + + expect(sideEffects).toHaveLength(2) + expect(sideEffects).toEqual(['a', 'b']) + }) })