Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1171 support for refine, superRefine, transform and lazy in discriminatedUnion #1290

Merged
merged 10 commits into from Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion deno/lib/README.md
Expand Up @@ -604,7 +604,7 @@ dateSchema.safeParse(new Date("1/12/22")); // success: true
dateSchema.safeParse("2022-01-12T00:00:00.000Z"); // success: true
```

## Zod enums-
## Zod enums

```ts
const FishEnum = z.enum(["Salmon", "Tuna", "Trout"]);
Expand Down
11 changes: 5 additions & 6 deletions deno/lib/__tests__/discriminatedUnions.test.ts
Expand Up @@ -25,6 +25,9 @@ test("valid - discriminator value of various primitive types", () => {
z.object({ type: z.literal(null), val: z.literal(7) }),
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
z.object({ type: z.literal(undefined), val: z.literal(9) }),
z.object({ type: z.literal("transform"), val: z.literal(10) }),
z.object({ type: z.literal("refine"), val: z.literal(11) }),
z.object({ type: z.literal("superRefine"), val: z.literal(12) }),
]);

expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
Expand Down Expand Up @@ -126,9 +129,7 @@ test("wrong schema - missing discriminator", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"The discriminator value could not be extracted from all the provided schemas"
);
expect(e.message.includes("could not be extracted")).toBe(true);
}
});

Expand All @@ -140,9 +141,7 @@ test("wrong schema - duplicate discriminator values", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"Some of the discriminator values are not unique"
);
expect(e.message.includes("has duplicate value")).toEqual(true);
}
});

Expand Down
124 changes: 72 additions & 52 deletions deno/lib/types.ts
Expand Up @@ -2137,33 +2137,46 @@ export class ZodUnion<T extends ZodUnionOptions> extends ZodType<
/////////////////////////////////////////////////////
/////////////////////////////////////////////////////

export type ZodDiscriminatedUnionOption<
Discriminator extends string,
DiscriminatorValue extends Primitive
> = ZodObject<
{ [key in Discriminator]: ZodLiteral<DiscriminatorValue> } & ZodRawShape,
any,
any
>;
const getDiscriminator = <T extends ZodTypeAny>(
type: T
): Primitive[] | null => {
if (type instanceof ZodLazy) {
return getDiscriminator(type.schema);
} else if (type instanceof ZodEffects) {
return getDiscriminator(type.innerType());
} else if (type instanceof ZodLiteral) {
return [type.value];
} else if (type instanceof ZodEnum) {
return type.options;
} else if (type instanceof ZodUndefined) {
return [undefined];
} else if (type instanceof ZodNull) {
return [null];
} else {
return null;
}
};

export type ZodDiscriminatedUnionOption<Discriminator extends string> =
ZodObject<{ [key in Discriminator]: ZodTypeAny } & ZodRawShape, any, any>;

export interface ZodDiscriminatedUnionDef<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
Options extends ZodDiscriminatedUnionOption<any>[]
> extends ZodTypeDef {
discriminator: Discriminator;
options: Map<DiscriminatorValue, Option>;
options: Options;
optionsMap: Map<Primitive, ZodDiscriminatedUnionOption<any>>;
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion;
}

export class ZodDiscriminatedUnion<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Option extends ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>
Options extends ZodDiscriminatedUnionOption<Discriminator>[]
> extends ZodType<
Option["_output"],
ZodDiscriminatedUnionDef<Discriminator, DiscriminatorValue, Option>,
Option["_input"]
output<Options[number]>,
ZodDiscriminatedUnionDef<Discriminator, Options>,
input<Options[number]>
> {
_parse(input: ParseInput): ParseReturnType<this["_output"]> {
const { ctx } = this._processInputParams(input);
Expand All @@ -2178,13 +2191,13 @@ export class ZodDiscriminatedUnion<
}

const discriminator = this.discriminator;
const discriminatorValue: DiscriminatorValue = ctx.data[discriminator];
const option = this.options.get(discriminatorValue);
const discriminatorValue: string = ctx.data[discriminator];
const option = this.optionsMap.get(discriminatorValue);

if (!option) {
addIssueToContext(ctx, {
code: ZodIssueCode.invalid_union_discriminator,
options: this.validDiscriminatorValues,
options: Array.from(this.optionsMap.keys()),
path: [discriminator],
});
return INVALID;
Expand All @@ -2195,28 +2208,28 @@ export class ZodDiscriminatedUnion<
data: ctx.data,
path: ctx.path,
parent: ctx,
});
}) as any;
} else {
return option._parseSync({
data: ctx.data,
path: ctx.path,
parent: ctx,
});
}) as any;
}
}

get discriminator() {
return this._def.discriminator;
}

get validDiscriminatorValues() {
return Array.from(this.options.keys());
}

get options() {
return this._def.options;
}

get optionsMap() {
return this._def.optionsMap;
}

/**
* The constructor of the discriminated union schema. Its behaviour is very similar to that of the normal z.union() constructor.
* However, it only allows a union of objects, all of which need to share a discriminator property. This property must
Expand All @@ -2227,44 +2240,45 @@ export class ZodDiscriminatedUnion<
*/
static create<
Discriminator extends string,
DiscriminatorValue extends Primitive,
Types extends [
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>,
...ZodDiscriminatedUnionOption<Discriminator, DiscriminatorValue>[]
ZodDiscriminatedUnionOption<Discriminator>,
...ZodDiscriminatedUnionOption<Discriminator>[]
]
>(
discriminator: Discriminator,
types: Types,
options: Types,
params?: RawCreateParams
): ZodDiscriminatedUnion<Discriminator, DiscriminatorValue, Types[number]> {
): ZodDiscriminatedUnion<Discriminator, Types> {
// Get all the valid discriminator values
const options: Map<DiscriminatorValue, Types[number]> = new Map();

try {
types.forEach((type) => {
const discriminatorValue = type.shape[discriminator].value;
options.set(discriminatorValue, type);
});
} catch (e) {
throw new Error(
"The discriminator value could not be extracted from all the provided schemas"
);
}

// Assert that all the discriminator values are unique
if (options.size !== types.length) {
throw new Error("Some of the discriminator values are not unique");
const optionsMap: Map<Primitive, Types[number]> = new Map();

// try {
for (const type of options) {
const discriminatorValues = getDiscriminator(type.shape[discriminator]);
if (!discriminatorValues) {
throw new Error(
`A discriminator value for key \`${discriminator}\`could not be extracted from all schema options`
);
}
for (const value of discriminatorValues) {
if (optionsMap.has(value)) {
throw new Error(
`Discriminator property ${discriminator} has duplicate value ${value}`
);
}
optionsMap.set(value, type);
}
}

return new ZodDiscriminatedUnion<
Discriminator,
DiscriminatorValue,
Types[number]
// DiscriminatorValue,
Types
>({
typeName: ZodFirstPartyTypeKind.ZodDiscriminatedUnion,
discriminator,
options,
optionsMap,
...processCreateParams(params),
});
}
Expand Down Expand Up @@ -3410,13 +3424,19 @@ export interface ZodEffectsDef<T extends ZodTypeAny = ZodTypeAny>

export class ZodEffects<
T extends ZodTypeAny,
Output = T["_output"],
Input = T["_input"]
Output = output<T>,
Input = input<T>
> extends ZodType<Output, ZodEffectsDef<T>, Input> {
innerType() {
return this._def.schema;
}

sourceType(): T {
return this._def.schema._def.typeName === ZodFirstPartyTypeKind.ZodEffects
? (this._def.schema as unknown as ZodEffects<T>).sourceType()
: (this._def.schema as T);
}

_parse(input: ParseInput): ParseReturnType<this["_output"]> {
const { status, ctx } = this._processInputParams(input);

Expand Down Expand Up @@ -3854,7 +3874,7 @@ export type ZodFirstPartySchemaTypes =
| ZodArray<any, any>
| ZodObject<any, any, any, any, any>
| ZodUnion<any>
| ZodDiscriminatedUnion<any, any, any>
| ZodDiscriminatedUnion<any, any>
| ZodIntersection<any, any>
| ZodTuple<any, any>
| ZodRecord<any, any>
Expand Down
11 changes: 5 additions & 6 deletions src/__tests__/discriminatedUnions.test.ts
Expand Up @@ -24,6 +24,9 @@ test("valid - discriminator value of various primitive types", () => {
z.object({ type: z.literal(null), val: z.literal(7) }),
z.object({ type: z.literal("undefined"), val: z.literal(8) }),
z.object({ type: z.literal(undefined), val: z.literal(9) }),
z.object({ type: z.literal("transform"), val: z.literal(10) }),
z.object({ type: z.literal("refine"), val: z.literal(11) }),
z.object({ type: z.literal("superRefine"), val: z.literal(12) }),
]);

expect(schema.parse({ type: "1", val: 1 })).toEqual({ type: "1", val: 1 });
Expand Down Expand Up @@ -125,9 +128,7 @@ test("wrong schema - missing discriminator", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"The discriminator value could not be extracted from all the provided schemas"
);
expect(e.message.includes("could not be extracted")).toBe(true);
}
});

Expand All @@ -139,9 +140,7 @@ test("wrong schema - duplicate discriminator values", () => {
]);
throw new Error();
} catch (e: any) {
expect(e.message).toEqual(
"Some of the discriminator values are not unique"
);
expect(e.message.includes("has duplicate value")).toEqual(true);
}
});

Expand Down