diff --git a/src/generator/schema.ts b/src/generator/schema.ts index 1295371b..9afeae6d 100644 --- a/src/generator/schema.ts +++ b/src/generator/schema.ts @@ -1,4 +1,5 @@ import { TRPCError } from '@trpc/server'; +import e from 'express'; import { OpenAPIV3 } from 'openapi-types'; import { z } from 'zod'; import zodToJsonSchema from 'zod-to-json-schema'; @@ -7,28 +8,28 @@ const zodSchemaToOpenApiSchemaObject = (zodSchema: z.ZodType): OpenAPIV3.SchemaO return zodToJsonSchema(zodSchema, { target: 'openApi3' }); }; -const instanceofZod = (schema: any): schema is z.ZodType => { - return !!schema?._def?.typeName; +const instanceofZod = (type: any): type is z.ZodType => { + return !!type?._def?.typeName; }; const instanceofZodTypeKind = ( - schema: any, + type: any, zodTypeKind: Z, -): schema is InstanceType => { - return schema?._def?.typeName === zodTypeKind; +): type is InstanceType => { + return type?._def?.typeName === zodTypeKind; }; -const getBaseZodType = (schema: z.ZodType): z.ZodType => { - if (instanceofZodTypeKind(schema, z.ZodFirstPartyTypeKind.ZodOptional)) { - return getBaseZodType(schema.unwrap()); +const unwrapZodType = (type: z.ZodType): z.ZodType => { + if (instanceofZodTypeKind(type, z.ZodFirstPartyTypeKind.ZodOptional)) { + return unwrapZodType(type.unwrap()); } - if (instanceofZodTypeKind(schema, z.ZodFirstPartyTypeKind.ZodDefault)) { - return getBaseZodType(schema.removeDefault()); + if (instanceofZodTypeKind(type, z.ZodFirstPartyTypeKind.ZodDefault)) { + return unwrapZodType(type.removeDefault()); } - if (instanceofZodTypeKind(schema, z.ZodFirstPartyTypeKind.ZodEffects)) { - return getBaseZodType(schema.innerType()); + if (instanceofZodTypeKind(type, z.ZodFirstPartyTypeKind.ZodEffects)) { + return unwrapZodType(type.innerType()); } - return schema; + return type; }; export const getParameterObjects = ( @@ -75,7 +76,16 @@ export const getParameterObjects = ( .map((key) => { const value = shape[key]!; - if (!instanceofZodTypeKind(getBaseZodType(value), z.ZodFirstPartyTypeKind.ZodString)) { + const unwrappedZodType = unwrapZodType(value); + if ( + !instanceofZodTypeKind(unwrappedZodType, z.ZodFirstPartyTypeKind.ZodString) && + !instanceofZodTypeKind(unwrappedZodType, z.ZodFirstPartyTypeKind.ZodEnum) && + !instanceofZodTypeKind(unwrappedZodType, z.ZodFirstPartyTypeKind.ZodNativeEnum) && + !( + instanceofZodTypeKind(unwrappedZodType, z.ZodFirstPartyTypeKind.ZodLiteral) && + typeof unwrappedZodType._def.value === 'string' + ) + ) { throw new TRPCError({ message: `Input parser key: "${key}" must be a ZodString`, code: 'INTERNAL_SERVER_ERROR', diff --git a/test/generator.test.ts b/test/generator.test.ts index 4786b4c0..604f2baa 100644 --- a/test/generator.test.ts +++ b/test/generator.test.ts @@ -1,5 +1,6 @@ import * as trpc from '@trpc/server'; import { Subscription } from '@trpc/server'; +import e from 'express'; import openAPISchemaValidator from 'openapi-schema-validator'; import { z } from 'zod'; @@ -212,7 +213,7 @@ describe('generator', () => { } }); - test('with non-object-string-value input', () => { + test('with object-non-string-value input', () => { { const appRouter = trpc.router().query('badInput', { meta: { openapi: { enabled: true, path: '/bad-input', method: 'GET' } }, @@ -270,6 +271,121 @@ describe('generator', () => { } }); + test('with object-enum-value input', () => { + enum NativeNameEnum { + James = 'James', + jlalmes = 'jlalmes', + } + + const appRouter = trpc + .router() + .query('enum', { + meta: { openapi: { enabled: true, path: '/enum', method: 'GET' } }, + input: z.object({ name: z.enum(['James', 'jlalmes']) }), + output: z.object({ name: z.enum(['James', 'jlalmes']) }), + resolve: () => ({ name: 'jlalmes' as const }), + }) + .query('nativeEnum', { + meta: { openapi: { enabled: true, path: '/native-enum', method: 'GET' } }, + input: z.object({ age: z.nativeEnum(NativeNameEnum) }), + output: z.object({ name: z.nativeEnum(NativeNameEnum) }), + resolve: () => ({ name: NativeNameEnum.James }), + }); + + const openApiDocument = generateOpenApiDocument(appRouter, { + title: 'tRPC OpenAPI', + version: '1.0.0', + baseUrl: 'http://localhost:3000/api', + }); + + expect(openApiSchemaValidator.validate(openApiDocument).errors).toEqual([]); + expect(openApiDocument.paths['/enum']!.get!.parameters).toMatchInlineSnapshot(` + Array [ + Object { + "description": undefined, + "in": "query", + "name": "name", + "required": true, + "schema": Object { + "enum": Array [ + "James", + "jlalmes", + ], + "type": "string", + }, + }, + ] + `); + expect(openApiDocument.paths['/native-enum']!.get!.parameters).toMatchInlineSnapshot(` + Array [ + Object { + "description": undefined, + "in": "query", + "name": "age", + "required": true, + "schema": Object { + "enum": Array [ + "James", + "jlalmes", + ], + "type": "string", + }, + }, + ] + `); + }); + + test('with object-literal-value input', () => { + { + const appRouter = trpc.router().query('numberLiteral', { + meta: { openapi: { enabled: true, path: '/number-literal', method: 'GET' } }, + input: z.object({ num: z.literal(123) }), + output: z.object({ num: z.literal(123) }), + resolve: () => ({ num: 123 as const }), + }); + + expect(() => { + generateOpenApiDocument(appRouter, { + title: 'tRPC OpenAPI', + version: '1.0.0', + baseUrl: 'http://localhost:3000/api', + }); + }).toThrowError('[query.numberLiteral] - Input parser key: "num" must be a ZodString'); + } + { + const appRouter = trpc.router().query('stringLiteral', { + meta: { openapi: { enabled: true, path: '/string-literal', method: 'GET' } }, + input: z.object({ str: z.literal('strlitval') }), + output: z.object({ str: z.literal('strlitval') }), + resolve: () => ({ str: 'strlitval' as const }), + }); + + const openApiDocument = generateOpenApiDocument(appRouter, { + title: 'tRPC OpenAPI', + version: '1.0.0', + baseUrl: 'http://localhost:3000/api', + }); + + expect(openApiSchemaValidator.validate(openApiDocument).errors).toEqual([]); + expect(openApiDocument.paths['/string-literal']!.get!.parameters).toMatchInlineSnapshot(` + Array [ + Object { + "description": undefined, + "in": "query", + "name": "str", + "required": true, + "schema": Object { + "enum": Array [ + "strlitval", + ], + "type": "string", + }, + }, + ] + `); + } + }); + test('with bad method', () => { { const appRouter = trpc.router().query('postQuery', {