diff --git a/lib/websocket.js b/lib/websocket.js index 68405dceb..3132cc150 100644 --- a/lib/websocket.js +++ b/lib/websocket.js @@ -771,8 +771,11 @@ function initAsClient(websocket, address, protocols, options) { if (opts.followRedirects) { if (websocket._redirects === 0) { + websocket._originalUnixSocket = isUnixSocket; websocket._originalSecure = isSecure; - websocket._originalHost = parsedUrl.host; + websocket._originalHostOrSocketPath = isUnixSocket + ? opts.socketPath + : parsedUrl.host; const headers = options && options.headers; @@ -788,7 +791,13 @@ function initAsClient(websocket, address, protocols, options) { } } } else if (websocket.listenerCount('redirect') === 0) { - const isSameHost = parsedUrl.host === websocket._originalHost; + const isSameHost = isUnixSocket + ? websocket._originalUnixSocket + ? opts.socketPath === websocket._originalHostOrSocketPath + : false + : websocket._originalUnixSocket + ? false + : parsedUrl.host === websocket._originalHostOrSocketPath; if (!isSameHost || (websocket._originalSecure && !isSecure)) { // diff --git a/test/websocket.test.js b/test/websocket.test.js index 64b5bf224..f5fbf1650 100644 --- a/test/websocket.test.js +++ b/test/websocket.test.js @@ -6,8 +6,10 @@ const assert = require('assert'); const crypto = require('crypto'); const https = require('https'); const http = require('http'); +const path = require('path'); const net = require('net'); const tls = require('tls'); +const os = require('os'); const fs = require('fs'); const { URL } = require('url'); @@ -1477,7 +1479,9 @@ describe('WebSocket', () => { }); }); - it('drops the Authorization, Cookie and Host headers', (done) => { + it('drops the Authorization, Cookie and Host headers (1/4)', (done) => { + // Test the `ws:` to `ws:` case. + const wss = new WebSocket.Server({ port: 0 }, () => { const port = wss.address().port; @@ -1531,6 +1535,221 @@ describe('WebSocket', () => { ws.close(); }); }); + + it('drops the Authorization, Cookie and Host headers (2/4)', function (done) { + if (process.platform === 'win32') return this.skip(); + + // Test the `ws:` to `ws+unix:` case. + + const socketPath = path.join( + os.tmpdir(), + `ws.${crypto.randomBytes(16).toString('hex')}.sock` + ); + + server.once('upgrade', (req, socket) => { + socket.end( + `HTTP/1.1 302 Found\r\nLocation: ws+unix://${socketPath}\r\n\r\n` + ); + }); + + const redirectedServer = http.createServer(); + const wss = new WebSocket.Server({ server: redirectedServer }); + + wss.on('connection', (ws, req) => { + assert.strictEqual(req.headers.authorization, undefined); + assert.strictEqual(req.headers.cookie, undefined); + assert.strictEqual(req.headers.host, 'localhost'); + + ws.close(); + }); + + redirectedServer.listen(socketPath, () => { + const headers = { + authorization: 'Basic Zm9vOmJhcg==', + cookie: 'foo=bar', + host: 'foo' + }; + + const ws = new WebSocket( + `ws://localhost:${server.address().port}`, + { followRedirects: true, headers } + ); + + const firstRequest = ws._req; + + assert.strictEqual( + firstRequest.getHeader('Authorization'), + headers.authorization + ); + assert.strictEqual( + firstRequest.getHeader('Cookie'), + headers.cookie + ); + assert.strictEqual(firstRequest.getHeader('Host'), headers.host); + + ws.on('close', (code) => { + assert.strictEqual(code, 1005); + assert.strictEqual(ws.url, `ws+unix://${socketPath}`); + assert.strictEqual(ws._redirects, 1); + + redirectedServer.close(done); + }); + }); + }); + + it('drops the Authorization, Cookie and Host headers (3/4)', function (done) { + if (process.platform === 'win32') return this.skip(); + + // Test the `ws+unix:` to `ws+unix:` case. + + const redirectingServerSocketPath = path.join( + os.tmpdir(), + `ws.${crypto.randomBytes(16).toString('hex')}.sock` + ); + const redirectedServerSocketPath = path.join( + os.tmpdir(), + `ws.${crypto.randomBytes(16).toString('hex')}.sock` + ); + + const redirectingServer = http.createServer(); + + redirectingServer.on('upgrade', (req, socket) => { + socket.end( + 'HTTP/1.1 302 Found\r\n' + + `Location: ws+unix://${redirectedServerSocketPath}\r\n\r\n` + ); + }); + + const redirectedServer = http.createServer(); + const wss = new WebSocket.Server({ server: redirectedServer }); + + wss.on('connection', (ws, req) => { + assert.strictEqual(req.headers.authorization, undefined); + assert.strictEqual(req.headers.cookie, undefined); + assert.strictEqual(req.headers.host, 'localhost'); + + ws.close(); + }); + + redirectingServer.listen(redirectingServerSocketPath, listening); + redirectedServer.listen(redirectedServerSocketPath, listening); + + let callCount = 0; + + function listening() { + if (++callCount !== 2) return; + + const headers = { + authorization: 'Basic Zm9vOmJhcg==', + cookie: 'foo=bar', + host: 'foo' + }; + + const ws = new WebSocket( + `ws+unix://${redirectingServerSocketPath}`, + { followRedirects: true, headers } + ); + + const firstRequest = ws._req; + + assert.strictEqual( + firstRequest.getHeader('Authorization'), + headers.authorization + ); + assert.strictEqual( + firstRequest.getHeader('Cookie'), + headers.cookie + ); + assert.strictEqual(firstRequest.getHeader('Host'), headers.host); + + ws.on('close', (code) => { + assert.strictEqual(code, 1005); + assert.strictEqual( + ws.url, + `ws+unix://${redirectedServerSocketPath}` + ); + assert.strictEqual(ws._redirects, 1); + + redirectingServer.close(); + redirectedServer.close(done); + }); + } + }); + + it('drops the Authorization, Cookie and Host headers (4/4)', function (done) { + if (process.platform === 'win32') return this.skip(); + + // Test the `ws+unix:` to `ws:` case. + + const redirectingServer = http.createServer(); + const redirectedServer = http.createServer(); + const wss = new WebSocket.Server({ server: redirectedServer }); + + wss.on('connection', (ws, req) => { + assert.strictEqual(req.headers.authorization, undefined); + assert.strictEqual(req.headers.cookie, undefined); + assert.strictEqual( + req.headers.host, + `localhost:${redirectedServer.address().port}` + ); + + ws.close(); + }); + + const socketPath = path.join( + os.tmpdir(), + `ws.${crypto.randomBytes(16).toString('hex')}.sock` + ); + + redirectingServer.listen(socketPath, listening); + redirectedServer.listen(0, listening); + + let callCount = 0; + + function listening() { + if (++callCount !== 2) return; + + const port = redirectedServer.address().port; + + redirectingServer.on('upgrade', (req, socket) => { + socket.end( + `HTTP/1.1 302 Found\r\nLocation: ws://localhost:${port}\r\n\r\n` + ); + }); + + const headers = { + authorization: 'Basic Zm9vOmJhcg==', + cookie: 'foo=bar', + host: 'foo' + }; + + const ws = new WebSocket(`ws+unix://${socketPath}`, { + followRedirects: true, + headers + }); + + const firstRequest = ws._req; + + assert.strictEqual( + firstRequest.getHeader('Authorization'), + headers.authorization + ); + assert.strictEqual( + firstRequest.getHeader('Cookie'), + headers.cookie + ); + assert.strictEqual(firstRequest.getHeader('Host'), headers.host); + + ws.on('close', (code) => { + assert.strictEqual(code, 1005); + assert.strictEqual(ws.url, `ws://localhost:${port}/`); + assert.strictEqual(ws._redirects, 1); + + redirectingServer.close(); + redirectedServer.close(done); + }); + } + }); }); describe("If there is at least one 'redirect' event listener", () => {