diff --git a/packages/eslint-plugin/src/rules/return-await.ts b/packages/eslint-plugin/src/rules/return-await.ts index 2c47272cdaf..17bc9c0e54e 100644 --- a/packages/eslint-plugin/src/rules/return-await.ts +++ b/packages/eslint-plugin/src/rules/return-await.ts @@ -78,21 +78,14 @@ export default util.createRule({ function removeAwait( fixer: TSESLint.RuleFixer, - node: TSESTree.ReturnStatement | TSESTree.ArrowFunctionExpression, + node: TSESTree.Expression, ): TSESLint.RuleFix | null { - const awaitNode = - node.type === AST_NODE_TYPES.ReturnStatement - ? node.argument - : node.body; // Should always be an await node; but let's be safe. - /* istanbul ignore if */ if (!util.isAwaitExpression(awaitNode)) { + /* istanbul ignore if */ if (!util.isAwaitExpression(node)) { return null; } - const awaitToken = sourceCode.getFirstToken( - awaitNode, - util.isAwaitKeyword, - ); + const awaitToken = sourceCode.getFirstToken(node, util.isAwaitKeyword); // Should always be the case; but let's be safe. /* istanbul ignore if */ if (!awaitToken) { return null; @@ -113,24 +106,12 @@ export default util.createRule({ function insertAwait( fixer: TSESLint.RuleFixer, - node: TSESTree.ReturnStatement | TSESTree.ArrowFunctionExpression, + node: TSESTree.Expression, ): TSESLint.RuleFix | null { - const targetNode = - node.type === AST_NODE_TYPES.ReturnStatement - ? node.argument - : node.body; - // There should always be a target node; but let's be safe. - /* istanbul ignore if */ if (!targetNode) { - return null; - } - - return fixer.insertTextBefore(targetNode, 'await '); + return fixer.insertTextBefore(node, 'await '); } - function test( - node: TSESTree.ReturnStatement | TSESTree.ArrowFunctionExpression, - expression: ts.Node, - ): void { + function test(node: TSESTree.Expression, expression: ts.Node): void { let child: ts.Node; const isAwait = tsutils.isAwaitExpression(expression); @@ -201,6 +182,18 @@ export default util.createRule({ } } + function findPossiblyReturnedNodes( + node: TSESTree.Expression, + ): TSESTree.Expression[] { + if (node.type === AST_NODE_TYPES.ConditionalExpression) { + return [ + ...findPossiblyReturnedNodes(node.alternate), + ...findPossiblyReturnedNodes(node.consequent), + ]; + } + return [node]; + } + return { FunctionDeclaration: enterFunction, FunctionExpression: enterFunction, @@ -210,27 +203,20 @@ export default util.createRule({ node: TSESTree.ArrowFunctionExpression, ): void { if (node.body.type !== AST_NODE_TYPES.BlockStatement) { - const expression = parserServices.esTreeNodeToTSNodeMap.get( - node.body, - ); - - test(node, expression); + findPossiblyReturnedNodes(node.body).forEach(node => { + const tsNode = parserServices.esTreeNodeToTSNodeMap.get(node); + test(node, tsNode); + }); } }, ReturnStatement(node): void { - if (!scopeInfo || !scopeInfo.hasAsync) { + if (!scopeInfo || !scopeInfo.hasAsync || !node.argument) { return; } - - const originalNode = parserServices.esTreeNodeToTSNodeMap.get(node); - - const { expression } = originalNode; - - if (!expression) { - return; - } - - test(node, expression); + findPossiblyReturnedNodes(node.argument).forEach(node => { + const tsNode = parserServices.esTreeNodeToTSNodeMap.get(node); + test(node, tsNode); + }); }, }; }, diff --git a/packages/eslint-plugin/tests/rules/return-await.test.ts b/packages/eslint-plugin/tests/rules/return-await.test.ts index 5a98d22c917..671374fc9bd 100644 --- a/packages/eslint-plugin/tests/rules/return-await.test.ts +++ b/packages/eslint-plugin/tests/rules/return-await.test.ts @@ -1,5 +1,5 @@ import rule from '../../src/rules/return-await'; -import { getFixturesRootDir, RuleTester } from '../RuleTester'; +import { getFixturesRootDir, RuleTester, noFormat } from '../RuleTester'; const rootDir = getFixturesRootDir(); @@ -600,5 +600,156 @@ ruleTester.run('return-await', rule, { }, ], }, + { + options: ['always'], + code: noFormat` +async function foo() {} +async function bar() {} +async function baz() {} +async function qux() {} +async function buzz() { + return (await foo()) ? bar() : baz(); +} + `, + output: noFormat` +async function foo() {} +async function bar() {} +async function baz() {} +async function qux() {} +async function buzz() { + return (await foo()) ? await bar() : await baz(); +} + `, + errors: [ + { + line: 7, + messageId: 'requiredPromiseAwait', + }, + { + line: 7, + messageId: 'requiredPromiseAwait', + }, + ], + }, + { + options: ['always'], + code: noFormat` +async function foo() {} +async function bar() {} +async function baz() {} +async function qux() {} +async function buzz() { + return (await foo()) + ? ( + bar ? bar() : baz() + ) : baz ? baz() : bar(); +} + `, + output: noFormat` +async function foo() {} +async function bar() {} +async function baz() {} +async function qux() {} +async function buzz() { + return (await foo()) + ? ( + bar ? await bar() : await baz() + ) : baz ? await baz() : await bar(); +} + `, + errors: [ + { + line: 9, + messageId: 'requiredPromiseAwait', + }, + { + line: 9, + messageId: 'requiredPromiseAwait', + }, + { + line: 10, + messageId: 'requiredPromiseAwait', + }, + { + line: 10, + messageId: 'requiredPromiseAwait', + }, + ], + }, + { + options: ['always'], + code: ` +async function foo() {} +async function bar() {} +async function buzz() { + return (await foo()) ? await 1 : bar(); +} + `, + output: ` +async function foo() {} +async function bar() {} +async function buzz() { + return (await foo()) ? 1 : await bar(); +} + `, + errors: [ + { + line: 5, + messageId: 'nonPromiseAwait', + }, + { + line: 5, + messageId: 'requiredPromiseAwait', + }, + ], + }, + { + options: ['always'], + code: ` +async function foo() {} +async function bar() {} +async function baz() {} +const buzz = async () => ((await foo()) ? bar() : baz()); + `, + output: ` +async function foo() {} +async function bar() {} +async function baz() {} +const buzz = async () => ((await foo()) ? await bar() : await baz()); + `, + errors: [ + { + line: 5, + messageId: 'requiredPromiseAwait', + }, + { + line: 5, + messageId: 'requiredPromiseAwait', + }, + ], + }, + { + options: ['always'], + code: ` +async function foo() {} +async function bar() {} +const buzz = async () => ((await foo()) ? await 1 : bar()); + `, + output: ` +async function foo() {} +async function bar() {} +const buzz = async () => ((await foo()) ? 1 : await bar()); + `, + errors: [ + { + line: 4, + messageId: 'nonPromiseAwait', + }, + { + line: 4, + messageId: 'requiredPromiseAwait', + }, + ], + }, ], });