diff --git a/packages/core/src/create.ts b/packages/core/src/create.ts index 2d773f93df..d0d6325977 100644 --- a/packages/core/src/create.ts +++ b/packages/core/src/create.ts @@ -18,12 +18,17 @@ import { AfterCallback, AfterResolverPayload, Envelop, + ExecuteDoneOptions, ExecuteFunction, + OnExecuteHookResult, + OnExecutionDoneHookResult, OnResolverCalledHooks, + OnSubscribeHookResult, Plugin, SubscribeFunction, } from '@envelop/types'; -import { makeSubscribe, makeExecute } from './util'; +import { makeSubscribe, makeExecute, mapAsyncIterator, finalAsyncIterator } from './util'; +import isAsyncIterable from 'graphql/jsutils/isAsyncIterable'; const trackedSchemaSymbol = Symbol('TRACKED_SCHEMA'); export const resolversHooksSymbol = Symbol('RESOLVERS_HOOKS'); @@ -215,10 +220,7 @@ export function envelop({ plugins }: { plugins: Plugin[] }): Envelop { const onResolversHandlers: OnResolverCalledHooks[] = []; let subscribeFn = subscribe as SubscribeFunction; - const afterCalls: ((options: { - result: AsyncIterableIterator | ExecutionResult; - setResult: (newResult: AsyncIterableIterator | ExecutionResult) => void; - }) => void)[] = []; + const afterCalls: Exclude[] = []; let context = args.contextValue; for (const onSubscribe of onSubscribeCbs) { @@ -261,13 +263,44 @@ export function envelop({ plugins }: { plugins: Plugin[] }): Envelop { contextValue: context, }); + const onNextHandler: Exclude[] = []; + const onEndHandler: Exclude[] = []; + for (const afterCb of afterCalls) { - afterCb({ + const streamHandler = afterCb({ result, setResult: newResult => { result = newResult; }, - }); + isStream: isAsyncIterable(result), + } as ExecuteDoneOptions); + + if (streamHandler) { + if (streamHandler.onNext) { + onNextHandler.push(streamHandler.onNext); + } + if (streamHandler.onEnd) { + onEndHandler.push(streamHandler.onEnd); + } + } + } + + if (isAsyncIterable(result)) { + if (onNextHandler.length) { + result = mapAsyncIterator(result, result => { + for (const onNext of onNextHandler) { + onNext({ result, setResult: newResult => (result = newResult) }); + } + return result; + }); + } + if (onEndHandler.length) { + result = finalAsyncIterator(result, () => { + for (const onEnd of onEndHandler) { + onEnd(); + } + }); + } } return result; @@ -278,10 +311,9 @@ export function envelop({ plugins }: { plugins: Plugin[] }): Envelop { ? makeExecute(async args => { const onResolversHandlers: OnResolverCalledHooks[] = []; let executeFn: ExecuteFunction = execute as ExecuteFunction; - let result: ExecutionResult; + let result: ExecutionResult | AsyncIterableIterator; - const afterCalls: ((options: { result: ExecutionResult; setResult: (newResult: ExecutionResult) => void }) => void)[] = - []; + const afterCalls: Exclude[] = []; let context = args.contextValue; for (const onExecute of onExecuteCbs) { @@ -336,13 +368,44 @@ export function envelop({ plugins }: { plugins: Plugin[] }): Envelop { contextValue: context, }); + const onNextHandler: Exclude[] = []; + const onEndHandler: Exclude[] = []; + for (const afterCb of afterCalls) { - afterCb({ + const streamHandler = afterCb({ result, setResult: newResult => { result = newResult; }, - }); + isStream: isAsyncIterable(result), + } as ExecuteDoneOptions); + + if (streamHandler) { + if (streamHandler.onNext) { + onNextHandler.push(streamHandler.onNext); + } + if (streamHandler.onEnd) { + onEndHandler.push(streamHandler.onEnd); + } + } + } + + if (isAsyncIterable(result)) { + if (onNextHandler.length) { + result = mapAsyncIterator(result, result => { + for (const onNext of onNextHandler) { + onNext({ result, setResult: newResult => (result = newResult) }); + } + return result; + }); + } + if (onEndHandler.length) { + result = finalAsyncIterator(result, () => { + for (const onEnd of onEndHandler) { + onEnd(); + } + }); + } } return result; diff --git a/packages/core/src/graphql-typings.d.ts b/packages/core/src/graphql-typings.d.ts index 474528ad13..2d70480c56 100644 --- a/packages/core/src/graphql-typings.d.ts +++ b/packages/core/src/graphql-typings.d.ts @@ -1,4 +1,4 @@ declare module 'graphql/jsutils/isAsyncIterable' { - function isAsyncIterable(input: unknown): input is AsyncIterable; + function isAsyncIterable(input: unknown): input is AsyncIterableIterator; export default isAsyncIterable; } diff --git a/packages/core/src/util.ts b/packages/core/src/util.ts index a582a7137e..098b37ec73 100644 --- a/packages/core/src/util.ts +++ b/packages/core/src/util.ts @@ -21,8 +21,8 @@ export function getExecuteArgs(args: PolymorphicExecuteArguments): ExecutionArgs * Utility function for making a execute function that handles polymorphic arguments. */ export const makeExecute = - (executeFn: (args: ExecutionArgs) => PromiseOrValue) => - (...polyArgs: PolymorphicExecuteArguments): PromiseOrValue => + (executeFn: (args: ExecutionArgs) => PromiseOrValue>) => + (...polyArgs: PolymorphicExecuteArguments): PromiseOrValue> => executeFn(getExecuteArgs(polyArgs)); export function getSubscribeArgs(args: PolymorphicSubscribeArguments): SubscriptionArgs { @@ -47,3 +47,23 @@ export const makeSubscribe = (subscribeFn: (args: SubscriptionArgs) => PromiseOrValue | ExecutionResult>) => (...polyArgs: PolymorphicSubscribeArguments): PromiseOrValue | ExecutionResult> => subscribeFn(getSubscribeArgs(polyArgs)); + +export async function* mapAsyncIterator( + asyncIterable: AsyncIterableIterator, + map: (input: TInput) => Promise | TOutput +): AsyncIterableIterator { + for await (const value of asyncIterable) { + yield map(value); + } +} + +export async function* finalAsyncIterator( + asyncIterable: AsyncIterableIterator, + onFinal: () => void +): AsyncIterableIterator { + try { + yield* asyncIterable; + } finally { + onFinal(); + } +} diff --git a/packages/core/test/common.ts b/packages/core/test/common.ts index 221de78b8d..ed849788fd 100644 --- a/packages/core/test/common.ts +++ b/packages/core/test/common.ts @@ -1,12 +1,13 @@ import { EventEmitter, on } from 'events'; import { GraphQLID, GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLString } from 'graphql'; +import isAsyncIterable from 'graphql/jsutils/isAsyncIterable'; const createPubSub = (emitter: EventEmitter) => { return { publish: >(topic: TTopic, payload: TTopicPayload[TTopic]) => { emitter.emit(topic as string, payload); }, - subscribe: async function*>( + subscribe: async function* >( topic: TTopic ): AsyncIterableIterator { const asyncIterator = on(emitter, topic); @@ -52,7 +53,7 @@ const GraphQLSubscription = new GraphQLObjectType({ fields: { ping: { type: GraphQLString, - subscribe: async function*() { + subscribe: async function* () { const stream = pubSub.subscribe('ping'); return yield* stream; }, @@ -79,3 +80,17 @@ export const subscription = /* GraphQL */ ` ping } `; + +export const collectAsyncIteratorValues = async (asyncIterable: AsyncIterableIterator): Promise> => { + const values: Array = []; + for await (const value of asyncIterable) { + values.push(value); + } + return values; +}; + +export function assertAsyncIterator(input: unknown): asserts input is AsyncIterableIterator { + if (!isAsyncIterable(input)) { + throw new Error('Expected AsyncIterable iterator.'); + } +} diff --git a/packages/core/test/execute.spec.ts b/packages/core/test/execute.spec.ts index f6b78822d4..4ccd2197ff 100644 --- a/packages/core/test/execute.spec.ts +++ b/packages/core/test/execute.spec.ts @@ -1,6 +1,7 @@ import { createSpiedPlugin, createTestkit } from '@envelop/testing'; -import { execute, GraphQLSchema } from 'graphql'; -import { schema, query } from './common'; +import { execute, ExecutionResult } from 'graphql'; +import { ExecuteFunction } from 'packages/types/src'; +import { schema, query, assertAsyncIterator, collectAsyncIteratorValues } from './common'; describe('execute', () => { it('Should wrap and trigger events correctly', async () => { @@ -31,6 +32,7 @@ describe('execute', () => { expect(spiedPlugin.spies.afterResolver).toHaveBeenCalledTimes(3); expect(spiedPlugin.spies.afterExecute).toHaveBeenCalledTimes(1); expect(spiedPlugin.spies.afterExecute).toHaveBeenCalledWith({ + isStream: false, setResult: expect.any(Function), result: { data: { @@ -132,4 +134,122 @@ describe('execute', () => { setResult: expect.any(Function), }); }); + + it('Should be able to manipulate streams', async () => { + const streamExecuteFn = async function* () { + for (const value of ['a', 'b', 'c', 'd']) { + yield { data: { alphabet: value } }; + } + }; + + const teskit = createTestkit( + [ + { + onExecute({ setExecuteFn }) { + setExecuteFn(streamExecuteFn); + + return { + onExecuteDone: () => { + return { + onNext: ({ setResult }) => { + setResult({ data: { alphabet: 'x' } }); + }, + }; + }, + }; + }, + }, + ], + schema + ); + + const result: ReturnType = await teskit.executeRaw({} as any); + assertAsyncIterator(result); + const values = await collectAsyncIteratorValues(result); + expect(values).toEqual([ + { data: { alphabet: 'x' } }, + { data: { alphabet: 'x' } }, + { data: { alphabet: 'x' } }, + { data: { alphabet: 'x' } }, + ]); + }); + + it('Should be able to invoke something after the stream has ended.', async () => { + expect.assertions(1); + const streamExecuteFn = async function* () { + for (const value of ['a', 'b', 'c', 'd']) { + yield { data: { alphabet: value } }; + } + }; + + const teskit = createTestkit( + [ + { + onExecute({ setExecuteFn }) { + setExecuteFn(streamExecuteFn); + + return { + onExecuteDone: () => { + let latestResult: ExecutionResult; + return { + onNext: ({ result }) => { + latestResult = result; + }, + onEnd: () => { + expect(latestResult).toEqual({ data: { alphabet: 'd' } }); + }, + }; + }, + }; + }, + }, + ], + schema + ); + + const result: ReturnType = await teskit.executeRaw({} as any); + assertAsyncIterator(result); + // run AsyncGenerator + await collectAsyncIteratorValues(result); + }); + + it('Should be able to invoke something after the stream has ended (manual return).', async () => { + expect.assertions(1); + const streamExecuteFn = async function* () { + for (const value of ['a', 'b', 'c', 'd']) { + yield { data: { alphabet: value } }; + } + }; + + const teskit = createTestkit( + [ + { + onExecute({ setExecuteFn }) { + setExecuteFn(streamExecuteFn); + + return { + onExecuteDone: () => { + let latestResult: ExecutionResult; + return { + onNext: ({ result }) => { + latestResult = result; + }, + onEnd: () => { + expect(latestResult).toEqual({ data: { alphabet: 'a' } }); + }, + }; + }, + }; + }, + }, + ], + schema + ); + + const result: ReturnType = await teskit.executeRaw({} as any); + assertAsyncIterator(result); + const instance = result[Symbol.asyncIterator](); + await instance.next(); + await instance.return!(); + }); }); diff --git a/packages/testing/src/index.ts b/packages/testing/src/index.ts index 6603a1313c..3d8798f641 100644 --- a/packages/testing/src/index.ts +++ b/packages/testing/src/index.ts @@ -1,7 +1,7 @@ -import { DocumentNode, ExecutionResult, GraphQLSchema, print } from 'graphql'; +import { DocumentNode, ExecutionArgs, ExecutionResult, GraphQLSchema, print } from 'graphql'; import { getGraphQLParameters, processRequest, Push } from 'graphql-helix'; -import { envelop, useSchema } from '@envelop/core'; -import { Envelop, Plugin } from '@envelop/types'; +import { envelop, makeExecute, useSchema } from '@envelop/core'; +import { Envelop, ExecuteFunction, Plugin } from '@envelop/types'; // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export function createSpiedPlugin() { @@ -57,6 +57,7 @@ export function createTestkit( subscribe: (operation: DocumentNode | string, variables?: Record, initialContext?: any) => Promise>; replaceSchema: (schema: GraphQLSchema) => void; wait: (ms: number) => Promise; + executeRaw: ExecuteFunction; } { let replaceSchema: (s: GraphQLSchema) => void = () => {}; @@ -99,7 +100,6 @@ export function createTestkit( contextFactory: initialContext ? () => proxy.contextFactory(initialContext) : proxy.contextFactory, schema: proxy.schema, }); - return (r as any).payload as ExecutionResult; }, subscribe: async (operation, rawVariables = {}, initialContext = null) => { @@ -134,6 +134,13 @@ export function createTestkit( return r; }, + executeRaw: makeExecute(async (args: ExecutionArgs) => { + const proxy = initRequest(); + return await proxy.execute({ + ...args, + contextValue: await proxy.contextFactory(args.contextValue), + }); + }), }; } diff --git a/packages/types/src/index.ts b/packages/types/src/index.ts index 308576b466..74670ad2f5 100644 --- a/packages/types/src/index.ts +++ b/packages/types/src/index.ts @@ -32,7 +32,9 @@ export type PolymorphicExecuteArguments = Maybe> ]; -export type ExecuteFunction = (...args: PolymorphicExecuteArguments) => PromiseOrValue; +export type ExecuteFunction = ( + ...args: PolymorphicExecuteArguments +) => PromiseOrValue>; export type PolymorphicSubscribeArguments = | [SubscriptionArgs] @@ -74,16 +76,34 @@ export type OnResolverCalledHooks; +export type OnExecutionDoneHookResult = { + onNext?: (options: { result: ExecutionResult; setResult: (newResult: ExecutionResult) => void }) => void; + onEnd?: () => void; +}; + +export type onSubscriptionDoneResult = OnExecutionDoneHookResult; + +type ExecuteDoneNonStreamOptions = { + result: ExecutionResult; + setResult: (newResult: ExecutionResult | AsyncIterableIterator) => void; + isStream: false; +}; + +type ExecuteDoneStreamOptions = { + result: AsyncIterableIterator; + setResult: (newResult: ExecutionResult | AsyncIterableIterator) => void; + isStream: true; +}; + +export type ExecuteDoneOptions = ExecuteDoneNonStreamOptions | ExecuteDoneStreamOptions; + export type OnExecuteHookResult = { - onExecuteDone?: (options: { result: ExecutionResult; setResult: (newResult: ExecutionResult) => void }) => void; + onExecuteDone?: (options: ExecuteDoneOptions) => OnExecutionDoneHookResult | void; onResolverCalled?: OnResolverCalledHooks; }; export type OnSubscribeHookResult = { - onSubscribeResult?: (options: { - result: AsyncIterableIterator | ExecutionResult; - setResult: (newResult: AsyncIterableIterator | ExecutionResult) => void; - }) => void; + onSubscribeResult?: (options: ExecuteDoneOptions) => void | onSubscriptionDoneResult; onResolverCalled?: OnResolverCalledHooks; };