Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support inferred types in conditionals #1265

Merged
merged 19 commits into from Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions factory/parser.ts
Expand Up @@ -22,6 +22,7 @@ import { FunctionNodeParser } from "../src/NodeParser/FunctionNodeParser";
import { FunctionParser } from "../src/NodeParser/FunctionParser";
import { HiddenNodeParser } from "../src/NodeParser/HiddenTypeNodeParser";
import { IndexedAccessTypeNodeParser } from "../src/NodeParser/IndexedAccessTypeNodeParser";
import { InferTypeNodeParser } from "../src/NodeParser/InferTypeNodeParser";
import { InterfaceAndClassNodeParser } from "../src/NodeParser/InterfaceAndClassNodeParser";
import { IntersectionNodeParser } from "../src/NodeParser/IntersectionNodeParser";
import { IntrinsicNodeParser } from "../src/NodeParser/IntrinsicNodeParser";
Expand Down Expand Up @@ -122,6 +123,8 @@ export function createParser(program: ts.Program, config: Config, augmentor?: Pa
.addNodeParser(new TypeReferenceNodeParser(typeChecker, chainNodeParser))
.addNodeParser(new ExpressionWithTypeArgumentsNodeParser(typeChecker, chainNodeParser))

.addNodeParser(new InferTypeNodeParser(typeChecker, chainNodeParser))

.addNodeParser(new IndexedAccessTypeNodeParser(chainNodeParser))
.addNodeParser(new TypeofNodeParser(typeChecker, chainNodeParser))
.addNodeParser(new MappedTypeNodeParser(chainNodeParser, mergedConfig.additionalProperties))
Expand Down
48 changes: 34 additions & 14 deletions src/NodeParser/ConditionalTypeNodeParser.ts
Expand Up @@ -6,6 +6,10 @@ import { isAssignableTo } from "../Utils/isAssignableTo";
import { narrowType } from "../Utils/narrowType";
import { UnionType } from "../Type/UnionType";

class CheckType {
constructor(public parameterName: string, public type: BaseType | undefined) {}
}

export class ConditionalTypeNodeParser implements SubNodeParser {
public constructor(protected typeChecker: ts.TypeChecker, protected childNodeParser: NodeParser) {}

Expand All @@ -18,22 +22,27 @@ export class ConditionalTypeNodeParser implements SubNodeParser {
const extendsType = this.childNodeParser.createType(node.extendsType, context);
const checkTypeParameterName = this.getTypeParameterName(node.checkType);

const inferMap = new Map();

// If check-type is not a type parameter then condition is very simple, no type narrowing needed
if (checkTypeParameterName == null) {
const result = isAssignableTo(extendsType, checkType);
return this.childNodeParser.createType(result ? node.trueType : node.falseType, context);
const result = isAssignableTo(extendsType, checkType, inferMap);
return this.childNodeParser.createType(
result ? node.trueType : node.falseType,
this.createSubContext(node, context, undefined, result ? inferMap : new Map())
);
}

// Narrow down check type for both condition branches
const trueCheckType = narrowType(checkType, (type) => isAssignableTo(extendsType, type));
const trueCheckType = narrowType(checkType, (type) => isAssignableTo(extendsType, type, inferMap));
const falseCheckType = narrowType(checkType, (type) => !isAssignableTo(extendsType, type));

// Follow the relevant branches and return the results from them
const results: BaseType[] = [];
if (trueCheckType !== undefined) {
const result = this.childNodeParser.createType(
node.trueType,
this.createSubContext(node, checkTypeParameterName, trueCheckType, context)
this.createSubContext(node, context, new CheckType(checkTypeParameterName, trueCheckType), inferMap)
);
if (result) {
results.push(result);
Expand All @@ -42,7 +51,7 @@ export class ConditionalTypeNodeParser implements SubNodeParser {
if (falseCheckType !== undefined) {
const result = this.childNodeParser.createType(
node.falseType,
this.createSubContext(node, checkTypeParameterName, falseCheckType, context)
this.createSubContext(node, context, new CheckType(checkTypeParameterName, falseCheckType))
);
if (result) {
results.push(result);
Expand Down Expand Up @@ -72,25 +81,36 @@ export class ConditionalTypeNodeParser implements SubNodeParser {
* the check-type is a type parameter which is then narrowed down by the extends-type.
*
* @param node - The reference node for the new context.
* @param checkTypeParameterName - The type parameter name of the check-type.
* @param narrowedCheckType - The narrowed down check type to use for the type parameter in sub parsers.
* @param checkType - An object containing the type parameter name of the check-type, and the narrowed
* down check type to use for the type parameter in sub parsers.
* @param inferMap - A map that links parameter names to their inferred types.
* @return The created sub context.
*/
protected createSubContext(
node: ts.ConditionalTypeNode,
checkTypeParameterName: string,
narrowedCheckType: BaseType,
parentContext: Context
parentContext: Context,
checkType?: CheckType,
inferMap: Map<string, BaseType> = new Map()
): Context {
const subContext = new Context(node);

// Set new narrowed type for check type parameter
subContext.pushParameter(checkTypeParameterName);
subContext.pushArgument(narrowedCheckType);
// Newly inferred types take precedence over check and parent types.
inferMap.forEach((value, key) => {
subContext.pushParameter(key);
subContext.pushArgument(value);
});

if (checkType !== undefined) {
// Set new narrowed type for check type parameter
if (!(checkType.parameterName in inferMap)) {
subContext.pushParameter(checkType.parameterName);
subContext.pushArgument(checkType.type);
}
}

// Copy all other type parameters from parent context
parentContext.getParameters().forEach((parentParameter) => {
if (parentParameter !== checkTypeParameterName) {
if (parentParameter !== checkType?.parameterName && !(parentParameter in inferMap)) {
subContext.pushParameter(parentParameter);
subContext.pushArgument(parentContext.getArgument(parentParameter));
}
Expand Down
17 changes: 17 additions & 0 deletions src/NodeParser/InferTypeNodeParser.ts
@@ -0,0 +1,17 @@
import ts from "typescript";
import { Context, NodeParser } from "../NodeParser";
import { SubNodeParser } from "../SubNodeParser";
import { BaseType } from "../Type/BaseType";
import { InferType } from "../Type/InferType";

export class InferTypeNodeParser implements SubNodeParser {
public constructor(protected typeChecker: ts.TypeChecker, protected childNodeParser: NodeParser) {}

public supportsNode(node: ts.InferTypeNode): boolean {
return node.kind === ts.SyntaxKind.InferType;
}

public createType(node: ts.InferTypeNode, _context: Context): BaseType | undefined {
return new InferType(node.typeParameter.name.escapedText.toString());
}
}
2 changes: 1 addition & 1 deletion src/NodeParser/MappedTypeNodeParser.ts
Expand Up @@ -87,7 +87,7 @@ export class MappedTypeNodeParser implements SubNodeParser {
protected getProperties(node: ts.MappedTypeNode, keyListType: UnionType, context: Context): ObjectProperty[] {
return keyListType
.getTypes()
.filter((type) => type instanceof LiteralType)
.filter((type): type is LiteralType => type instanceof LiteralType)
.reduce((result: ObjectProperty[], key: LiteralType) => {
const namedKey = this.mapKey(node, key, context);
const propertyType = this.childNodeParser.createType(
Expand Down
4 changes: 3 additions & 1 deletion src/NodeParser/RestTypeNodeParser.ts
Expand Up @@ -3,14 +3,16 @@ import { Context, NodeParser } from "../NodeParser";
import { SubNodeParser } from "../SubNodeParser";
import { ArrayType } from "../Type/ArrayType";
import { BaseType } from "../Type/BaseType";
import { InferType } from "../Type/InferType";
import { RestType } from "../Type/RestType";
import { TupleType } from "../Type/TupleType";

export class RestTypeNodeParser implements SubNodeParser {
public constructor(protected childNodeParser: NodeParser) {}
public supportsNode(node: ts.RestTypeNode): boolean {
return node.kind === ts.SyntaxKind.RestType;
}
public createType(node: ts.RestTypeNode, context: Context): BaseType {
return new RestType(this.childNodeParser.createType(node.type, context) as ArrayType);
return new RestType(this.childNodeParser.createType(node.type, context) as ArrayType | InferType | TupleType);
}
}
11 changes: 11 additions & 0 deletions src/Type/InferType.ts
@@ -0,0 +1,11 @@
import { BaseType } from "./BaseType";

export class InferType extends BaseType {
constructor(private id: string) {
super();
}

public getId(): string {
return this.id;
}
}
6 changes: 4 additions & 2 deletions src/Type/RestType.ts
@@ -1,8 +1,10 @@
import { ArrayType } from "./ArrayType";
import { BaseType } from "./BaseType";
import { InferType } from "./InferType";
import { TupleType } from "./TupleType";

export class RestType extends BaseType {
public constructor(private item: ArrayType, private title: string | null = null) {
public constructor(private item: ArrayType | InferType | TupleType, private title: string | null = null) {
super();
}

Expand All @@ -14,7 +16,7 @@ export class RestType extends BaseType {
return this.title;
}

public getType(): ArrayType {
public getType(): ArrayType | InferType | TupleType {
return this.item;
}
}
29 changes: 27 additions & 2 deletions src/Type/TupleType.ts
@@ -1,15 +1,40 @@
import { derefType } from "../Utils/derefType";
import { ArrayType } from "./ArrayType";
import { BaseType } from "./BaseType";
import { InferType } from "./InferType";
import { RestType } from "./RestType";

function normalize(types: Readonly<Array<BaseType | undefined>>): Array<BaseType | undefined> {
let normalized: Array<BaseType | undefined> = [];

for (const type of types) {
if (type instanceof RestType) {
const inner_type = derefType(type.getType()) as ArrayType | InferType | TupleType;
normalized = [
...normalized,
...(inner_type instanceof TupleType ? normalize(inner_type.getTypes()) : [type]),
];
} else {
normalized.push(type);
}
}
return normalized;
}

export class TupleType extends BaseType {
public constructor(private types: readonly (BaseType | undefined)[]) {
private types: Readonly<Array<BaseType | undefined>>;

public constructor(types: Readonly<Array<BaseType | undefined>>) {
super();

this.types = normalize(types);
}

public getId(): string {
return `[${this.types.map((item) => item?.getId() ?? "never").join(",")}]`;
}

public getTypes(): readonly (BaseType | undefined)[] {
public getTypes(): Readonly<Array<BaseType | undefined>> {
return this.types;
}
}
19 changes: 18 additions & 1 deletion src/TypeFormatter/TupleTypeFormatter.ts
@@ -1,5 +1,6 @@
import { Definition } from "../Schema/Definition";
import { SubTypeFormatter } from "../SubTypeFormatter";
import { ArrayType } from "../Type/ArrayType";
import { BaseType } from "../Type/BaseType";
import { OptionalType } from "../Type/OptionalType";
import { RestType } from "../Type/RestType";
Expand All @@ -8,6 +9,21 @@ import { TypeFormatter } from "../TypeFormatter";
import { notUndefined } from "../Utils/notUndefined";
import { uniqueArray } from "../Utils/uniqueArray";

function uniformRestType(type: RestType, check_type: BaseType): boolean {
const inner = type.getType();
return (
(inner instanceof ArrayType && inner.getItem().getId() === check_type.getId()) ||
(inner instanceof TupleType &&
inner.getTypes().every((tuple_type) => {
if (tuple_type instanceof RestType) {
return uniformRestType(tuple_type, check_type);
} else {
return tuple_type?.getId() === check_type.getId();
}
}))
);
}

export class TupleTypeFormatter implements SubTypeFormatter {
public constructor(protected childTypeFormatter: TypeFormatter) {}

Expand All @@ -20,6 +36,7 @@ export class TupleTypeFormatter implements SubTypeFormatter {

const requiredElements = subTypes.filter((t) => !(t instanceof OptionalType) && !(t instanceof RestType));
const optionalElements = subTypes.filter((t): t is OptionalType => t instanceof OptionalType);
// NOTE: A maximum of one rest type is assumed.
const restType = subTypes.find((t): t is RestType => t instanceof RestType);
const firstItemType = requiredElements.length > 0 ? requiredElements[0] : optionalElements[0]?.getType();

Expand All @@ -32,7 +49,7 @@ export class TupleTypeFormatter implements SubTypeFormatter {
firstItemType &&
requiredElements.every((item) => item.getId() === firstItemType.getId()) &&
optionalElements.every((item) => item.getType().getId() === firstItemType.getId()) &&
(!restType || restType.getType().getItem().getId() === firstItemType.getId());
(!restType || uniformRestType(restType, firstItemType));

// If so, generate a simple array with minItems (and possibly maxItems) instead.
if (isUniformArray) {
Expand Down