Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Ensure websocket parity with API Gateway #1301

Merged
merged 5 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 56 additions & 31 deletions src/events/websocket/WebSocketClients.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,43 @@ export default class WebSocketClients {
clearTimeout(timeoutId)
}

async _processEvent(websocketClient, connectionId, route, event) {
let functionKey = this.#webSocketRoutes.get(route)
async verifyClient(connectionId, request) {
const route = this.#webSocketRoutes.get('$connect')
if (!route) {
return { verified: false, statusCode: 502 }
}

const connectEvent = new WebSocketConnectEvent(
connectionId,
request,
this.#options,
).create()

const lambdaFunction = this.#lambda.get(route.functionKey)
lambdaFunction.setEvent(connectEvent)

try {
const { statusCode } = await lambdaFunction.runHandler()
const verified = statusCode >= 200 && statusCode < 300
return { verified, statusCode }
} catch (err) {
if (this.log) {
this.log.debug(`Error in route handler '${route.functionKey}'`, err)
} else {
debugLog(`Error in route handler '${route.functionKey}'`, err)
}
return { verified: false, statusCode: 502 }
}
}

async _processEvent(websocketClient, connectionId, routeKey, event) {
let route = this.#webSocketRoutes.get(routeKey)

if (!functionKey && route !== '$connect' && route !== '$disconnect') {
functionKey = this.#webSocketRoutes.get('$default')
if (!route && routeKey !== '$disconnect') {
route = this.#webSocketRoutes.get('$default')
}

if (!functionKey) {
if (!route) {
return
}

Expand All @@ -123,28 +152,29 @@ export default class WebSocketClients {
)
}

// mimic AWS behaviour (close connection) when the $connect route handler throws
if (route === '$connect') {
websocketClient.close()
}

if (this.log) {
this.log.debug(`Error in route handler '${functionKey}'`, err)
this.log.debug(`Error in route handler '${route.functionKey}'`, err)
} else {
debugLog(`Error in route handler '${functionKey}'`, err)
debugLog(`Error in route handler '${route.functionKey}'`, err)
}
}

const lambdaFunction = this.#lambda.get(functionKey)
const lambdaFunction = this.#lambda.get(route.functionKey)

lambdaFunction.setEvent(event)

// let result

try {
/* result = */ await lambdaFunction.runHandler()

// TODO what to do with "result"?
const { body } = await lambdaFunction.runHandler()
if (
body &&
routeKey !== '$disconnect' &&
route.definition.routeResponseSelectionExpression === '$default'
) {
// https://docs.aws.amazon.com/apigateway/latest/developerguide/apigateway-websocket-api-selection-expressions.html#apigateway-websocket-api-route-response-selection-expressions
// TODO: Once API gateway supports RouteResponses, this will need to change to support that functionality
// For now, send body back to the client
this.send(connectionId, body)
}
} catch (err) {
if (this.log) {
this.log.error(err)
Expand Down Expand Up @@ -176,17 +206,9 @@ export default class WebSocketClients {
return route || DEFAULT_WEBSOCKETS_ROUTE
}

addClient(webSocketClient, request, connectionId) {
addClient(webSocketClient, connectionId) {
this._addWebSocketClient(webSocketClient, connectionId)

const connectEvent = new WebSocketConnectEvent(
connectionId,
request,
this.#options,
).create()

this._processEvent(webSocketClient, connectionId, '$connect', connectEvent)

webSocketClient.on('close', () => {
if (this.log) {
this.log.debug(`disconnect:${connectionId}`)
Expand Down Expand Up @@ -233,14 +255,17 @@ export default class WebSocketClients {
})
}

addRoute(functionKey, route) {
addRoute(functionKey, definition) {
// set the route name
this.#webSocketRoutes.set(route, functionKey)
this.#webSocketRoutes.set(definition.route, {
functionKey,
definition,
})

if (this.log) {
this.log.notice(`route '${route}'`)
this.log.notice(`route '${definition}'`)
} else {
serverlessLog(`route '${route}'`)
serverlessLog(`route '${definition}'`)
}
}

Expand Down
39 changes: 36 additions & 3 deletions src/events/websocket/WebSocketServer.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { createUniqueId } from '../../utils/index.js'
export default class WebSocketServer {
#options = null
#webSocketClients = null
#connectionIds = new Map()

constructor(options, webSocketClients, sharedServer, v3Utils) {
this.#options = options
Expand All @@ -20,6 +21,35 @@ export default class WebSocketServer {

const server = new Server({
server: sharedServer,
verifyClient: ({ req }, cb) => {
const connectionId = createUniqueId()
const { headers } = req
const key = headers['sec-websocket-key']

if (this.log) {
this.log.debug(`verifyClient:${key} ${connectionId}`)
} else {
debugLog(`verifyClient:${key} ${connectionId}`)
}

// Use the websocket key to coorelate connection IDs
this.#connectionIds[key] = connectionId

this.#webSocketClients
.verifyClient(connectionId, req)
.then(({ verified, statusCode }) => {
try {
if (!verified) {
cb(false, statusCode)
return
}
cb(true)
} catch (e) {
debugLog(`Error verifying`, e)
cb(false)
}
})
},
})

server.on('connection', (webSocketClient, request) => {
Expand All @@ -29,15 +59,18 @@ export default class WebSocketServer {
console.log('received connection')
}

const connectionId = createUniqueId()
const { headers } = request
const key = headers['sec-websocket-key']

const connectionId = this.#connectionIds[key]

if (this.log) {
this.log.debug(`connect:${connectionId}`)
} else {
debugLog(`connect:${connectionId}`)
}

this.#webSocketClients.addClient(webSocketClient, request, connectionId)
this.#webSocketClients.addClient(webSocketClient, connectionId)
})
}

Expand All @@ -63,7 +96,7 @@ export default class WebSocketServer {
stop() {}

addRoute(functionKey, webSocketEvent) {
this.#webSocketClients.addRoute(functionKey, webSocketEvent.route)
this.#webSocketClients.addRoute(functionKey, webSocketEvent)
// serverlessLog(`route '${route}'`)
}
}