Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Middleware Async #3108

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 31 additions & 30 deletions Sources/Vapor/Middleware/CORSMiddleware.swift
Expand Up @@ -6,7 +6,7 @@ import NIOCore
///
/// - note: Make sure this middleware is inserted before all your error/abort middlewares,
/// so that even the failed request responses contain proper CORS information.
public final class CORSMiddleware: Middleware {
public final class CORSMiddleware: AsyncMiddleware {
/// Option for the allow origin header in responses for CORS requests.
///
/// - none: Disallows any origin.
Expand Down Expand Up @@ -122,45 +122,46 @@ public final class CORSMiddleware: Middleware {
public init(configuration: Configuration = .default()) {
self.configuration = configuration
}

public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture<Response> {
public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
// Check if it's valid CORS request
guard request.headers[.origin].first != nil else {
return next.respond(to: request)
return try await next.respond(to: request)
}

// Determine if the request is pre-flight.
// If it is, create empty response otherwise get response from the responder chain.
let response = request.isPreflight
? request.eventLoop.makeSucceededFuture(.init())
: next.respond(to: request)
let response: Response
if request.isPreflight {
response = .init()
} else {
response = try await next.respond(to: request)
}

return response.map { response in
// Modify response headers based on CORS settings
let originBasedAccessControlAllowHeader = self.configuration.allowedOrigin.header(forRequest: request)
response.responseBox.withLockedValue { box in
box.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: originBasedAccessControlAllowHeader)
box.headers.replaceOrAdd(name: .accessControlAllowHeaders, value: self.configuration.allowedHeaders)
box.headers.replaceOrAdd(name: .accessControlAllowMethods, value: self.configuration.allowedMethods)

if let exposedHeaders = self.configuration.exposedHeaders {
box.headers.replaceOrAdd(name: .accessControlExpose, value: exposedHeaders)
}

if let cacheExpiration = self.configuration.cacheExpiration {
box.headers.replaceOrAdd(name: .accessControlMaxAge, value: String(cacheExpiration))
}

if self.configuration.allowCredentials {
box.headers.replaceOrAdd(name: .accessControlAllowCredentials, value: "true")
}
// Modify response headers based on CORS settings
let originBasedAccessControlAllowHeader = self.configuration.allowedOrigin.header(forRequest: request)
response.responseBox.withLockedValue { box in
box.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: originBasedAccessControlAllowHeader)
box.headers.replaceOrAdd(name: .accessControlAllowHeaders, value: self.configuration.allowedHeaders)
box.headers.replaceOrAdd(name: .accessControlAllowMethods, value: self.configuration.allowedMethods)

if let exposedHeaders = self.configuration.exposedHeaders {
box.headers.replaceOrAdd(name: .accessControlExpose, value: exposedHeaders)
}

if let cacheExpiration = self.configuration.cacheExpiration {
box.headers.replaceOrAdd(name: .accessControlMaxAge, value: String(cacheExpiration))
}

if self.configuration.allowCredentials {
box.headers.replaceOrAdd(name: .accessControlAllowCredentials, value: "true")
}

if case .originBased = self.configuration.allowedOrigin, !originBasedAccessControlAllowHeader.isEmpty {
box.headers.add(name: .vary, value: "origin")
}
if case .originBased = self.configuration.allowedOrigin, !originBasedAccessControlAllowHeader.isEmpty {
box.headers.add(name: .vary, value: "origin")
}
return response
}
return response
}
}

Expand Down
8 changes: 5 additions & 3 deletions Sources/Vapor/Middleware/ErrorMiddleware.swift
Expand Up @@ -3,7 +3,7 @@ import NIOCore
import NIOHTTP1

/// Captures all errors and transforms them into an internal server error HTTP response.
public final class ErrorMiddleware: Middleware {
public final class ErrorMiddleware: AsyncMiddleware {
/// Structure of `ErrorMiddleware` default response.
internal struct ErrorResponse: Codable {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you are at it, is it possible to make ErrorResponse public?

In my project I use https://github.com/mattpolzin/VaporOpenAPI to generate the openApi.yml file from the server code and with this in turn the client facing structs - the server is the single source of truth for the models.

Right now I need to duplicate this struct and make it conform to Sampleable & OpenAPIExampleProvider for this to work, which is not ideal.

I also tried writing my own ErrorMiddleware with a public ErrorResponse struct but this failed as I couldn't use response.responseBox as its marked internal and can only be used from the vapor module.

What would be the recommended approach here?

/// Always `true` to indicate this is a non-typical JSON response.
Expand Down Expand Up @@ -76,8 +76,10 @@ public final class ErrorMiddleware: Middleware {
self.closure = closure
}

public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture<Response> {
return next.respond(to: request).flatMapErrorThrowing { error in
public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
do {
return try await next.respond(to: request)
} catch {
return self.closure(request, error)
}
}
Expand Down
23 changes: 10 additions & 13 deletions Sources/Vapor/Middleware/FileMiddleware.swift
Expand Up @@ -4,7 +4,7 @@ import NIOCore
/// Serves static files from a public directory.
///
/// `FileMiddleware` will default to `DirectoryConfig`'s working directory with `"/Public"` appended.
public final class FileMiddleware: Middleware {
public final class FileMiddleware: AsyncMiddleware {
/// The public directory. Guaranteed to end with a slash.
private let publicDirectory: String
private let defaultFile: String?
Expand Down Expand Up @@ -35,18 +35,18 @@ public final class FileMiddleware: Middleware {
self.directoryAction = directoryAction
}

public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture<Response> {
public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
// make a copy of the percent-decoded path
guard var path = request.url.path.removingPercentEncoding else {
return request.eventLoop.makeFailedFuture(Abort(.badRequest))
throw Abort(.badRequest)
}

// path must be relative.
path = path.removeLeadingSlashes()

// protect against relative paths
guard !path.contains("../") else {
return request.eventLoop.makeFailedFuture(Abort(.forbidden))
throw Abort(.forbidden)
}

// create absolute path
Expand All @@ -55,7 +55,7 @@ public final class FileMiddleware: Middleware {
// check if path exists and whether it is a directory
var isDir: ObjCBool = false
guard FileManager.default.fileExists(atPath: absPath, isDirectory: &isDir) else {
return next.respond(to: request)
return try await next.respond(to: request)
}

if isDir.boolValue {
Expand All @@ -64,17 +64,15 @@ public final class FileMiddleware: Middleware {
case .redirect:
var redirectUrl = request.url
redirectUrl.path += "/"
return request.eventLoop.future(
request.redirect(to: redirectUrl.string, redirectType: .permanent)
)
return request.redirect(to: redirectUrl.string, redirectType: .permanent)
case .none:
return next.respond(to: request)
return try await next.respond(to: request)
}
}

// If a directory, check for the default file
guard let defaultFile = defaultFile else {
return next.respond(to: request)
return try await next.respond(to: request)
}

if defaultFile.isAbsolute() {
Expand All @@ -85,13 +83,12 @@ public final class FileMiddleware: Middleware {

// If the default file doesn't exist, pass on request
guard FileManager.default.fileExists(atPath: absPath) else {
return next.respond(to: request)
return try await next.respond(to: request)
}
}

// stream the file
let res = request.fileio.streamFile(at: absPath)
return request.eventLoop.makeSucceededFuture(res)
return request.fileio.streamFile(at: absPath)
}

/// Creates a new `FileMiddleware` for a server contained in an Xcode Project.
Expand Down
6 changes: 3 additions & 3 deletions Sources/Vapor/Middleware/RouteLoggingMiddleware.swift
Expand Up @@ -3,15 +3,15 @@ import Logging

/// Emits a log message containing the request method and path to a `Request`'s logger.
/// The log level of the message is configurable.
public final class RouteLoggingMiddleware: Middleware {
public final class RouteLoggingMiddleware: AsyncMiddleware {
public let logLevel: Logger.Level

public init(logLevel: Logger.Level = .info) {
self.logLevel = logLevel
}

public func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture<Response> {
public func respond(to request: Request, chainingTo next: AsyncResponder) async throws -> Response {
request.logger.log(level: self.logLevel, "\(request.method) \(request.url.path.removingPercentEncoding ?? request.url.path)")
return next.respond(to: request)
return try await next.respond(to: request)
}
}