diff --git a/docs/api/BalancedPool.md b/docs/api/BalancedPool.md index 9ffc83d81b5..4580f07efb4 100644 --- a/docs/api/BalancedPool.md +++ b/docs/api/BalancedPool.md @@ -11,14 +11,15 @@ Requests are not guaranteed to be dispatched in order of invocation. Arguments: * **upstreams** `URL | string | string[]` - It should only include the **protocol, hostname, and port**. -* **options** `PoolOptions` (optional) +* **options** `BalancedPoolOptions` (optional) -### Parameter: `PoolOptions` +### Parameter: `BalancedPoolOptions` -The `PoolOptions` are passed to each of the `Pool` instances being created. +Extends: [`PoolOptions`](Pool.md#parameter-pooloptions) -See: [`PoolOptions`](Pool.md#parameter-pooloptions) +* **factory** `(origin: URL, opts: Object) => Dispatcher` - Default: `(origin, opts) => new Pool(origin, opts)` +The `PoolOptions` are passed to each of the `Pool` instances being created. ## Instance Properties ### `BalancedPool.upstreams` diff --git a/lib/balanced-pool.js b/lib/balanced-pool.js index 4ad0bf48602..641c813b152 100644 --- a/lib/balanced-pool.js +++ b/lib/balanced-pool.js @@ -1,7 +1,8 @@ 'use strict' const { - BalancedPoolMissingUpstreamError + BalancedPoolMissingUpstreamError, + InvalidArgumentError } = require('./core/errors') const { PoolBase, @@ -13,11 +14,16 @@ const { } = require('./pool-base') const Pool = require('./pool') const { kUrl } = require('./core/symbols') +const kFactory = Symbol('factory') const kOptions = Symbol('options') +function defaultFactory (origin, opts) { + return new Pool(origin, opts); +} + class BalancedPool extends PoolBase { - constructor (upstreams = [], opts = {}) { + constructor (upstreams = [], { factory = defaultFactory, ...opts } = {}) { super() this[kOptions] = opts @@ -26,6 +32,12 @@ class BalancedPool extends PoolBase { upstreams = [upstreams] } + if (typeof factory !== 'function') { + throw new InvalidArgumentError('factory must be a function.') + } + + this[kFactory] = factory + for (const upstream of upstreams) { this.addUpstream(upstream) } @@ -40,7 +52,7 @@ class BalancedPool extends PoolBase { return this } - this[kAddClient](new Pool(upstream, Object.assign({}, this[kOptions]))) + this[kAddClient](this[kFactory](upstream, Object.assign({}, this[kOptions]))) return this } diff --git a/test/balanced-pool.js b/test/balanced-pool.js index 473522ac9f2..9349ce96573 100644 --- a/test/balanced-pool.js +++ b/test/balanced-pool.js @@ -1,7 +1,7 @@ 'use strict' const { test } = require('tap') -const { BalancedPool, Client, errors } = require('..') +const { BalancedPool, Client, errors, Pool } = require('..') const { createServer } = require('http') const { promisify } = require('util') @@ -163,3 +163,63 @@ test('busy', (t) => { } }) }) + +test('invalid options throws', (t) => { + t.plan(2) + + try { + new BalancedPool(null, { factory: '' }) // eslint-disable-line + } catch (err) { + t.type(err, errors.InvalidArgumentError) + t.equal(err.message, 'factory must be a function.') + } +}) + +test('factory option with basic get request', async(t) => { + t.plan(12) + + let factoryCalled = 0 + const opts = { + factory: (origin, opts) => { + factoryCalled ++ + return new Pool(origin, opts) + } + } + + const client = new BalancedPool([], opts) // eslint-disable-line + + let serverCalled = 0 + const server = createServer((req, res) => { + serverCalled++ + t.equal('/', req.url) + t.equal('GET', req.method) + res.setHeader('content-type', 'text/plain') + res.end('hello') + }) + t.teardown(server.close.bind(server)) + + await promisify(server.listen).call(server, 0) + + client.addUpstream(`http://localhost:${server.address().port}`) + + t.same(client.upstreams, [`http://localhost:${server.address().port}`]) + + t.teardown(client.destroy.bind(client)) + + { + const { statusCode, headers, body } = await client.request({ path: '/', method: 'GET' }) + t.equal(statusCode, 200) + t.equal(headers['content-type'], 'text/plain') + t.equal('hello', await body.text()) + } + + t.equal(serverCalled, 1) + t.equal(factoryCalled, 1) + + t.equal(client.destroyed, false) + t.equal(client.closed, false) + await client.close() + t.equal(client.destroyed, true) + t.equal(client.closed, true) + +})