Skip to content

Commit

Permalink
Add JWKSCache helper for storing JWKS (#112)
Browse files Browse the repository at this point in the history
* Added an atomic and thread-safe method to download JWKS

* Renamed AtomicJwks to AtomicJWKS

* Style changes requested by Tanner

* Fixed a unit test

* Renamed AtomicJWKS to JWKSCache

* Updates the package dependency versions

* renamed getKeys(on:) to keys(on:)

* refactor JWKSCache for atomicity (#113)

* refactor JWKSCache for atomicity

* add missing el hop

* sync access to clearing current request

* fix @grosch comments

Co-authored-by: grosch <scott.grosch@icloud.com>
  • Loading branch information
tanner0101 and grosch committed Feb 14, 2020
1 parent 9e121e0 commit 0ecb6fa
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ let package = Package(
.library(name: "JWT", targets: ["JWT"]),
],
dependencies: [
.package(url: "https://github.com/vapor/jwt-kit.git", from: "4.0.0-beta.2"),
.package(url: "https://github.com/vapor/vapor.git", from: "4.0.0-beta.2"),
.package(url: "https://github.com/vapor/jwt-kit.git", from: "4.0.0-beta.2.3"),
.package(url: "https://github.com/vapor/vapor.git", from: "4.0.0-beta.3.17"),
],
targets: [
.target(name: "JWT", dependencies: ["JWTKit", "Vapor"]),
Expand Down
140 changes: 140 additions & 0 deletions Sources/JWT/JWKSCache.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import Vapor

/// A thread-safe and atomic class for retrieving JSON Web Key Sets which honors the
/// HTTP `Cache-Control`, `Expires` and `Etag` headers.
public final class JWKSCache {
public enum Error: Swift.Error {
case missingCache
case unexpctedResponseStatus(HTTPStatus, uri: URI)
}
private let uri: URI
private let client: Client
private let sync: Lock

struct CachedJWKS {
var cacheUntil: Date
var jwks: JWKS
}

private var cachedETag: String?
private var cachedJWKS: CachedJWKS?
private var currentRequest: EventLoopFuture<JWKS>?

/// Creates a new `JWKSCache`.
/// - Parameters:
/// - keyURL: The URL to the JWKS data.
/// - application: The Vapor `Application`.
public init(keyURL: String, client: Client) {
self.uri = URI(string: keyURL)
self.client = client
self.sync = .init()
}

/// Downloads the JSON Web Key Set, taking into account `Cache-Control`, `Expires` and `Etag` headers..
/// - Parameters:
/// - req: The Vapor `Request` object.
public func keys(on request: Request) -> EventLoopFuture<JWKS> {
self.keys(logger: request.logger, on: request.eventLoop)
}

/// Downloads the JSON Web Key Set, taking into account `Cache-Control`, `Expires` and `Etag` headers..
/// - Parameters:
/// - logger: For logging debug messages.
/// - eventLoop: Event loop to be called back on.
public func keys(logger: Logger, on eventLoop: EventLoop) -> EventLoopFuture<JWKS> {
// Synchronize access to shared state.
self.sync.lock()
defer { self.sync.unlock() }

// Check if we have cached keys that are still valid.
if let cachedJWKS = self.cachedJWKS, Date() < cachedJWKS.cacheUntil {
return eventLoop.makeSucceededFuture(cachedJWKS.jwks)
}

// Check if there is already a request happening
// to fetch keys.
if let keys = self.currentRequest {
// The current key request may be happening on a
// different event loop.
return keys.hop(to: eventLoop)
}

// Create a new key request and store it.
logger.debug("Requesting JWKS from \(self.uri).")
let keys = self.requestKeys(logger: logger)
self.currentRequest = keys

// Once the key request finishes, clear the current
// request and return the keys.
return keys.map { keys in
// Synchronize access to shared state.
self.sync.lock()
defer { self.sync.unlock() }
self.currentRequest = nil
return keys
}.hop(to: eventLoop)
}

private func requestKeys(logger: Logger) -> EventLoopFuture<JWKS> {
// Add cached eTag header to this request if we have it.
var headers: HTTPHeaders = [:]
if let eTag = self.cachedETag {
headers.add(name: .ifNoneMatch, value: eTag)
}

// Store the requested-at date to calculate expiration date.
let requestSentAt = Date()

// Send the GET request for the JWKs.
return self.client.get(
self.uri, headers: headers
).flatMapThrowing { response in
// Synchronize access to shared state.
self.sync.lock()
defer { self.sync.unlock() }

let expirationDate = response.headers.expirationDate(requestSentAt: requestSentAt)
self.cachedETag = response.headers.firstValue(name: .eTag)
switch response.status {
case .notModified:
// The cached JWKS are still the latest version.
logger.debug("Cached JWKS are still valid.")
guard var cachedJWKS = self.cachedJWKS else {
throw Error.missingCache
}

// Update the JWKS cache if there is an expiration date.
if let expirationDate = expirationDate {
// Update the cache metadata.
cachedJWKS.cacheUntil = expirationDate
self.cachedJWKS = cachedJWKS
} else {
self.cachedJWKS = nil
}
return cachedJWKS.jwks
case .ok:
// New JWKS have been returned.
logger.debug("New JWKS have been returned.")
let jwks = try response.content.decode(JWKS.self)

// Cache the JWKS if there is an expiration date.
if let expirationDate = expirationDate {
if var cachedJWKS = self.cachedJWKS {
// Update the existing cache.
cachedJWKS.cacheUntil = expirationDate
cachedJWKS.jwks = jwks
self.cachedJWKS = cachedJWKS
} else {
// Create a new cache.
self.cachedJWKS = .init(cacheUntil: expirationDate, jwks: jwks)
}
} else {
self.cachedJWKS = nil
}
return jwks
default:
throw Error.unexpctedResponseStatus(response.status, uri: self.uri)
}
}
}
}
38 changes: 38 additions & 0 deletions Tests/JWTTests/JWTTests.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import JWT
import JWTKit
import XCTVapor

class JWTKitTests: XCTestCase {
Expand Down Expand Up @@ -118,8 +119,45 @@ class JWTKitTests: XCTestCase {
}
}

func testJWKSDownload() throws {
// creates a new application for testing
let app = Application(.testing)
defer { app.shutdown() }

app.client.configuration.ignoreUncleanSSLShutdown = true

let google = JWKSCache(
keyURL: "https://www.googleapis.com/oauth2/v3/certs",
client: app.client
)

app.get("keys") { req in
google.keys(on: req).map { jwks in
jwks.keys.count
}
}

try app.test(.GET, "keys") { res in
XCTAssertEqual(res.status, .ok)
XCTAssertEqual(res.body.string, "2")
}
}

override func setUp() {
XCTAssert(isLoggingConfigured)
}
}

let isLoggingConfigured: Bool = {
LoggingSystem.bootstrap { label in
var handler = StreamLogHandler.standardOutput(label: label)
handler.logLevel = .debug
return handler
}
return true
}()


struct LoginResponse: Content {
var token: String
}
Expand Down

0 comments on commit 0ecb6fa

Please sign in to comment.