Skip to content

Commit

Permalink
feat: Ensure websocket parity with API Gateway (#1301)
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuss committed Feb 18, 2022
1 parent e9e8169 commit 8f02226
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 34 deletions.
87 changes: 56 additions & 31 deletions src/events/websocket/WebSocketClients.js
Expand Up @@ -101,14 +101,43 @@ export default class WebSocketClients {
clearTimeout(timeoutId)
}

async _processEvent(websocketClient, connectionId, route, event) {
let functionKey = this.#webSocketRoutes.get(route)
async verifyClient(connectionId, request) {
const route = this.#webSocketRoutes.get('$connect')
if (!route) {
return { verified: false, statusCode: 502 }
}

const connectEvent = new WebSocketConnectEvent(
connectionId,
request,
this.#options,
).create()

const lambdaFunction = this.#lambda.get(route.functionKey)
lambdaFunction.setEvent(connectEvent)

try {
const { statusCode } = await lambdaFunction.runHandler()
const verified = statusCode >= 200 && statusCode < 300
return { verified, statusCode }
} catch (err) {
if (this.log) {
this.log.debug(`Error in route handler '${route.functionKey}'`, err)
} else {
debugLog(`Error in route handler '${route.functionKey}'`, err)
}
return { verified: false, statusCode: 502 }
}
}

async _processEvent(websocketClient, connectionId, routeKey, event) {
let route = this.#webSocketRoutes.get(routeKey)

if (!functionKey && route !== '$connect' && route !== '$disconnect') {
functionKey = this.#webSocketRoutes.get('$default')
if (!route && routeKey !== '$disconnect') {
route = this.#webSocketRoutes.get('$default')
}

if (!functionKey) {
if (!route) {
return
}

Expand All @@ -123,28 +152,29 @@ export default class WebSocketClients {
)
}

// mimic AWS behaviour (close connection) when the $connect route handler throws
if (route === '$connect') {
websocketClient.close()
}

if (this.log) {
this.log.debug(`Error in route handler '${functionKey}'`, err)
this.log.debug(`Error in route handler '${route.functionKey}'`, err)
} else {
debugLog(`Error in route handler '${functionKey}'`, err)
debugLog(`Error in route handler '${route.functionKey}'`, err)
}
}

const lambdaFunction = this.#lambda.get(functionKey)
const lambdaFunction = this.#lambda.get(route.functionKey)

lambdaFunction.setEvent(event)

// let result

try {
/* result = */ await lambdaFunction.runHandler()

// TODO what to do with "result"?
const { body } = await lambdaFunction.runHandler()
if (
body &&
routeKey !== '$disconnect' &&
route.definition.routeResponseSelectionExpression === '$default'
) {
// https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-selection-expressions.html#apigateway-websocket-api-route-response-selection-expressions
// TODO: Once API gateway supports RouteResponses, this will need to change to support that functionality
// For now, send body back to the client
this.send(connectionId, body)
}
} catch (err) {
if (this.log) {
this.log.error(err)
Expand Down Expand Up @@ -176,17 +206,9 @@ export default class WebSocketClients {
return route || DEFAULT_WEBSOCKETS_ROUTE
}

addClient(webSocketClient, request, connectionId) {
addClient(webSocketClient, connectionId) {
this._addWebSocketClient(webSocketClient, connectionId)

const connectEvent = new WebSocketConnectEvent(
connectionId,
request,
this.#options,
).create()

this._processEvent(webSocketClient, connectionId, '$connect', connectEvent)

webSocketClient.on('close', () => {
if (this.log) {
this.log.debug(`disconnect:${connectionId}`)
Expand Down Expand Up @@ -233,14 +255,17 @@ export default class WebSocketClients {
})
}

addRoute(functionKey, route) {
addRoute(functionKey, definition) {
// set the route name
this.#webSocketRoutes.set(route, functionKey)
this.#webSocketRoutes.set(definition.route, {
functionKey,
definition,
})

if (this.log) {
this.log.notice(`route '${route}'`)
this.log.notice(`route '${definition}'`)
} else {
serverlessLog(`route '${route}'`)
serverlessLog(`route '${definition}'`)
}
}

Expand Down
39 changes: 36 additions & 3 deletions src/events/websocket/WebSocketServer.js
Expand Up @@ -6,6 +6,7 @@ import { createUniqueId } from '../../utils/index.js'
export default class WebSocketServer {
#options = null
#webSocketClients = null
#connectionIds = new Map()

constructor(options, webSocketClients, sharedServer, v3Utils) {
this.#options = options
Expand All @@ -20,6 +21,35 @@ export default class WebSocketServer {

const server = new Server({
server: sharedServer,
verifyClient: ({ req }, cb) => {
const connectionId = createUniqueId()
const { headers } = req
const key = headers['sec-websocket-key']

if (this.log) {
this.log.debug(`verifyClient:${key} ${connectionId}`)
} else {
debugLog(`verifyClient:${key} ${connectionId}`)
}

// Use the websocket key to coorelate connection IDs
this.#connectionIds[key] = connectionId

this.#webSocketClients
.verifyClient(connectionId, req)
.then(({ verified, statusCode }) => {
try {
if (!verified) {
cb(false, statusCode)
return
}
cb(true)
} catch (e) {
debugLog(`Error verifying`, e)
cb(false)
}
})
},
})

server.on('connection', (webSocketClient, request) => {
Expand All @@ -29,15 +59,18 @@ export default class WebSocketServer {
console.log('received connection')
}

const connectionId = createUniqueId()
const { headers } = request
const key = headers['sec-websocket-key']

const connectionId = this.#connectionIds[key]

if (this.log) {
this.log.debug(`connect:${connectionId}`)
} else {
debugLog(`connect:${connectionId}`)
}

this.#webSocketClients.addClient(webSocketClient, request, connectionId)
this.#webSocketClients.addClient(webSocketClient, connectionId)
})
}

Expand All @@ -63,7 +96,7 @@ export default class WebSocketServer {
stop() {}

addRoute(functionKey, webSocketEvent) {
this.#webSocketClients.addRoute(functionKey, webSocketEvent.route)
this.#webSocketClients.addRoute(functionKey, webSocketEvent)
// serverlessLog(`route '${route}'`)
}
}
24 changes: 24 additions & 0 deletions tests/integration/_testHelpers/websocketPromise.js
@@ -0,0 +1,24 @@
const websocketSend = (ws, data) =>
new Promise((res) => {
ws.on('open', () => {
ws.send(data, (e) => {
if (e) {
res({ err: e })
}
})
})
ws.on('close', (c) => {
res({ code: c })
})
ws.on('message', (d) => {
res({ data: d })
})
ws.on('error', (e) => {
res({ err: e })
})
setTimeout(() => {
res({})
}, 5000)
})

export default websocketSend
19 changes: 19 additions & 0 deletions tests/integration/websocket-oneway/handler.js
@@ -0,0 +1,19 @@
'use strict'

exports.handler = async (event) => {
const { body, requestContext } = event

if (
body &&
JSON.parse(body).throwError &&
requestContext &&
requestContext.routeKey === '$default'
) {
throw new Error('Throwing error from incoming message')
}

return {
statusCode: 200,
body: body || undefined,
}
}
26 changes: 26 additions & 0 deletions tests/integration/websocket-oneway/serverless.yml
@@ -0,0 +1,26 @@
service: oneway-websocket-tests

plugins:
- ../../../

provider:
memorySize: 128
name: aws
region: us-east-1 # default
runtime: nodejs12.x
stage: dev
versionFunctions: false

functions:
handler:
handler: handler.handler
events:
- http:
path: echo
method: get
- websocket:
route: $connect
- websocket:
route: $disconnect
- websocket:
route: $default
55 changes: 55 additions & 0 deletions tests/integration/websocket-oneway/websocket-oneway.test.js
@@ -0,0 +1,55 @@
import { resolve } from 'path'
import WebSocket from 'ws'
import { joinUrl, setup, teardown } from '../_testHelpers/index.js'
import websocketSend from '../_testHelpers/websocketPromise.js'

jest.setTimeout(30000)

describe('one way websocket tests', () => {
// init
beforeAll(() =>
setup({
servicePath: resolve(__dirname),
}),
)

// cleanup
afterAll(() => teardown())

test('websocket echos nothing', async () => {
const url = new URL(joinUrl(TEST_BASE_URL, '/dev'))
url.port = url.port ? '3001' : url.port
url.protocol = 'ws'

const payload = JSON.stringify({
hello: 'world',
now: new Date().toISOString(),
})

const ws = new WebSocket(url.toString())
const { data, code, err } = await websocketSend(ws, payload)

expect(code).toBeUndefined()
expect(err).toBeUndefined()
expect(data).toBeUndefined()
})

test('execution error emits Internal Server Error', async () => {
const url = new URL(joinUrl(TEST_BASE_URL, '/dev'))
url.port = url.port ? '3001' : url.port
url.protocol = 'ws'

const payload = JSON.stringify({
hello: 'world',
now: new Date().toISOString(),
throwError: true,
})

const ws = new WebSocket(url.toString())
const { data, code, err } = await websocketSend(ws, payload)

expect(code).toBeUndefined()
expect(err).toBeUndefined()
expect(JSON.parse(data).message).toEqual('Internal server error')
})
})
32 changes: 32 additions & 0 deletions tests/integration/websocket-twoway/handler.js
@@ -0,0 +1,32 @@
'use strict'

exports.handler = async (event) => {
const { body, queryStringParameters, requestContext } = event
const statusCode =
queryStringParameters && queryStringParameters.statusCode
? Number(queryStringParameters.statusCode)
: 200

if (
queryStringParameters &&
queryStringParameters.throwError &&
requestContext &&
requestContext.routeKey === '$connect'
) {
throw new Error('Throwing error during connect phase')
}

if (
body &&
JSON.parse(body).throwError &&
requestContext &&
requestContext.routeKey === '$default'
) {
throw new Error('Throwing error from incoming message')
}

return {
statusCode,
body: body || undefined,
}
}
28 changes: 28 additions & 0 deletions tests/integration/websocket-twoway/serverless.yml
@@ -0,0 +1,28 @@
service: twoway-websocket-tests

plugins:
- ../../../

provider:
memorySize: 128
name: aws
region: us-east-1 # default
runtime: nodejs12.x
stage: dev
versionFunctions: false

functions:
handler:
handler: handler.handler
events:
- http:
path: echo
method: get
- websocket:
route: $connect
- websocket:
route: $disconnect
- websocket:
route: $default
# Enable 2-way comms
routeResponseSelectionExpression: $default

0 comments on commit 8f02226

Please sign in to comment.