From 97f3072cebbbe7d99fefa5aaba66a12494d96a77 Mon Sep 17 00:00:00 2001 From: Robert Nagy Date: Thu, 17 Jun 2021 22:25:34 +0200 Subject: [PATCH] stream: add signal support to pipeline generators Generators in pipeline must be able to be aborted or pipeline can deadlock. PR-URL: https://github.com/nodejs/node/pull/39067 Reviewed-By: Matteo Collina Reviewed-By: Benjamin Gruenbaum Reviewed-By: James M Snell --- doc/api/stream.md | 51 +++++++++++++++++++++++---- lib/internal/streams/compose.js | 2 +- lib/internal/streams/duplexify.js | 18 +++++++--- lib/internal/streams/pipeline.js | 40 ++++++++++++++++++--- lib/stream.js | 2 +- lib/stream/promises.js | 18 ++-------- test/parallel/test-stream-pipeline.js | 19 ++++++++++ 7 files changed, 117 insertions(+), 33 deletions(-) diff --git a/doc/api/stream.md b/doc/api/stream.md index a700eb9233d77d..30660e4e0a855b 100644 --- a/doc/api/stream.md +++ b/doc/api/stream.md @@ -1886,16 +1886,14 @@ const { pipeline } = require('stream/promises'); async function run() { const ac = new AbortController(); - const options = { - signal: ac.signal, - }; + const signal = ac.signal; setTimeout(() => ac.abort(), 1); await pipeline( fs.createReadStream('archive.tar'), zlib.createGzip(), fs.createWriteStream('archive.tar.gz'), - options, + { signal }, ); } @@ -1911,10 +1909,10 @@ const fs = require('fs'); async function run() { await pipeline( fs.createReadStream('lowercase.txt'), - async function* (source) { + async function* (source, signal) { source.setEncoding('utf8'); // Work with strings rather than `Buffer`s. for await (const chunk of source) { - yield chunk.toUpperCase(); + yield await processChunk(chunk, { signal }); } }, fs.createWriteStream('uppercase.txt') @@ -1925,6 +1923,28 @@ async function run() { run().catch(console.error); ``` +Remember to handle the `signal` argument passed into the async generator. +Especially in the case where the async generator is the source for the +pipeline (i.e. first argument) or the pipeline will never complete. + +```js +const { pipeline } = require('stream/promises'); +const fs = require('fs'); + +async function run() { + await pipeline( + async function * (signal) { + await someLongRunningfn({ signal }); + yield 'asd'; + }, + fs.createWriteStream('uppercase.txt') + ); + console.log('Pipeline succeeded.'); +} + +run().catch(console.error); +``` + `stream.pipeline()` will call `stream.destroy(err)` on all streams except: * `Readable` streams which have emitted `'end'` or `'close'`. * `Writable` streams which have emitted `'finish'` or `'close'`. @@ -3342,13 +3362,20 @@ the `Readable.from()` utility method: ```js const { Readable } = require('stream'); +const ac = new AbortController(); +const signal = ac.signal; + async function * generate() { yield 'a'; + await someLongRunningFn({ signal }); yield 'b'; yield 'c'; } const readable = Readable.from(generate()); +readable.on('close', () => { + ac.abort(); +}); readable.on('data', (chunk) => { console.log(chunk); @@ -3368,6 +3395,11 @@ const { pipeline: pipelinePromise } = require('stream/promises'); const writable = fs.createWriteStream('./file'); +const ac = new AbortController(); +const signal = ac.signal; + +const iterator = createIterator({ signal }); + // Callback Pattern pipeline(iterator, writable, (err, value) => { if (err) { @@ -3375,6 +3407,8 @@ pipeline(iterator, writable, (err, value) => { } else { console.log(value, 'value returned'); } +}).on('close', () => { + ac.abort(); }); // Promise Pattern @@ -3382,7 +3416,10 @@ pipelinePromise(iterator, writable) .then((value) => { console.log(value, 'value returned'); }) - .catch(console.error); + .catch((err) => { + console.error(err); + ac.abort(); + }); ``` diff --git a/lib/internal/streams/compose.js b/lib/internal/streams/compose.js index 04b6c6dcb0a53e..d11a372732caab 100644 --- a/lib/internal/streams/compose.js +++ b/lib/internal/streams/compose.js @@ -1,6 +1,6 @@ 'use strict'; -const pipeline = require('internal/streams/pipeline'); +const { pipeline } = require('internal/streams/pipeline'); const Duplex = require('internal/streams/duplex'); const { destroyer } = require('internal/streams/destroy'); const { diff --git a/lib/internal/streams/duplexify.js b/lib/internal/streams/duplexify.js index 448c909fd52c01..fea0e508411c88 100644 --- a/lib/internal/streams/duplexify.js +++ b/lib/internal/streams/duplexify.js @@ -26,6 +26,7 @@ const from = require('internal/streams/from'); const { isBlob, } = require('internal/blob'); +const { AbortController } = require('internal/abort_controller'); const { FunctionPrototypeCall @@ -81,14 +82,15 @@ module.exports = function duplexify(body, name) { // } if (typeof body === 'function') { - const { value, write, final } = fromAsyncGen(body); + const { value, write, final, destroy } = fromAsyncGen(body); if (isIterable(value)) { return from(Duplexify, value, { // TODO (ronag): highWaterMark? objectMode: true, write, - final + final, + destroy }); } @@ -123,7 +125,8 @@ module.exports = function duplexify(body, name) { process.nextTick(cb, err); } }); - } + }, + destroy }); } @@ -202,15 +205,18 @@ module.exports = function duplexify(body, name) { function fromAsyncGen(fn) { let { promise, resolve } = createDeferredPromise(); + const ac = new AbortController(); + const signal = ac.signal; const value = fn(async function*() { while (true) { const { chunk, done, cb } = await promise; process.nextTick(cb); if (done) return; + if (signal.aborted) throw new AbortError(); yield chunk; ({ promise, resolve } = createDeferredPromise()); } - }()); + }(), { signal }); return { value, @@ -219,6 +225,10 @@ function fromAsyncGen(fn) { }, final(cb) { resolve({ done: true, cb }); + }, + destroy(err, cb) { + ac.abort(); + cb(err); } }; } diff --git a/lib/internal/streams/pipeline.js b/lib/internal/streams/pipeline.js index 3d39b3ac7b228a..012d99de0357f2 100644 --- a/lib/internal/streams/pipeline.js +++ b/lib/internal/streams/pipeline.js @@ -21,15 +21,20 @@ const { ERR_MISSING_ARGS, ERR_STREAM_DESTROYED, }, + AbortError, } = require('internal/errors'); -const { validateCallback } = require('internal/validators'); +const { + validateCallback, + validateAbortSignal +} = require('internal/validators'); const { isIterable, isReadableNodeStream, isNodeStream, } = require('internal/streams/utils'); +const { AbortController } = require('internal/abort_controller'); let PassThrough; let Readable; @@ -168,10 +173,26 @@ function pipeline(...streams) { streams = streams[0]; } + return pipelineImpl(streams, callback); +} + +function pipelineImpl(streams, callback, opts) { if (streams.length < 2) { throw new ERR_MISSING_ARGS('streams'); } + const ac = new AbortController(); + const signal = ac.signal; + const outerSignal = opts?.signal; + + validateAbortSignal(outerSignal, 'options.signal'); + + function abort() { + finishImpl(new AbortError()); + } + + outerSignal?.addEventListener('abort', abort); + let error; let value; const destroys = []; @@ -179,8 +200,10 @@ function pipeline(...streams) { let finishCount = 0; function finish(err) { - const final = --finishCount === 0; + finishImpl(err, --finishCount === 0); + } + function finishImpl(err, final) { if (err && (!error || error.code === 'ERR_STREAM_PREMATURE_CLOSE')) { error = err; } @@ -193,6 +216,9 @@ function pipeline(...streams) { destroys.shift()(error); } + outerSignal?.removeEventListener('abort', abort); + ac.abort(); + if (final) { callback(error, value); } @@ -211,7 +237,7 @@ function pipeline(...streams) { if (i === 0) { if (typeof stream === 'function') { - ret = stream(); + ret = stream({ signal }); if (!isIterable(ret)) { throw new ERR_INVALID_RETURN_VALUE( 'Iterable, AsyncIterable or Stream', 'source', ret); @@ -223,7 +249,7 @@ function pipeline(...streams) { } } else if (typeof stream === 'function') { ret = makeAsyncIterable(ret); - ret = stream(ret); + ret = stream(ret, { signal }); if (reading) { if (!isIterable(ret, true)) { @@ -291,7 +317,11 @@ function pipeline(...streams) { } } + if (signal?.aborted || outerSignal?.aborted) { + process.nextTick(abort); + } + return ret; } -module.exports = pipeline; +module.exports = { pipelineImpl, pipeline }; diff --git a/lib/stream.js b/lib/stream.js index 43f59788f62bc8..cc56b76e31a4a6 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -29,8 +29,8 @@ const { promisify: { custom: customPromisify }, } = require('internal/util'); -const pipeline = require('internal/streams/pipeline'); const compose = require('internal/streams/compose'); +const { pipeline } = require('internal/streams/pipeline'); const { destroyer } = require('internal/streams/destroy'); const eos = require('internal/streams/end-of-stream'); const internalBuffer = require('internal/buffer'); diff --git a/lib/stream/promises.js b/lib/stream/promises.js index 8a8e66417c6057..0db01a8b208d60 100644 --- a/lib/stream/promises.js +++ b/lib/stream/promises.js @@ -5,20 +5,12 @@ const { Promise, } = primordials; -const { - addAbortSignalNoValidate, -} = require('internal/streams/add-abort-signal'); - -const { - validateAbortSignal, -} = require('internal/validators'); - const { isIterable, isNodeStream, } = require('internal/streams/utils'); -const pl = require('internal/streams/pipeline'); +const { pipelineImpl: pl } = require('internal/streams/pipeline'); const eos = require('internal/streams/end-of-stream'); function pipeline(...streams) { @@ -29,19 +21,15 @@ function pipeline(...streams) { !isNodeStream(lastArg) && !isIterable(lastArg)) { const options = ArrayPrototypePop(streams); signal = options.signal; - validateAbortSignal(signal, 'options.signal'); } - const pipe = pl(...streams, (err, value) => { + pl(streams, (err, value) => { if (err) { reject(err); } else { resolve(value); } - }); - if (signal) { - addAbortSignalNoValidate(signal, pipe); - } + }, { signal }); }); } diff --git a/test/parallel/test-stream-pipeline.js b/test/parallel/test-stream-pipeline.js index e2e5fe2e0d561a..b21e1ce52b3cb3 100644 --- a/test/parallel/test-stream-pipeline.js +++ b/test/parallel/test-stream-pipeline.js @@ -11,10 +11,12 @@ const { Duplex, addAbortSignal, } = require('stream'); +const pipelinep = require('stream/promises').pipeline; const assert = require('assert'); const http = require('http'); const { promisify } = require('util'); const net = require('net'); +const tsp = require('timers/promises'); { let finished = false; @@ -1387,3 +1389,20 @@ const net = require('net'); assert.strictEqual(res, content); })); } + +{ + const ac = new AbortController(); + const signal = ac.signal; + pipelinep( + async function * ({ signal }) { + await tsp.setTimeout(1e6, signal); + }, + async function(source) { + + }, + { signal } + ).catch(common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + })); + ac.abort(); +}