Skip to content

Commit

Permalink
feat: support inferred types in conditionals (#1265)
Browse files Browse the repository at this point in the history
  • Loading branch information
daanboer committed Jul 31, 2022
1 parent aa63c0a commit 8547b12
Show file tree
Hide file tree
Showing 30 changed files with 595 additions and 33 deletions.
3 changes: 3 additions & 0 deletions factory/parser.ts
Expand Up @@ -22,6 +22,7 @@ import { FunctionNodeParser } from "../src/NodeParser/FunctionNodeParser";
import { FunctionParser } from "../src/NodeParser/FunctionParser";
import { HiddenNodeParser } from "../src/NodeParser/HiddenTypeNodeParser";
import { IndexedAccessTypeNodeParser } from "../src/NodeParser/IndexedAccessTypeNodeParser";
import { InferTypeNodeParser } from "../src/NodeParser/InferTypeNodeParser";
import { InterfaceAndClassNodeParser } from "../src/NodeParser/InterfaceAndClassNodeParser";
import { IntersectionNodeParser } from "../src/NodeParser/IntersectionNodeParser";
import { IntrinsicNodeParser } from "../src/NodeParser/IntrinsicNodeParser";
Expand Down Expand Up @@ -122,6 +123,8 @@ export function createParser(program: ts.Program, config: Config, augmentor?: Pa
.addNodeParser(new TypeReferenceNodeParser(typeChecker, chainNodeParser))
.addNodeParser(new ExpressionWithTypeArgumentsNodeParser(typeChecker, chainNodeParser))

.addNodeParser(new InferTypeNodeParser(typeChecker, chainNodeParser))

.addNodeParser(new IndexedAccessTypeNodeParser(chainNodeParser))
.addNodeParser(new TypeofNodeParser(typeChecker, chainNodeParser))
.addNodeParser(new MappedTypeNodeParser(chainNodeParser, mergedConfig.additionalProperties))
Expand Down
48 changes: 34 additions & 14 deletions src/NodeParser/ConditionalTypeNodeParser.ts
Expand Up @@ -6,6 +6,10 @@ import { isAssignableTo } from "../Utils/isAssignableTo";
import { narrowType } from "../Utils/narrowType";
import { UnionType } from "../Type/UnionType";

class CheckType {
constructor(public parameterName: string, public type: BaseType | undefined) {}
}

export class ConditionalTypeNodeParser implements SubNodeParser {
public constructor(protected typeChecker: ts.TypeChecker, protected childNodeParser: NodeParser) {}

Expand All @@ -18,22 +22,27 @@ export class ConditionalTypeNodeParser implements SubNodeParser {
const extendsType = this.childNodeParser.createType(node.extendsType, context);
const checkTypeParameterName = this.getTypeParameterName(node.checkType);

const inferMap = new Map();

// If check-type is not a type parameter then condition is very simple, no type narrowing needed
if (checkTypeParameterName == null) {
const result = isAssignableTo(extendsType, checkType);
return this.childNodeParser.createType(result ? node.trueType : node.falseType, context);
const result = isAssignableTo(extendsType, checkType, inferMap);
return this.childNodeParser.createType(
result ? node.trueType : node.falseType,
this.createSubContext(node, context, undefined, result ? inferMap : new Map())
);
}

// Narrow down check type for both condition branches
const trueCheckType = narrowType(checkType, (type) => isAssignableTo(extendsType, type));
const trueCheckType = narrowType(checkType, (type) => isAssignableTo(extendsType, type, inferMap));
const falseCheckType = narrowType(checkType, (type) => !isAssignableTo(extendsType, type));

// Follow the relevant branches and return the results from them
const results: BaseType[] = [];
if (trueCheckType !== undefined) {
const result = this.childNodeParser.createType(
node.trueType,
this.createSubContext(node, checkTypeParameterName, trueCheckType, context)
this.createSubContext(node, context, new CheckType(checkTypeParameterName, trueCheckType), inferMap)
);
if (result) {
results.push(result);
Expand All @@ -42,7 +51,7 @@ export class ConditionalTypeNodeParser implements SubNodeParser {
if (falseCheckType !== undefined) {
const result = this.childNodeParser.createType(
node.falseType,
this.createSubContext(node, checkTypeParameterName, falseCheckType, context)
this.createSubContext(node, context, new CheckType(checkTypeParameterName, falseCheckType))
);
if (result) {
results.push(result);
Expand Down Expand Up @@ -72,25 +81,36 @@ export class ConditionalTypeNodeParser implements SubNodeParser {
* the check-type is a type parameter which is then narrowed down by the extends-type.
*
* @param node - The reference node for the new context.
* @param checkTypeParameterName - The type parameter name of the check-type.
* @param narrowedCheckType - The narrowed down check type to use for the type parameter in sub parsers.
* @param checkType - An object containing the type parameter name of the check-type, and the narrowed
* down check type to use for the type parameter in sub parsers.
* @param inferMap - A map that links parameter names to their inferred types.
* @return The created sub context.
*/
protected createSubContext(
node: ts.ConditionalTypeNode,
checkTypeParameterName: string,
narrowedCheckType: BaseType,
parentContext: Context
parentContext: Context,
checkType?: CheckType,
inferMap: Map<string, BaseType> = new Map()
): Context {
const subContext = new Context(node);

// Set new narrowed type for check type parameter
subContext.pushParameter(checkTypeParameterName);
subContext.pushArgument(narrowedCheckType);
// Newly inferred types take precedence over check and parent types.
inferMap.forEach((value, key) => {
subContext.pushParameter(key);
subContext.pushArgument(value);
});

if (checkType !== undefined) {
// Set new narrowed type for check type parameter
if (!(checkType.parameterName in inferMap)) {
subContext.pushParameter(checkType.parameterName);
subContext.pushArgument(checkType.type);
}
}

// Copy all other type parameters from parent context
parentContext.getParameters().forEach((parentParameter) => {
if (parentParameter !== checkTypeParameterName) {
if (parentParameter !== checkType?.parameterName && !(parentParameter in inferMap)) {
subContext.pushParameter(parentParameter);
subContext.pushArgument(parentContext.getArgument(parentParameter));
}
Expand Down
17 changes: 17 additions & 0 deletions src/NodeParser/InferTypeNodeParser.ts
@@ -0,0 +1,17 @@
import ts from "typescript";
import { Context, NodeParser } from "../NodeParser";
import { SubNodeParser } from "../SubNodeParser";
import { BaseType } from "../Type/BaseType";
import { InferType } from "../Type/InferType";

export class InferTypeNodeParser implements SubNodeParser {
public constructor(protected typeChecker: ts.TypeChecker, protected childNodeParser: NodeParser) {}

public supportsNode(node: ts.InferTypeNode): boolean {
return node.kind === ts.SyntaxKind.InferType;
}

public createType(node: ts.InferTypeNode, _context: Context): BaseType | undefined {
return new InferType(node.typeParameter.name.escapedText.toString());
}
}
2 changes: 1 addition & 1 deletion src/NodeParser/MappedTypeNodeParser.ts
Expand Up @@ -87,7 +87,7 @@ export class MappedTypeNodeParser implements SubNodeParser {
protected getProperties(node: ts.MappedTypeNode, keyListType: UnionType, context: Context): ObjectProperty[] {
return keyListType
.getTypes()
.filter((type) => type instanceof LiteralType)
.filter((type): type is LiteralType => type instanceof LiteralType)
.reduce((result: ObjectProperty[], key: LiteralType) => {
const namedKey = this.mapKey(node, key, context);
const propertyType = this.childNodeParser.createType(
Expand Down
4 changes: 3 additions & 1 deletion src/NodeParser/RestTypeNodeParser.ts
Expand Up @@ -3,14 +3,16 @@ import { Context, NodeParser } from "../NodeParser";
import { SubNodeParser } from "../SubNodeParser";
import { ArrayType } from "../Type/ArrayType";
import { BaseType } from "../Type/BaseType";
import { InferType } from "../Type/InferType";
import { RestType } from "../Type/RestType";
import { TupleType } from "../Type/TupleType";

export class RestTypeNodeParser implements SubNodeParser {
public constructor(protected childNodeParser: NodeParser) {}
public supportsNode(node: ts.RestTypeNode): boolean {
return node.kind === ts.SyntaxKind.RestType;
}
public createType(node: ts.RestTypeNode, context: Context): BaseType {
return new RestType(this.childNodeParser.createType(node.type, context) as ArrayType);
return new RestType(this.childNodeParser.createType(node.type, context) as ArrayType | InferType | TupleType);
}
}
11 changes: 11 additions & 0 deletions src/Type/InferType.ts
@@ -0,0 +1,11 @@
import { BaseType } from "./BaseType";

export class InferType extends BaseType {
constructor(private id: string) {
super();
}

public getId(): string {
return this.id;
}
}
6 changes: 4 additions & 2 deletions src/Type/RestType.ts
@@ -1,8 +1,10 @@
import { ArrayType } from "./ArrayType";
import { BaseType } from "./BaseType";
import { InferType } from "./InferType";
import { TupleType } from "./TupleType";

export class RestType extends BaseType {
public constructor(private item: ArrayType, private title: string | null = null) {
public constructor(private item: ArrayType | InferType | TupleType, private title: string | null = null) {
super();
}

Expand All @@ -14,7 +16,7 @@ export class RestType extends BaseType {
return this.title;
}

public getType(): ArrayType {
public getType(): ArrayType | InferType | TupleType {
return this.item;
}
}
29 changes: 27 additions & 2 deletions src/Type/TupleType.ts
@@ -1,15 +1,40 @@
import { derefType } from "../Utils/derefType";
import { ArrayType } from "./ArrayType";
import { BaseType } from "./BaseType";
import { InferType } from "./InferType";
import { RestType } from "./RestType";

function normalize(types: Readonly<Array<BaseType | undefined>>): Array<BaseType | undefined> {
let normalized: Array<BaseType | undefined> = [];

for (const type of types) {
if (type instanceof RestType) {
const inner_type = derefType(type.getType()) as ArrayType | InferType | TupleType;
normalized = [
...normalized,
...(inner_type instanceof TupleType ? normalize(inner_type.getTypes()) : [type]),
];
} else {
normalized.push(type);
}
}
return normalized;
}

export class TupleType extends BaseType {
public constructor(private types: readonly (BaseType | undefined)[]) {
private types: Readonly<Array<BaseType | undefined>>;

public constructor(types: Readonly<Array<BaseType | undefined>>) {
super();

this.types = normalize(types);
}

public getId(): string {
return `[${this.types.map((item) => item?.getId() ?? "never").join(",")}]`;
}

public getTypes(): readonly (BaseType | undefined)[] {
public getTypes(): Readonly<Array<BaseType | undefined>> {
return this.types;
}
}
19 changes: 18 additions & 1 deletion src/TypeFormatter/TupleTypeFormatter.ts
@@ -1,5 +1,6 @@
import { Definition } from "../Schema/Definition";
import { SubTypeFormatter } from "../SubTypeFormatter";
import { ArrayType } from "../Type/ArrayType";
import { BaseType } from "../Type/BaseType";
import { OptionalType } from "../Type/OptionalType";
import { RestType } from "../Type/RestType";
Expand All @@ -8,6 +9,21 @@ import { TypeFormatter } from "../TypeFormatter";
import { notUndefined } from "../Utils/notUndefined";
import { uniqueArray } from "../Utils/uniqueArray";

function uniformRestType(type: RestType, check_type: BaseType): boolean {
const inner = type.getType();
return (
(inner instanceof ArrayType && inner.getItem().getId() === check_type.getId()) ||
(inner instanceof TupleType &&
inner.getTypes().every((tuple_type) => {
if (tuple_type instanceof RestType) {
return uniformRestType(tuple_type, check_type);
} else {
return tuple_type?.getId() === check_type.getId();
}
}))
);
}

export class TupleTypeFormatter implements SubTypeFormatter {
public constructor(protected childTypeFormatter: TypeFormatter) {}

Expand All @@ -20,6 +36,7 @@ export class TupleTypeFormatter implements SubTypeFormatter {

const requiredElements = subTypes.filter((t) => !(t instanceof OptionalType) && !(t instanceof RestType));
const optionalElements = subTypes.filter((t): t is OptionalType => t instanceof OptionalType);
// NOTE: A maximum of one rest type is assumed.
const restType = subTypes.find((t): t is RestType => t instanceof RestType);
const firstItemType = requiredElements.length > 0 ? requiredElements[0] : optionalElements[0]?.getType();

Expand All @@ -32,7 +49,7 @@ export class TupleTypeFormatter implements SubTypeFormatter {
firstItemType &&
requiredElements.every((item) => item.getId() === firstItemType.getId()) &&
optionalElements.every((item) => item.getType().getId() === firstItemType.getId()) &&
(!restType || restType.getType().getItem().getId() === firstItemType.getId());
(!restType || uniformRestType(restType, firstItemType));

// If so, generate a simple array with minItems (and possibly maxItems) instead.
if (isUniformArray) {
Expand Down

0 comments on commit 8547b12

Please sign in to comment.