Skip to content

Commit

Permalink
Sendable Response (#3082)
Browse files Browse the repository at this point in the history
* Make Response Sendable

* Fix a warning

* Add preconcurrency annotations

* Remove a lock

* Bump ConsoleKit to Sendable version

* Remove preconcurrency imports for ConsoleKit

* Move all the response stuff into it's own box

* Fix up some usages

* Bypass multiple lock accesses for internal types

* Bypass multiple lock accesses for internal types

* Set the content length correctly

* Update last use of response body

* Migrate double lock access to single lock

* Update Sources/Vapor/Response/Response.swift

Co-authored-by: Mahdi Bahrami <github@mahdibm.com>

---------

Co-authored-by: Mahdi Bahrami <github@mahdibm.com>
  • Loading branch information
0xTim and MahdiBM committed Nov 1, 2023
1 parent 600d666 commit 3bf4e73
Show file tree
Hide file tree
Showing 17 changed files with 196 additions and 99 deletions.
2 changes: 1 addition & 1 deletion Package.swift
Expand Up @@ -22,7 +22,7 @@ let package = Package(
.package(url: "https://github.com/vapor/async-kit.git", from: "1.15.0"),

// 💻 APIs for creating interactive CLI tools.
.package(url: "https://github.com/vapor/console-kit.git", from: "4.0.0"),
.package(url: "https://github.com/vapor/console-kit.git", from: "4.9.0"),

// 🔑 Hashing (SHA2, HMAC), encryption (AES), public-key (RSA), and random data generation.
.package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "4.0.0"),
Expand Down
2 changes: 1 addition & 1 deletion Package@swift-5.9.swift
Expand Up @@ -22,7 +22,7 @@ let package = Package(
.package(url: "https://github.com/vapor/async-kit.git", from: "1.15.0"),

// 💻 APIs for creating interactive CLI tools.
.package(url: "https://github.com/vapor/console-kit.git", from: "4.0.0"),
.package(url: "https://github.com/vapor/console-kit.git", from: "4.9.0"),

// 🔑 Hashing (SHA2, HMAC), encryption (AES), public-key (RSA), and random data generation.
.package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "4.0.0"),
Expand Down
3 changes: 2 additions & 1 deletion Sources/Vapor/Commands/ServeCommand.swift
@@ -1,5 +1,6 @@
import Foundation
@preconcurrency import ConsoleKit
@preconcurrency import Dispatch
import ConsoleKit
import NIOConcurrencyHelpers

/// Boots the application's server. Listens for `SIGINT` and `SIGTERM` for graceful shutdown.
Expand Down
8 changes: 5 additions & 3 deletions Sources/Vapor/Concurrency/ResponseCodable+Concurrency.swift
Expand Up @@ -52,10 +52,12 @@ extension AsyncResponseEncodable {
/// - returns: Newly encoded `Response`.
public func encodeResponse(status: HTTPStatus, headers: HTTPHeaders = [:], for request: Request) async throws -> Response {
let response = try await self.encodeResponse(for: request)
for (name, value) in headers {
response.headers.replaceOrAdd(name: name, value: value)
response.responseBox.withLockedValue { box in
for (name, value) in headers {
box.headers.replaceOrAdd(name: name, value: value)
}
box.status = status
}
response.status = status
return response
}
}
Expand Down
6 changes: 4 additions & 2 deletions Sources/Vapor/Concurrency/ViewRenderer+Concurrency.swift
Expand Up @@ -13,8 +13,10 @@ public extension ViewRenderer {
extension View: AsyncResponseEncodable {
public func encodeResponse(for request: Request) async throws -> Response {
let response = Response()
response.headers.contentType = .html
response.body = .init(buffer: self.data)
response.responseBox.withLockedValue { box in
box.headers.contentType = .html
box.body = .init(buffer: self.data)
}
return response
}
}
2 changes: 1 addition & 1 deletion Sources/Vapor/Core/Core.swift
@@ -1,4 +1,4 @@
@preconcurrency import ConsoleKit
import ConsoleKit
import NIOCore
import NIOPosix
import NIOConcurrencyHelpers
Expand Down
2 changes: 1 addition & 1 deletion Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Expand Up @@ -31,7 +31,7 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
self.errorCaught(context: context, error: error)
case .success(let response):
if request.method == .HEAD {
response.forHeadRequest = true
response.responseBox.withLockedValue { $0.forHeadRequest = true }
}
self.serialize(response, for: request, context: context)
}
Expand Down
32 changes: 17 additions & 15 deletions Sources/Vapor/HTTP/Server/HTTPServerResponseEncoder.swift
Expand Up @@ -20,23 +20,25 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH

func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let response = self.unwrapOutboundIn(data)
// add a RFC1123 timestamp to the Date header to make this
// a valid request
response.headers.add(name: "date", value: self.dateCache.currentTimestamp())

if let server = self.serverHeader {
response.headers.add(name: "server", value: server)
var headOrNoContentRequest = false
response.responseBox.withLockedValue { box in
// add a RFC1123 timestamp to the Date header to make this
// a valid request
box.headers.add(name: "date", value: self.dateCache.currentTimestamp())
if let server = self.serverHeader {
box.headers.add(name: "server", value: server)
}

// begin serializing
let responseHead = HTTPResponseHead(version: box.version, status: box.status, headers: box.headers)
context.write(wrapOutboundOut(.head(responseHead)), promise: nil)

if box.status == .noContent || box.forHeadRequest {
headOrNoContentRequest = true
}
}

// begin serializing
context.write(wrapOutboundOut(.head(.init(
version: response.version,
status: response.status,
headers: response.headers
))), promise: nil)


if response.status == .noContent || response.forHeadRequest {
if headOrNoContentRequest {
// don't send bodies for 204 (no content) responses
// or HEAD requests
context.fireUserInboundEventTriggered(ResponseEndSentEvent())
Expand Down
8 changes: 6 additions & 2 deletions Sources/Vapor/HTTP/Server/HTTPServerUpgradeHandler.swift
Expand Up @@ -57,7 +57,10 @@ final class HTTPServerUpgradeHandler: ChannelDuplexHandler, RemovableChannelHand
switch self.upgradeState {
case .pending(let req, let buffer):
self.upgradeState = .upgraded
if res.status == .switchingProtocols, let upgrader = res.upgrader {
let (status, upgrader) = res.responseBox.withLockedValue { box in
return (box.status, box.upgrader)
}
if status == .switchingProtocols, let upgrader = upgrader {
let protocolUpgrader = upgrader.applyUpgrade(req: req, res: res)
let sendableBox = SendableBox(
context: context,
Expand Down Expand Up @@ -128,7 +131,8 @@ private final class UpgradeBufferHandler: ChannelInboundHandler, RemovableChanne
}

/// Conformance for any struct that performs an HTTP Upgrade
public protocol Upgrader {
@preconcurrency
public protocol Upgrader: Sendable {
func applyUpgrade(req: Request, res: Response) -> HTTPServerProtocolUpgrader
}

Expand Down
37 changes: 19 additions & 18 deletions Sources/Vapor/Middleware/CORSMiddleware.swift
Expand Up @@ -138,26 +138,27 @@ public final class CORSMiddleware: Middleware {
return response.map { response in
// Modify response headers based on CORS settings
let originBasedAccessControlAllowHeader = self.configuration.allowedOrigin.header(forRequest: request)
response.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: originBasedAccessControlAllowHeader)
response.headers.replaceOrAdd(name: .accessControlAllowHeaders, value: self.configuration.allowedHeaders)
response.headers.replaceOrAdd(name: .accessControlAllowMethods, value: self.configuration.allowedMethods)

if let exposedHeaders = self.configuration.exposedHeaders {
response.headers.replaceOrAdd(name: .accessControlExpose, value: exposedHeaders)
}

if let cacheExpiration = self.configuration.cacheExpiration {
response.headers.replaceOrAdd(name: .accessControlMaxAge, value: String(cacheExpiration))
}

if self.configuration.allowCredentials {
response.headers.replaceOrAdd(name: .accessControlAllowCredentials, value: "true")
}
response.responseBox.withLockedValue { box in
box.headers.replaceOrAdd(name: .accessControlAllowOrigin, value: originBasedAccessControlAllowHeader)
box.headers.replaceOrAdd(name: .accessControlAllowHeaders, value: self.configuration.allowedHeaders)
box.headers.replaceOrAdd(name: .accessControlAllowMethods, value: self.configuration.allowedMethods)

if let exposedHeaders = self.configuration.exposedHeaders {
box.headers.replaceOrAdd(name: .accessControlExpose, value: exposedHeaders)
}

if let cacheExpiration = self.configuration.cacheExpiration {
box.headers.replaceOrAdd(name: .accessControlMaxAge, value: String(cacheExpiration))
}

if self.configuration.allowCredentials {
box.headers.replaceOrAdd(name: .accessControlAllowCredentials, value: "true")
}

if case .originBased = self.configuration.allowedOrigin, !originBasedAccessControlAllowHeader.isEmpty {
response.headers.add(name: .vary, value: "origin")
if case .originBased = self.configuration.allowedOrigin, !originBasedAccessControlAllowHeader.isEmpty {
box.headers.add(name: .vary, value: "origin")
}
}

return response
}
}
Expand Down
12 changes: 8 additions & 4 deletions Sources/Vapor/Middleware/ErrorMiddleware.swift
Expand Up @@ -51,11 +51,15 @@ public final class ErrorMiddleware: Middleware {
// attempt to serialize the error to json
do {
let errorResponse = ErrorResponse(error: true, reason: reason)
response.body = try .init(data: JSONEncoder().encode(errorResponse), byteBufferAllocator: req.byteBufferAllocator)
response.headers.replaceOrAdd(name: .contentType, value: "application/json; charset=utf-8")
try response.responseBox.withLockedValue { box in
box.body = try .init(data: JSONEncoder().encode(errorResponse), byteBufferAllocator: req.byteBufferAllocator)
box.headers.replaceOrAdd(name: .contentType, value: "application/json; charset=utf-8")
}
} catch {
response.body = .init(string: "Oops: \(error)", byteBufferAllocator: req.byteBufferAllocator)
response.headers.replaceOrAdd(name: .contentType, value: "text/plain; charset=utf-8")
response.responseBox.withLockedValue { box in
box.body = .init(string: "Oops: \(error)", byteBufferAllocator: req.byteBufferAllocator)
box.headers.replaceOrAdd(name: .contentType, value: "text/plain; charset=utf-8")
}
}
return response
}
Expand Down
12 changes: 8 additions & 4 deletions Sources/Vapor/Request/Redirect.swift
Expand Up @@ -14,8 +14,10 @@ extension Request {
@available(*, deprecated, renamed: "redirect(to:redirectType:)")
public func redirect(to location: String, type: RedirectType) -> Response {
let response = Response()
response.status = type.status
response.headers.replaceOrAdd(name: .location, value: location)
response.responseBox.withLockedValue { box in
box.status = type.status
box.headers.replaceOrAdd(name: .location, value: location)
}
return response
}

Expand All @@ -33,8 +35,10 @@ extension Request {
/// - Returns: A response that redirects the client to the specified location
public func redirect(to location: String, redirectType: Redirect = .normal) -> Response {
let response = Response()
response.status = redirectType.status
response.headers.replaceOrAdd(name: .location, value: location)
response.responseBox.withLockedValue { box in
box.status = redirectType.status
box.headers.replaceOrAdd(name: .location, value: location)
}
return response
}
}
Expand Down
15 changes: 9 additions & 6 deletions Sources/Vapor/Response/Response+Body.swift
@@ -1,21 +1,22 @@
@preconcurrency import Dispatch
import Foundation
import NIOCore
import NIOConcurrencyHelpers

extension Response {
struct BodyStream {
struct BodyStream: Sendable {
let count: Int
let callback: (BodyStreamWriter) -> ()
let callback: @Sendable (BodyStreamWriter) -> ()
}

/// Represents a `Response`'s body.
///
/// let body = Response.Body(string: "Hello, world!")
///
/// This can contain any data (streaming or static) and should match the message's `"Content-Type"` header.
public struct Body: CustomStringConvertible, ExpressibleByStringLiteral {
public struct Body: CustomStringConvertible, ExpressibleByStringLiteral, Sendable {
/// The internal HTTP body storage enum. This is an implementation detail.
internal enum Storage {
internal enum Storage: Sendable {
/// Cases
case none
case buffer(ByteBuffer)
Expand Down Expand Up @@ -151,12 +152,14 @@ extension Response {
self.storage = .buffer(buffer)
}

public init(stream: @escaping (BodyStreamWriter) -> (), count: Int, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) {
@preconcurrency
public init(stream: @Sendable @escaping (BodyStreamWriter) -> (), count: Int, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) {
self.byteBufferAllocator = byteBufferAllocator
self.storage = .stream(.init(count: count, callback: stream))
}

public init(stream: @escaping (BodyStreamWriter) -> (), byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) {
@preconcurrency
public init(stream: @Sendable @escaping (BodyStreamWriter) -> (), byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) {
self.init(stream: stream, count: -1, byteBufferAllocator: byteBufferAllocator)
}

Expand Down

0 comments on commit 3bf4e73

Please sign in to comment.