From 633d3f292d6a2820b32442b72b7b1b500d9d2c51 Mon Sep 17 00:00:00 2001 From: Debadree Chatterjee Date: Fri, 17 Feb 2023 16:39:08 +0530 Subject: [PATCH] stream: add abort signal for ReadableStream and WritableStream Refs: https://github.com/nodejs/node/issues/39316 PR-URL: https://github.com/nodejs/node/pull/46273 Reviewed-By: Matteo Collina Reviewed-By: Benjamin Gruenbaum Reviewed-By: James M Snell --- doc/api/stream.md | 42 ++++- lib/internal/streams/add-abort-signal.js | 25 ++- lib/internal/streams/utils.js | 2 + lib/internal/webstreams/readablestream.js | 4 +- lib/internal/webstreams/writablestream.js | 4 + .../test-webstreams-abort-controller.js | 168 ++++++++++++++++++ 6 files changed, 233 insertions(+), 12 deletions(-) create mode 100644 test/parallel/test-webstreams-abort-controller.js diff --git a/doc/api/stream.md b/doc/api/stream.md index a9f72c120190d2..4b4b3b586770cd 100644 --- a/doc/api/stream.md +++ b/doc/api/stream.md @@ -3226,17 +3226,24 @@ readable.getReader().read().then((result) => { * `signal` {AbortSignal} A signal representing possible cancellation -* `stream` {Stream} a stream to attach a signal to +* `stream` {Stream|ReadableStream|WritableStream} + +A stream to attach a signal to. Attaches an AbortSignal to a readable or writeable stream. This lets code control stream destruction using an `AbortController`. Calling `abort` on the `AbortController` corresponding to the passed `AbortSignal` will behave the same way as calling `.destroy(new AbortError())` -on the stream. +on the stream, and `controller.error(new AbortError())` for webstreams. ```js const fs = require('node:fs'); @@ -3274,6 +3281,37 @@ const stream = addAbortSignal( })(); ``` +Or using an `AbortSignal` with a ReadableStream: + +```js +const controller = new AbortController(); +const rs = new ReadableStream({ + start(controller) { + controller.enqueue('hello'); + controller.enqueue('world'); + controller.close(); + }, +}); + +addAbortSignal(controller.signal, rs); + +finished(rs, (err) => { + if (err) { + if (err.name === 'AbortError') { + // The operation was cancelled + } + } +}); + +const reader = rs.getReader(); + +reader.read().then(({ value, done }) => { + console.log(value); // hello + console.log(done); // false + controller.abort(); +}); +``` + ## API for stream implementers diff --git a/lib/internal/streams/add-abort-signal.js b/lib/internal/streams/add-abort-signal.js index 9bcd202ec63c1e..d6c8ca4c9c7842 100644 --- a/lib/internal/streams/add-abort-signal.js +++ b/lib/internal/streams/add-abort-signal.js @@ -5,6 +5,12 @@ const { codes, } = require('internal/errors'); +const { + isNodeStream, + isWebStream, + kControllerErrorFunction, +} = require('internal/streams/utils'); + const eos = require('internal/streams/end-of-stream'); const { ERR_INVALID_ARG_TYPE } = codes; @@ -18,24 +24,25 @@ const validateAbortSignal = (signal, name) => { } }; -function isNodeStream(obj) { - return !!(obj && typeof obj.pipe === 'function'); -} - module.exports.addAbortSignal = function addAbortSignal(signal, stream) { validateAbortSignal(signal, 'signal'); - if (!isNodeStream(stream)) { - throw new ERR_INVALID_ARG_TYPE('stream', 'stream.Stream', stream); + if (!isNodeStream(stream) && !isWebStream(stream)) { + throw new ERR_INVALID_ARG_TYPE('stream', ['ReadableStream', 'WritableStream', 'Stream'], stream); } return module.exports.addAbortSignalNoValidate(signal, stream); }; + module.exports.addAbortSignalNoValidate = function(signal, stream) { if (typeof signal !== 'object' || !('aborted' in signal)) { return stream; } - const onAbort = () => { - stream.destroy(new AbortError(undefined, { cause: signal.reason })); - }; + const onAbort = isNodeStream(stream) ? + () => { + stream.destroy(new AbortError(undefined, { cause: signal.reason })); + } : + () => { + stream[kControllerErrorFunction](new AbortError(undefined, { cause: signal.reason })); + }; if (signal.aborted) { onAbort(); } else { diff --git a/lib/internal/streams/utils.js b/lib/internal/streams/utils.js index 74faca5fe9bb2a..c9e61ca8cdd8eb 100644 --- a/lib/internal/streams/utils.js +++ b/lib/internal/streams/utils.js @@ -13,6 +13,7 @@ const kIsReadable = Symbol('kIsReadable'); const kIsDisturbed = Symbol('kIsDisturbed'); const kIsClosedPromise = SymbolFor('nodejs.webstream.isClosedPromise'); +const kControllerErrorFunction = SymbolFor('nodejs.webstream.controllerErrorFunction'); function isReadableNodeStream(obj, strict = false) { return !!( @@ -305,6 +306,7 @@ module.exports = { isReadable, kIsReadable, kIsClosedPromise, + kControllerErrorFunction, isClosed, isDestroyed, isDuplexNodeStream, diff --git a/lib/internal/webstreams/readablestream.js b/lib/internal/webstreams/readablestream.js index cc683e956db5f5..5f4370d48e40bd 100644 --- a/lib/internal/webstreams/readablestream.js +++ b/lib/internal/webstreams/readablestream.js @@ -83,6 +83,7 @@ const { kIsErrored, kIsReadable, kIsClosedPromise, + kControllerErrorFunction, } = require('internal/streams/utils'); const { @@ -260,6 +261,7 @@ class ReadableStream { }; this[kIsClosedPromise] = createDeferredPromise(); + this[kControllerErrorFunction] = () => {}; // The spec requires handling of the strategy first // here. Specifically, if getting the size and @@ -1891,7 +1893,6 @@ function readableStreamClose(stream) { assert(stream[kState].state === 'readable'); stream[kState].state = 'closed'; stream[kIsClosedPromise].resolve(); - const { reader, } = stream[kState]; @@ -2330,6 +2331,7 @@ function setupReadableStreamDefaultController( stream, }; stream[kState].controller = controller; + stream[kControllerErrorFunction] = FunctionPrototypeBind(controller.error, controller); const startResult = startAlgorithm(); diff --git a/lib/internal/webstreams/writablestream.js b/lib/internal/webstreams/writablestream.js index a8922c08456358..04a4db15300682 100644 --- a/lib/internal/webstreams/writablestream.js +++ b/lib/internal/webstreams/writablestream.js @@ -71,6 +71,7 @@ const { const { kIsClosedPromise, + kControllerErrorFunction, } = require('internal/streams/utils'); const { @@ -199,6 +200,7 @@ class WritableStream { }; this[kIsClosedPromise] = createDeferredPromise(); + this[kControllerErrorFunction] = () => {}; const size = extractSizeAlgorithm(strategy?.size); const highWaterMark = extractHighWaterMark(strategy?.highWaterMark, 1); @@ -370,6 +372,7 @@ function TransferredWritableStream() { }, }; this[kIsClosedPromise] = createDeferredPromise(); + this[kControllerErrorFunction] = () => {}; }, [], WritableStream)); } @@ -1282,6 +1285,7 @@ function setupWritableStreamDefaultController( writeAlgorithm, }; stream[kState].controller = controller; + stream[kControllerErrorFunction] = FunctionPrototypeBind(controller.error, controller); writableStreamUpdateBackpressure( stream, diff --git a/test/parallel/test-webstreams-abort-controller.js b/test/parallel/test-webstreams-abort-controller.js new file mode 100644 index 00000000000000..1468e418c6cb4a --- /dev/null +++ b/test/parallel/test-webstreams-abort-controller.js @@ -0,0 +1,168 @@ +'use strict'; + +const common = require('../common'); +const { finished, addAbortSignal } = require('stream'); +const { ReadableStream, WritableStream } = require('stream/web'); +const assert = require('assert'); + +function createTestReadableStream() { + return new ReadableStream({ + start(controller) { + controller.enqueue('a'); + controller.enqueue('b'); + controller.enqueue('c'); + controller.close(); + } + }); +} + +function createTestWritableStream(values) { + return new WritableStream({ + write(chunk) { + values.push(chunk); + } + }); +} + +{ + const rs = createTestReadableStream(); + + const reader = rs.getReader(); + + const ac = new AbortController(); + + addAbortSignal(ac.signal, rs); + + finished(rs, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.rejects(reader.read(), /AbortError/).then(common.mustCall()); + assert.rejects(reader.closed, /AbortError/).then(common.mustCall()); + })); + + reader.read().then(common.mustCall((result) => { + assert.strictEqual(result.value, 'a'); + ac.abort(); + })); +} + +{ + const rs = createTestReadableStream(); + + const ac = new AbortController(); + + addAbortSignal(ac.signal, rs); + + assert.rejects((async () => { + for await (const chunk of rs) { + if (chunk === 'b') { + ac.abort(); + } + } + })(), /AbortError/).then(common.mustCall()); +} + +{ + const rs1 = createTestReadableStream(); + + const rs2 = createTestReadableStream(); + + const ac = new AbortController(); + + addAbortSignal(ac.signal, rs1); + addAbortSignal(ac.signal, rs2); + + const reader1 = rs1.getReader(); + const reader2 = rs2.getReader(); + + finished(rs1, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.rejects(reader1.read(), /AbortError/).then(common.mustCall()); + assert.rejects(reader1.closed, /AbortError/).then(common.mustCall()); + })); + + finished(rs2, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.rejects(reader2.read(), /AbortError/).then(common.mustCall()); + assert.rejects(reader2.closed, /AbortError/).then(common.mustCall()); + })); + + ac.abort(); +} + +{ + const rs = createTestReadableStream(); + + const { 0: rs1, 1: rs2 } = rs.tee(); + + const ac = new AbortController(); + + addAbortSignal(ac.signal, rs); + + const reader1 = rs1.getReader(); + const reader2 = rs2.getReader(); + + finished(rs1, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.rejects(reader1.read(), /AbortError/).then(common.mustCall()); + assert.rejects(reader1.closed, /AbortError/).then(common.mustCall()); + })); + + finished(rs2, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.rejects(reader2.read(), /AbortError/).then(common.mustCall()); + assert.rejects(reader2.closed, /AbortError/).then(common.mustCall()); + })); + + ac.abort(); +} + +{ + const values = []; + const ws = createTestWritableStream(values); + + const ac = new AbortController(); + + addAbortSignal(ac.signal, ws); + + const writer = ws.getWriter(); + + finished(ws, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.deepStrictEqual(values, ['a']); + assert.rejects(writer.write('b'), /AbortError/).then(common.mustCall()); + assert.rejects(writer.closed, /AbortError/).then(common.mustCall()); + })); + + writer.write('a').then(() => { + ac.abort(); + }); +} + +{ + const values = []; + + const ws1 = createTestWritableStream(values); + const ws2 = createTestWritableStream(values); + + const ac = new AbortController(); + + addAbortSignal(ac.signal, ws1); + addAbortSignal(ac.signal, ws2); + + const writer1 = ws1.getWriter(); + const writer2 = ws2.getWriter(); + + finished(ws1, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.rejects(writer1.write('a'), /AbortError/).then(common.mustCall()); + assert.rejects(writer1.closed, /AbortError/).then(common.mustCall()); + })); + + finished(ws2, common.mustCall((err) => { + assert.strictEqual(err.name, 'AbortError'); + assert.rejects(writer2.write('a'), /AbortError/).then(common.mustCall()); + assert.rejects(writer2.closed, /AbortError/).then(common.mustCall()); + })); + + ac.abort(); +}