diff --git a/src/services/codefixes/inferFromUsage.ts b/src/services/codefixes/inferFromUsage.ts index 417ed5c625cfb..2eb9eaca7b797 100644 --- a/src/services/codefixes/inferFromUsage.ts +++ b/src/services/codefixes/inferFromUsage.ts @@ -205,12 +205,7 @@ namespace ts.codefix { return; } - const references = inferFunctionReferencesFromUsage(containingFunction, sourceFile, program, cancellationToken); - const parameterInferences = InferFromReference.inferTypeForParametersFromReferences(references, containingFunction, program, cancellationToken) || - containingFunction.parameters.map(p => ({ - declaration: p, - type: isIdentifier(p.name) ? inferTypeForVariableFromUsage(p.name, program, cancellationToken) : program.getTypeChecker().getAnyType() - })); + const parameterInferences = inferTypeForParametersFromUsage(containingFunction, sourceFile, program, cancellationToken); Debug.assert(containingFunction.parameters.length === parameterInferences.length, "Parameter count and inference count should match"); if (isInJSFile(containingFunction)) { @@ -229,16 +224,11 @@ namespace ts.codefix { } function annotateThis(changes: textChanges.ChangeTracker, sourceFile: SourceFile, containingFunction: textChanges.ThisTypeAnnotatable, program: Program, host: LanguageServiceHost, cancellationToken: CancellationToken) { - const references = inferFunctionReferencesFromUsage(containingFunction, sourceFile, program, cancellationToken); - if (!references) { + const references = getFunctionReferences(containingFunction, sourceFile, program, cancellationToken); + if (!references || !references.length) { return; } - - const thisInference = InferFromReference.inferTypeForThisFromReferences(references, program, cancellationToken); - if (!thisInference) { - return; - } - + const thisInference = inferTypeFromReferences(program, references, cancellationToken).thisParameter(); const typeNode = getTypeNodeIfAccessible(thisInference, containingFunction, program, host); if (!typeNode) { return; @@ -357,12 +347,19 @@ namespace ts.codefix { function inferTypeForVariableFromUsage(token: Identifier, program: Program, cancellationToken: CancellationToken): Type { const references = getReferences(token, program, cancellationToken); - const checker = program.getTypeChecker(); - const types = InferFromReference.inferTypesFromReferences(references, checker, cancellationToken); - return InferFromReference.unifyFromUsage(types, checker); + return inferTypeFromReferences(program, references, cancellationToken).single(); } - function inferFunctionReferencesFromUsage(containingFunction: FunctionLike, sourceFile: SourceFile, program: Program, cancellationToken: CancellationToken): ReadonlyArray | undefined { + function inferTypeForParametersFromUsage(func: SignatureDeclaration, sourceFile: SourceFile, program: Program, cancellationToken: CancellationToken) { + const references = getFunctionReferences(func, sourceFile, program, cancellationToken); + return references && inferTypeFromReferences(program, references, cancellationToken).parameters(func) || + func.parameters.map(p => ({ + declaration: p, + type: isIdentifier(p.name) ? inferTypeForVariableFromUsage(p.name, program, cancellationToken) : program.getTypeChecker().getAnyType() + })); + } + + function getFunctionReferences(containingFunction: FunctionLike, sourceFile: SourceFile, program: Program, cancellationToken: CancellationToken): ReadonlyArray | undefined { let searchToken; switch (containingFunction.kind) { case SyntaxKind.Constructor: @@ -394,7 +391,14 @@ namespace ts.codefix { readonly isOptional?: boolean; } - namespace InferFromReference { + function inferTypeFromReferences(program: Program, references: ReadonlyArray, cancellationToken: CancellationToken) { + const checker = program.getTypeChecker(); + return { + single, + parameters, + thisParameter, + }; + interface CallUsage { argumentTypes: Type[]; returnType: Usage; @@ -415,25 +419,19 @@ namespace ts.codefix { candidateThisTypes?: Type[]; } - export function inferTypesFromReferences(references: ReadonlyArray, checker: TypeChecker, cancellationToken: CancellationToken): Type[] { - const usage: Usage = {}; - for (const reference of references) { - cancellationToken.throwIfCancellationRequested(); - calculateUsageOfNode(reference, checker, usage); - } - return inferFromUsage(usage, checker); + function single(): Type { + return unifyFromUsage(inferTypesFromReferencesSingle(references)); } - export function inferTypeForParametersFromReferences(references: ReadonlyArray | undefined, declaration: FunctionLike, program: Program, cancellationToken: CancellationToken): ParameterInference[] | undefined { - if (references === undefined || references.length === 0 || !declaration.parameters) { + function parameters(declaration: FunctionLike): ParameterInference[] | undefined { + if (references.length === 0 || !declaration.parameters) { return undefined; } - const checker = program.getTypeChecker(); const usage: Usage = {}; for (const reference of references) { cancellationToken.throwIfCancellationRequested(); - calculateUsageOfNode(reference, checker, usage); + calculateUsageOfNode(reference, usage); } const calls = [...usage.constructs || [], ...usage.calls || []]; return declaration.parameters.map((parameter, parameterIndex): ParameterInference => { @@ -455,10 +453,10 @@ namespace ts.codefix { } } if (isIdentifier(parameter.name)) { - const inferred = inferTypesFromReferences(getReferences(parameter.name, program, cancellationToken), checker, cancellationToken); + const inferred = inferTypesFromReferencesSingle(getReferences(parameter.name, program, cancellationToken)); types.push(...(isRest ? mapDefined(inferred, checker.getElementTypeOfArrayType) : inferred)); } - const type = unifyFromUsage(types, checker); + const type = unifyFromUsage(types); return { type: isRest ? checker.createArrayType(type) : type, isOptional: isOptional && !isRest, @@ -467,23 +465,26 @@ namespace ts.codefix { }); } - export function inferTypeForThisFromReferences(references: ReadonlyArray, program: Program, cancellationToken: CancellationToken) { - if (references.length === 0) { - return undefined; + function thisParameter() { + const usage: Usage = {}; + for (const reference of references) { + cancellationToken.throwIfCancellationRequested(); + calculateUsageOfNode(reference, usage); } - const checker = program.getTypeChecker(); - const usage: Usage = {}; + return unifyFromUsage(usage.candidateThisTypes || emptyArray); + } + function inferTypesFromReferencesSingle(references: readonly Identifier[]): Type[] { + const usage: Usage = {}; for (const reference of references) { cancellationToken.throwIfCancellationRequested(); - calculateUsageOfNode(reference, checker, usage); + calculateUsageOfNode(reference, usage); } - - return unifyFromUsage(usage.candidateThisTypes || emptyArray, checker); + return inferFromUsage(usage); } - function calculateUsageOfNode(node: Expression, checker: TypeChecker, usage: Usage): void { + function calculateUsageOfNode(node: Expression, usage: Usage): void { while (isRightSideOfQualifiedNameOrPropertyAccess(node)) { node = node.parent; } @@ -496,33 +497,33 @@ namespace ts.codefix { inferTypeFromPrefixUnaryExpression(node.parent, usage); break; case SyntaxKind.BinaryExpression: - inferTypeFromBinaryExpression(node, node.parent, checker, usage); + inferTypeFromBinaryExpression(node, node.parent, usage); break; case SyntaxKind.CaseClause: case SyntaxKind.DefaultClause: - inferTypeFromSwitchStatementLabel(node.parent, checker, usage); + inferTypeFromSwitchStatementLabel(node.parent, usage); break; case SyntaxKind.CallExpression: case SyntaxKind.NewExpression: if ((node.parent).expression === node) { - inferTypeFromCallExpression(node.parent, checker, usage); + inferTypeFromCallExpression(node.parent, usage); } else { - inferTypeFromContextualType(node, checker, usage); + inferTypeFromContextualType(node, usage); } break; case SyntaxKind.PropertyAccessExpression: - inferTypeFromPropertyAccessExpression(node.parent, checker, usage); + inferTypeFromPropertyAccessExpression(node.parent, usage); break; case SyntaxKind.ElementAccessExpression: - inferTypeFromPropertyElementExpression(node.parent, node, checker, usage); + inferTypeFromPropertyElementExpression(node.parent, node, usage); break; case SyntaxKind.PropertyAssignment: case SyntaxKind.ShorthandPropertyAssignment: - inferTypeFromPropertyAssignment(node.parent, checker, usage); + inferTypeFromPropertyAssignment(node.parent, usage); break; case SyntaxKind.PropertyDeclaration: - inferTypeFromPropertyDeclaration(node.parent, checker, usage); + inferTypeFromPropertyDeclaration(node.parent, usage); break; case SyntaxKind.VariableDeclaration: { const { name, initializer } = node.parent as VariableDeclaration; @@ -535,11 +536,11 @@ namespace ts.codefix { } // falls through default: - return inferTypeFromContextualType(node, checker, usage); + return inferTypeFromContextualType(node, usage); } } - function inferTypeFromContextualType(node: Expression, checker: TypeChecker, usage: Usage): void { + function inferTypeFromContextualType(node: Expression, usage: Usage): void { if (isExpressionNode(node)) { addCandidateType(usage, checker.getContextualType(node)); } @@ -563,7 +564,7 @@ namespace ts.codefix { } } - function inferTypeFromBinaryExpression(node: Expression, parent: BinaryExpression, checker: TypeChecker, usage: Usage): void { + function inferTypeFromBinaryExpression(node: Expression, parent: BinaryExpression, usage: Usage): void { switch (parent.operatorToken.kind) { // ExponentiationOperator case SyntaxKind.AsteriskAsteriskToken: @@ -663,11 +664,11 @@ namespace ts.codefix { } } - function inferTypeFromSwitchStatementLabel(parent: CaseOrDefaultClause, checker: TypeChecker, usage: Usage): void { + function inferTypeFromSwitchStatementLabel(parent: CaseOrDefaultClause, usage: Usage): void { addCandidateType(usage, checker.getTypeAtLocation(parent.parent.parent.expression)); } - function inferTypeFromCallExpression(parent: CallExpression | NewExpression, checker: TypeChecker, usage: Usage): void { + function inferTypeFromCallExpression(parent: CallExpression | NewExpression, usage: Usage): void { const call: CallUsage = { argumentTypes: [], returnType: {} @@ -679,7 +680,7 @@ namespace ts.codefix { } } - calculateUsageOfNode(parent, checker, call.returnType); + calculateUsageOfNode(parent, call.returnType); if (parent.kind === SyntaxKind.CallExpression) { (usage.calls || (usage.calls = [])).push(call); } @@ -688,17 +689,17 @@ namespace ts.codefix { } } - function inferTypeFromPropertyAccessExpression(parent: PropertyAccessExpression, checker: TypeChecker, usage: Usage): void { + function inferTypeFromPropertyAccessExpression(parent: PropertyAccessExpression, usage: Usage): void { const name = escapeLeadingUnderscores(parent.name.text); if (!usage.properties) { usage.properties = createUnderscoreEscapedMap(); } const propertyUsage = usage.properties.get(name) || { }; - calculateUsageOfNode(parent, checker, propertyUsage); + calculateUsageOfNode(parent, propertyUsage); usage.properties.set(name, propertyUsage); } - function inferTypeFromPropertyElementExpression(parent: ElementAccessExpression, node: Expression, checker: TypeChecker, usage: Usage): void { + function inferTypeFromPropertyElementExpression(parent: ElementAccessExpression, node: Expression, usage: Usage): void { if (node === parent.argumentExpression) { usage.isNumberOrString = true; return; @@ -706,7 +707,7 @@ namespace ts.codefix { else { const indexType = checker.getTypeAtLocation(parent.argumentExpression); const indexUsage = {}; - calculateUsageOfNode(parent, checker, indexUsage); + calculateUsageOfNode(parent, indexUsage); if (indexType.flags & TypeFlags.NumberLike) { usage.numberIndex = indexUsage; } @@ -716,14 +717,14 @@ namespace ts.codefix { } } - function inferTypeFromPropertyAssignment(assignment: PropertyAssignment | ShorthandPropertyAssignment, checker: TypeChecker, usage: Usage) { + function inferTypeFromPropertyAssignment(assignment: PropertyAssignment | ShorthandPropertyAssignment, usage: Usage) { const nodeWithRealType = isVariableDeclaration(assignment.parent.parent) ? assignment.parent.parent : assignment.parent; addCandidateThisType(usage, checker.getTypeAtLocation(nodeWithRealType)); } - function inferTypeFromPropertyDeclaration(declaration: PropertyDeclaration, checker: TypeChecker, usage: Usage) { + function inferTypeFromPropertyDeclaration(declaration: PropertyDeclaration, usage: Usage) { addCandidateThisType(usage, checker.getTypeAtLocation(declaration.parent)); } @@ -745,7 +746,7 @@ namespace ts.codefix { return inferences.filter(i => toRemove.every(f => !f(i))); } - export function unifyFromUsage(inferences: ReadonlyArray, checker: TypeChecker, fallback = checker.getAnyType()): Type { + function unifyFromUsage(inferences: ReadonlyArray, fallback = checker.getAnyType()): Type { if (!inferences.length) return fallback; // 1. string or number individually override string | number @@ -769,12 +770,12 @@ namespace ts.codefix { const anons = good.filter(i => checker.getObjectFlags(i) & ObjectFlags.Anonymous) as AnonymousType[]; if (anons.length) { good = good.filter(i => !(checker.getObjectFlags(i) & ObjectFlags.Anonymous)); - good.push(unifyAnonymousTypes(anons, checker)); + good.push(unifyAnonymousTypes(anons)); } return checker.getWidenedType(checker.getUnionType(good)); } - function unifyAnonymousTypes(anons: AnonymousType[], checker: TypeChecker) { + function unifyAnonymousTypes(anons: AnonymousType[]) { if (anons.length === 1) { return anons[0]; } @@ -815,7 +816,7 @@ namespace ts.codefix { numberIndices.length ? checker.createIndexInfo(checker.getUnionType(numberIndices), numberIndexReadonly) : undefined); } - function inferFromUsage(usage: Usage, checker: TypeChecker) { + function inferFromUsage(usage: Usage) { const types = []; if (usage.isNumber) { @@ -831,12 +832,12 @@ namespace ts.codefix { types.push(...(usage.candidateTypes || []).map(t => checker.getBaseTypeOfLiteralType(t))); if (usage.properties && hasCalls(usage.properties.get("then" as __String))) { - const paramType = getParameterTypeFromCalls(0, usage.properties.get("then" as __String)!.calls!, /*isRestParameter*/ false, checker)!; // TODO: GH#18217 - const types = paramType.getCallSignatures().map(c => c.getReturnType()); + const paramType = getParameterTypeFromCalls(0, usage.properties.get("then" as __String)!.calls!, /*isRestParameter*/ false)!; // TODO: GH#18217 + const types = paramType.getCallSignatures().map(sig => sig.getReturnType()); types.push(checker.createPromiseType(types.length ? checker.getUnionType(types, UnionReduction.Subtype) : checker.getAnyType())); } else if (usage.properties && hasCalls(usage.properties.get("push" as __String))) { - types.push(checker.createArrayType(getParameterTypeFromCalls(0, usage.properties.get("push" as __String)!.calls!, /*isRestParameter*/ false, checker)!)); + types.push(checker.createArrayType(getParameterTypeFromCalls(0, usage.properties.get("push" as __String)!.calls!, /*isRestParameter*/ false)!)); } if (usage.numberIndex) { @@ -858,13 +859,13 @@ namespace ts.codefix { if (usage.calls) { for (const call of usage.calls) { - callSignatures.push(getSignatureFromCall(call, checker)); + callSignatures.push(getSignatureFromCall(call)); } } if (usage.constructs) { for (const construct of usage.constructs) { - constructSignatures.push(getSignatureFromCall(construct, checker)); + constructSignatures.push(getSignatureFromCall(construct)); } } @@ -877,11 +878,11 @@ namespace ts.codefix { return types; function recur(innerUsage: Usage): Type { - return unifyFromUsage(inferFromUsage(innerUsage, checker), checker); + return unifyFromUsage(inferFromUsage(innerUsage)); } } - function getParameterTypeFromCalls(parameterIndex: number, calls: CallUsage[], isRestParameter: boolean, checker: TypeChecker) { + function getParameterTypeFromCalls(parameterIndex: number, calls: CallUsage[], isRestParameter: boolean) { let types: Type[] = []; if (calls) { for (const call of calls) { @@ -903,14 +904,14 @@ namespace ts.codefix { return undefined; } - function getSignatureFromCall(call: CallUsage, checker: TypeChecker): Signature { + function getSignatureFromCall(call: CallUsage): Signature { const parameters: Symbol[] = []; for (let i = 0; i < call.argumentTypes.length; i++) { const symbol = checker.createSymbol(SymbolFlags.FunctionScopedVariable, escapeLeadingUnderscores(`arg${i}`)); symbol.type = checker.getWidenedType(checker.getBaseTypeOfLiteralType(call.argumentTypes[i])); parameters.push(symbol); } - const returnType = unifyFromUsage(inferFromUsage(call.returnType, checker), checker, checker.getVoidType()); + const returnType = unifyFromUsage(inferFromUsage(call.returnType), checker.getVoidType()); // TODO: GH#18217 return checker.createSignature(/*declaration*/ undefined!, /*typeParameters*/ undefined, /*thisParameter*/ undefined, parameters, returnType, /*typePredicate*/ undefined, call.argumentTypes.length, /*hasRestParameter*/ false, /*hasLiteralTypes*/ false); }