From 77b79eff5e51d35d728f4d956ce6fc8f4a034d9a Mon Sep 17 00:00:00 2001 From: Robin Labat Date: Tue, 26 Jul 2022 17:30:11 +0200 Subject: [PATCH] #1171 --- .../lib/__tests__/discriminatedUnions.test.ts | 9 +++++ deno/lib/types.ts | 35 +++++++++++++++++-- src/__tests__/discriminatedUnions.test.ts | 9 +++++ src/types.ts | 35 +++++++++++++++++-- 4 files changed, 82 insertions(+), 6 deletions(-) diff --git a/deno/lib/__tests__/discriminatedUnions.test.ts b/deno/lib/__tests__/discriminatedUnions.test.ts index a5f8f8996..c61cb54a4 100644 --- a/deno/lib/__tests__/discriminatedUnions.test.ts +++ b/deno/lib/__tests__/discriminatedUnions.test.ts @@ -25,6 +25,11 @@ test("valid - discriminator value of various primitive types", () => { z.object({ type: z.literal(null), val: z.literal(7) }), z.object({ type: z.literal("undefined"), val: z.literal(8) }), z.object({ type: z.literal(undefined), val: z.literal(9) }), + z + .object({ type: z.literal("transform"), val: z.literal(10) }) + .transform((val) => ({ + val, + })), ]); expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 }); @@ -57,6 +62,10 @@ test("valid - discriminator value of various primitive types", () => { type: undefined, val: 9, }); + // console.log("test"); + // expect(schema.parse({ type: "transform", val: 10 })).toEqual({ + // val: 10, + // }); }); test("invalid - null", () => { diff --git a/deno/lib/types.ts b/deno/lib/types.ts index 13f6eaaa9..3024d6a97 100644 --- a/deno/lib/types.ts +++ b/deno/lib/types.ts @@ -2137,7 +2137,7 @@ export class ZodUnion extends ZodType< ///////////////////////////////////////////////////// ///////////////////////////////////////////////////// -export type ZodDiscriminatedUnionOption< +export type ZodDiscriminatedUnionOptionBase< Discriminator extends string, DiscriminatorValue extends Primitive > = ZodObject< @@ -2146,6 +2146,15 @@ export type ZodDiscriminatedUnionOption< any >; +export type ZodDiscriminatedUnionOption< + Discriminator extends string, + DiscriminatorValue extends Primitive +> = + | ZodDiscriminatedUnionOptionBase + | ZodEffects< + ZodDiscriminatedUnionOptionBase + >; + export interface ZodDiscriminatedUnionDef< Discriminator extends string, DiscriminatorValue extends Primitive, @@ -2243,8 +2252,22 @@ export class ZodDiscriminatedUnion< try { types.forEach((type) => { - const discriminatorValue = type.shape[discriminator].value; - options.set(discriminatorValue, type); + if (type._def.typeName === ZodFirstPartyTypeKind.ZodObject) { + const discriminatorValue = ( + type as ZodDiscriminatedUnionOptionBase< + Discriminator, + DiscriminatorValue + > + ).shape[discriminator].value; + options.set(discriminatorValue, type); + } else if (type._def.typeName === ZodFirstPartyTypeKind.ZodEffects) { + const discriminatorValue = ( + type as ZodEffects< + ZodDiscriminatedUnionOptionBase + > + ).sourceType().shape[discriminator].value; + options.set(discriminatorValue, type); + } }); } catch (e) { throw new Error( @@ -3417,6 +3440,12 @@ export class ZodEffects< return this._def.schema; } + sourceType(): T { + return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects + ? (this._def.schema as unknown as ZodEffects).sourceType() + : (this._def.schema as T); + } + _parse(input: ParseInput): ParseReturnType { const { status, ctx } = this._processInputParams(input); diff --git a/src/__tests__/discriminatedUnions.test.ts b/src/__tests__/discriminatedUnions.test.ts index 3d26a14d9..dc6e70814 100644 --- a/src/__tests__/discriminatedUnions.test.ts +++ b/src/__tests__/discriminatedUnions.test.ts @@ -24,6 +24,11 @@ test("valid - discriminator value of various primitive types", () => { z.object({ type: z.literal(null), val: z.literal(7) }), z.object({ type: z.literal("undefined"), val: z.literal(8) }), z.object({ type: z.literal(undefined), val: z.literal(9) }), + z + .object({ type: z.literal("transform"), val: z.literal(10) }) + .transform((val) => ({ + val, + })), ]); expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 }); @@ -56,6 +61,10 @@ test("valid - discriminator value of various primitive types", () => { type: undefined, val: 9, }); + // console.log("test"); + // expect(schema.parse({ type: "transform", val: 10 })).toEqual({ + // val: 10, + // }); }); test("invalid - null", () => { diff --git a/src/types.ts b/src/types.ts index 1d4badcdd..2518f083f 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2137,7 +2137,7 @@ export class ZodUnion extends ZodType< ///////////////////////////////////////////////////// ///////////////////////////////////////////////////// -export type ZodDiscriminatedUnionOption< +export type ZodDiscriminatedUnionOptionBase< Discriminator extends string, DiscriminatorValue extends Primitive > = ZodObject< @@ -2146,6 +2146,15 @@ export type ZodDiscriminatedUnionOption< any >; +export type ZodDiscriminatedUnionOption< + Discriminator extends string, + DiscriminatorValue extends Primitive +> = + | ZodDiscriminatedUnionOptionBase + | ZodEffects< + ZodDiscriminatedUnionOptionBase + >; + export interface ZodDiscriminatedUnionDef< Discriminator extends string, DiscriminatorValue extends Primitive, @@ -2243,8 +2252,22 @@ export class ZodDiscriminatedUnion< try { types.forEach((type) => { - const discriminatorValue = type.shape[discriminator].value; - options.set(discriminatorValue, type); + if (type._def.typeName === ZodFirstPartyTypeKind.ZodObject) { + const discriminatorValue = ( + type as ZodDiscriminatedUnionOptionBase< + Discriminator, + DiscriminatorValue + > + ).shape[discriminator].value; + options.set(discriminatorValue, type); + } else if (type._def.typeName === ZodFirstPartyTypeKind.ZodEffects) { + const discriminatorValue = ( + type as ZodEffects< + ZodDiscriminatedUnionOptionBase + > + ).sourceType().shape[discriminator].value; + options.set(discriminatorValue, type); + } }); } catch (e) { throw new Error( @@ -3417,6 +3440,12 @@ export class ZodEffects< return this._def.schema; } + sourceType(): T { + return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects + ? (this._def.schema as unknown as ZodEffects).sourceType() + : (this._def.schema as T); + } + _parse(input: ParseInput): ParseReturnType { const { status, ctx } = this._processInputParams(input);