diff --git a/packages/eslint-plugin/src/rules/require-await.ts b/packages/eslint-plugin/src/rules/require-await.ts index a1e92e8f91d..f0935b202ea 100644 --- a/packages/eslint-plugin/src/rules/require-await.ts +++ b/packages/eslint-plugin/src/rules/require-await.ts @@ -115,16 +115,22 @@ export default util.createRule({ if (node?.argument?.type === AST_NODE_TYPES.Literal) { // making this `false` as for literals we don't need to check the definition // eg : async function* run() { yield* 1 } - scopeInfo.isAsyncYield = false; + scopeInfo.isAsyncYield ||= false; } const tsNode = parserServices.esTreeNodeToTSNodeMap.get(node?.argument); const type = checker.getTypeAtLocation(tsNode); - const symbol = type.getSymbol(); - - // async function* test1() {yield* asyncGenerator() } - if (symbol?.getName() === 'AsyncGenerator') { - scopeInfo.isAsyncYield = true; + const typesToCheck = expandUnionOrIntersectionType(type); + for (const type of typesToCheck) { + const asyncIterator = tsutils.getWellKnownSymbolPropertyOfType( + type, + 'asyncIterator', + checker, + ); + if (asyncIterator !== undefined) { + scopeInfo.isAsyncYield = true; + break; + } } } @@ -230,3 +236,10 @@ function getFunctionHeadLoc( end, }; } + +function expandUnionOrIntersectionType(type: ts.Type): ts.Type[] { + if (type.isUnionOrIntersection()) { + return type.types.flatMap(expandUnionOrIntersectionType); + } + return [type]; +} diff --git a/packages/eslint-plugin/tests/rules/require-await.test.ts b/packages/eslint-plugin/tests/rules/require-await.test.ts index ec2b178ecbb..f385bc8402a 100644 --- a/packages/eslint-plugin/tests/rules/require-await.test.ts +++ b/packages/eslint-plugin/tests/rules/require-await.test.ts @@ -159,6 +159,57 @@ async function* asyncGenerator() { } async function* test1() { yield* asyncGenerator(); +} + `, + ` +async function* asyncGenerator() { + await Promise.resolve(); + yield 1; +} +async function* test1() { + yield* asyncGenerator(); + yield* 2; +} + `, + ` +async function* test(source: AsyncIterable) { + yield* source; +} + `, + ` +async function* test(source: Iterable & AsyncIterable) { + yield* source; +} + `, + ` +async function* test(source: Iterable | AsyncIterable) { + yield* source; +} + `, + ` +type MyType = { + [Symbol.iterator](): Iterator; + [Symbol.asyncIterator](): AsyncIterator; +}; +async function* test(source: MyType) { + yield* source; +} + `, + ` +type MyType = { + [Symbol.asyncIterator]: () => AsyncIterator; +}; +async function* test(source: MyType) { + yield* source; +} + `, + ` +type MyFunctionType = () => AsyncIterator; +type MyType = { + [Symbol.asyncIterator]: MyFunctionType; +}; +async function* test(source: MyType) { + yield* source; } `, 'const foo: () => void = async function* () {};', @@ -294,6 +345,41 @@ async function* asyncGenerator() { }, { code: ` +async function* asyncGenerator(source: Iterable) { + yield* source; +} + `, + errors: [ + { + messageId: 'missingAwait', + data: { + name: "Async generator function 'asyncGenerator'", + }, + }, + ], + }, + { + code: ` +function isAsyncIterable(value: unknown): value is AsyncIterable { + return true; +} +async function* asyncGenerator(source: Iterable | AsyncIterable) { + if (!isAsyncIterable(source)) { + yield* source; + } +} + `, + errors: [ + { + messageId: 'missingAwait', + data: { + name: "Async generator function 'asyncGenerator'", + }, + }, + ], + }, + { + code: ` function* syncGenerator() { yield 1; }