diff --git a/src/events/websocket/WebSocketClients.js b/src/events/websocket/WebSocketClients.js index 5bde5ea3e..515fa75b9 100644 --- a/src/events/websocket/WebSocketClients.js +++ b/src/events/websocket/WebSocketClients.js @@ -101,14 +101,43 @@ export default class WebSocketClients { clearTimeout(timeoutId) } - async _processEvent(websocketClient, connectionId, route, event) { - let functionKey = this.#webSocketRoutes.get(route) + async verifyClient(connectionId, request) { + const route = this.#webSocketRoutes.get('$connect') + if (!route) { + return { verified: false, statusCode: 502 } + } + + const connectEvent = new WebSocketConnectEvent( + connectionId, + request, + this.#options, + ).create() + + const lambdaFunction = this.#lambda.get(route.functionKey) + lambdaFunction.setEvent(connectEvent) + + try { + const { statusCode } = await lambdaFunction.runHandler() + const verified = statusCode >= 200 && statusCode < 300 + return { verified, statusCode } + } catch (err) { + if (this.log) { + this.log.debug(`Error in route handler '${route.functionKey}'`, err) + } else { + debugLog(`Error in route handler '${route.functionKey}'`, err) + } + return { verified: false, statusCode: 502 } + } + } + + async _processEvent(websocketClient, connectionId, routeKey, event) { + let route = this.#webSocketRoutes.get(routeKey) - if (!functionKey && route !== '$connect' && route !== '$disconnect') { - functionKey = this.#webSocketRoutes.get('$default') + if (!route && routeKey !== '$disconnect') { + route = this.#webSocketRoutes.get('$default') } - if (!functionKey) { + if (!route) { return } @@ -123,28 +152,29 @@ export default class WebSocketClients { ) } - // mimic AWS behaviour (close connection) when the $connect route handler throws - if (route === '$connect') { - websocketClient.close() - } - if (this.log) { - this.log.debug(`Error in route handler '${functionKey}'`, err) + this.log.debug(`Error in route handler '${route.functionKey}'`, err) } else { - debugLog(`Error in route handler '${functionKey}'`, err) + debugLog(`Error in route handler '${route.functionKey}'`, err) } } - const lambdaFunction = this.#lambda.get(functionKey) + const lambdaFunction = this.#lambda.get(route.functionKey) lambdaFunction.setEvent(event) - // let result - try { - /* result = */ await lambdaFunction.runHandler() - - // TODO what to do with "result"? + const { body } = await lambdaFunction.runHandler() + if ( + body && + routeKey !== '$disconnect' && + route.definition.routeResponseSelectionExpression === '$default' + ) { + // https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-selection-expressions.html#apigateway-websocket-api-route-response-selection-expressions + // TODO: Once API gateway supports RouteResponses, this will need to change to support that functionality + // For now, send body back to the client + this.send(connectionId, body) + } } catch (err) { if (this.log) { this.log.error(err) @@ -176,17 +206,9 @@ export default class WebSocketClients { return route || DEFAULT_WEBSOCKETS_ROUTE } - addClient(webSocketClient, request, connectionId) { + addClient(webSocketClient, connectionId) { this._addWebSocketClient(webSocketClient, connectionId) - const connectEvent = new WebSocketConnectEvent( - connectionId, - request, - this.#options, - ).create() - - this._processEvent(webSocketClient, connectionId, '$connect', connectEvent) - webSocketClient.on('close', () => { if (this.log) { this.log.debug(`disconnect:${connectionId}`) @@ -233,14 +255,17 @@ export default class WebSocketClients { }) } - addRoute(functionKey, route) { + addRoute(functionKey, definition) { // set the route name - this.#webSocketRoutes.set(route, functionKey) + this.#webSocketRoutes.set(definition.route, { + functionKey, + definition, + }) if (this.log) { - this.log.notice(`route '${route}'`) + this.log.notice(`route '${definition}'`) } else { - serverlessLog(`route '${route}'`) + serverlessLog(`route '${definition}'`) } } diff --git a/src/events/websocket/WebSocketServer.js b/src/events/websocket/WebSocketServer.js index 17d016df8..012c5806d 100644 --- a/src/events/websocket/WebSocketServer.js +++ b/src/events/websocket/WebSocketServer.js @@ -6,6 +6,7 @@ import { createUniqueId } from '../../utils/index.js' export default class WebSocketServer { #options = null #webSocketClients = null + #connectionIds = new Map() constructor(options, webSocketClients, sharedServer, v3Utils) { this.#options = options @@ -20,6 +21,35 @@ export default class WebSocketServer { const server = new Server({ server: sharedServer, + verifyClient: ({ req }, cb) => { + const connectionId = createUniqueId() + const { headers } = req + const key = headers['sec-websocket-key'] + + if (this.log) { + this.log.debug(`verifyClient:${key} ${connectionId}`) + } else { + debugLog(`verifyClient:${key} ${connectionId}`) + } + + // Use the websocket key to coorelate connection IDs + this.#connectionIds[key] = connectionId + + this.#webSocketClients + .verifyClient(connectionId, req) + .then(({ verified, statusCode }) => { + try { + if (!verified) { + cb(false, statusCode) + return + } + cb(true) + } catch (e) { + debugLog(`Error verifying`, e) + cb(false) + } + }) + }, }) server.on('connection', (webSocketClient, request) => { @@ -29,7 +59,10 @@ export default class WebSocketServer { console.log('received connection') } - const connectionId = createUniqueId() + const { headers } = request + const key = headers['sec-websocket-key'] + + const connectionId = this.#connectionIds[key] if (this.log) { this.log.debug(`connect:${connectionId}`) @@ -37,7 +70,7 @@ export default class WebSocketServer { debugLog(`connect:${connectionId}`) } - this.#webSocketClients.addClient(webSocketClient, request, connectionId) + this.#webSocketClients.addClient(webSocketClient, connectionId) }) } @@ -63,7 +96,7 @@ export default class WebSocketServer { stop() {} addRoute(functionKey, webSocketEvent) { - this.#webSocketClients.addRoute(functionKey, webSocketEvent.route) + this.#webSocketClients.addRoute(functionKey, webSocketEvent) // serverlessLog(`route '${route}'`) } } diff --git a/tests/integration/_testHelpers/websocketPromise.js b/tests/integration/_testHelpers/websocketPromise.js new file mode 100644 index 000000000..d8f695941 --- /dev/null +++ b/tests/integration/_testHelpers/websocketPromise.js @@ -0,0 +1,24 @@ +const websocketSend = (ws, data) => + new Promise((res) => { + ws.on('open', () => { + ws.send(data, (e) => { + if (e) { + res({ err: e }) + } + }) + }) + ws.on('close', (c) => { + res({ code: c }) + }) + ws.on('message', (d) => { + res({ data: d }) + }) + ws.on('error', (e) => { + res({ err: e }) + }) + setTimeout(() => { + res({}) + }, 5000) + }) + +export default websocketSend diff --git a/tests/integration/websocket-oneway/handler.js b/tests/integration/websocket-oneway/handler.js new file mode 100644 index 000000000..68d71d7d2 --- /dev/null +++ b/tests/integration/websocket-oneway/handler.js @@ -0,0 +1,19 @@ +'use strict' + +exports.handler = async (event) => { + const { body, requestContext } = event + + if ( + body && + JSON.parse(body).throwError && + requestContext && + requestContext.routeKey === '$default' + ) { + throw new Error('Throwing error from incoming message') + } + + return { + statusCode: 200, + body: body || undefined, + } +} diff --git a/tests/integration/websocket-oneway/serverless.yml b/tests/integration/websocket-oneway/serverless.yml new file mode 100644 index 000000000..c133f9605 --- /dev/null +++ b/tests/integration/websocket-oneway/serverless.yml @@ -0,0 +1,26 @@ +service: oneway-websocket-tests + +plugins: + - ../../../ + +provider: + memorySize: 128 + name: aws + region: us-east-1 # default + runtime: nodejs12.x + stage: dev + versionFunctions: false + +functions: + handler: + handler: handler.handler + events: + - http: + path: echo + method: get + - websocket: + route: $connect + - websocket: + route: $disconnect + - websocket: + route: $default diff --git a/tests/integration/websocket-oneway/websocket-oneway.test.js b/tests/integration/websocket-oneway/websocket-oneway.test.js new file mode 100644 index 000000000..304f31c60 --- /dev/null +++ b/tests/integration/websocket-oneway/websocket-oneway.test.js @@ -0,0 +1,55 @@ +import { resolve } from 'path' +import WebSocket from 'ws' +import { joinUrl, setup, teardown } from '../_testHelpers/index.js' +import websocketSend from '../_testHelpers/websocketPromise.js' + +jest.setTimeout(30000) + +describe('one way websocket tests', () => { + // init + beforeAll(() => + setup({ + servicePath: resolve(__dirname), + }), + ) + + // cleanup + afterAll(() => teardown()) + + test('websocket echos nothing', async () => { + const url = new URL(joinUrl(TEST_BASE_URL, '/dev')) + url.port = url.port ? '3001' : url.port + url.protocol = 'ws' + + const payload = JSON.stringify({ + hello: 'world', + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url.toString()) + const { data, code, err } = await websocketSend(ws, payload) + + expect(code).toBeUndefined() + expect(err).toBeUndefined() + expect(data).toBeUndefined() + }) + + test('execution error emits Internal Server Error', async () => { + const url = new URL(joinUrl(TEST_BASE_URL, '/dev')) + url.port = url.port ? '3001' : url.port + url.protocol = 'ws' + + const payload = JSON.stringify({ + hello: 'world', + now: new Date().toISOString(), + throwError: true, + }) + + const ws = new WebSocket(url.toString()) + const { data, code, err } = await websocketSend(ws, payload) + + expect(code).toBeUndefined() + expect(err).toBeUndefined() + expect(JSON.parse(data).message).toEqual('Internal server error') + }) +}) diff --git a/tests/integration/websocket-twoway/handler.js b/tests/integration/websocket-twoway/handler.js new file mode 100644 index 000000000..cc5b1ed59 --- /dev/null +++ b/tests/integration/websocket-twoway/handler.js @@ -0,0 +1,32 @@ +'use strict' + +exports.handler = async (event) => { + const { body, queryStringParameters, requestContext } = event + const statusCode = + queryStringParameters && queryStringParameters.statusCode + ? Number(queryStringParameters.statusCode) + : 200 + + if ( + queryStringParameters && + queryStringParameters.throwError && + requestContext && + requestContext.routeKey === '$connect' + ) { + throw new Error('Throwing error during connect phase') + } + + if ( + body && + JSON.parse(body).throwError && + requestContext && + requestContext.routeKey === '$default' + ) { + throw new Error('Throwing error from incoming message') + } + + return { + statusCode, + body: body || undefined, + } +} diff --git a/tests/integration/websocket-twoway/serverless.yml b/tests/integration/websocket-twoway/serverless.yml new file mode 100644 index 000000000..d48dd3a4f --- /dev/null +++ b/tests/integration/websocket-twoway/serverless.yml @@ -0,0 +1,28 @@ +service: twoway-websocket-tests + +plugins: + - ../../../ + +provider: + memorySize: 128 + name: aws + region: us-east-1 # default + runtime: nodejs12.x + stage: dev + versionFunctions: false + +functions: + handler: + handler: handler.handler + events: + - http: + path: echo + method: get + - websocket: + route: $connect + - websocket: + route: $disconnect + - websocket: + route: $default + # Enable 2-way comms + routeResponseSelectionExpression: $default diff --git a/tests/integration/websocket-twoway/websocket-twoway.test.js b/tests/integration/websocket-twoway/websocket-twoway.test.js new file mode 100644 index 000000000..ae8da1187 --- /dev/null +++ b/tests/integration/websocket-twoway/websocket-twoway.test.js @@ -0,0 +1,102 @@ +import { resolve } from 'path' +import WebSocket from 'ws' +import { joinUrl, setup, teardown } from '../_testHelpers/index.js' +import websocketSend from '../_testHelpers/websocketPromise.js' + +jest.setTimeout(30000) + +describe('two way websocket tests', () => { + // init + beforeAll(() => + setup({ + servicePath: resolve(__dirname), + }), + ) + + // cleanup + afterAll(() => teardown()) + + test('websocket echos sent message', async () => { + const url = new URL(joinUrl(TEST_BASE_URL, '/dev')) + url.port = url.port ? '3001' : url.port + url.protocol = 'ws' + + const payload = JSON.stringify({ + hello: 'world', + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url.toString()) + const { data, code, err } = await websocketSend(ws, payload) + + expect(code).toBeUndefined() + expect(err).toBeUndefined() + expect(data).toEqual(payload) + }) + + test.each([401, 500, 501, 502])( + 'websocket connection emits status code %s', + async (statusCode) => { + const url = new URL(joinUrl(TEST_BASE_URL, '/dev')) + url.port = url.port ? '3001' : url.port + url.searchParams.set('statusCode', statusCode) + url.protocol = 'ws' + + const payload = JSON.stringify({ + hello: 'world', + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url.toString()) + const { data, code, err } = await websocketSend(ws, payload) + + expect(code).toBeUndefined() + + if (statusCode >= 200 && statusCode < 300) { + expect(err).toBeUndefined() + expect(data).toEqual(payload) + } else { + expect(err.message).toEqual(`Unexpected server response: ${statusCode}`) + expect(data).toBeUndefined() + } + }, + ) + + test('websocket emits 502 on connection error', async () => { + const url = new URL(joinUrl(TEST_BASE_URL, '/dev')) + url.port = url.port ? '3001' : url.port + url.searchParams.set('throwError', 'true') + url.protocol = 'ws' + + const payload = JSON.stringify({ + hello: 'world', + now: new Date().toISOString(), + }) + + const ws = new WebSocket(url.toString()) + const { data, code, err } = await websocketSend(ws, payload) + + expect(code).toBeUndefined() + expect(err.message).toEqual('Unexpected server response: 502') + expect(data).toBeUndefined() + }) + + test('execution error emits Internal Server Error', async () => { + const url = new URL(joinUrl(TEST_BASE_URL, '/dev')) + url.port = url.port ? '3001' : url.port + url.protocol = 'ws' + + const payload = JSON.stringify({ + hello: 'world', + now: new Date().toISOString(), + throwError: true, + }) + + const ws = new WebSocket(url.toString()) + const { data, code, err } = await websocketSend(ws, payload) + + expect(code).toBeUndefined() + expect(err).toBeUndefined() + expect(JSON.parse(data).message).toEqual('Internal server error') + }) +})