Skip to content

Commit

Permalink
Fix NIOLoopBound issues (#3081)
Browse files Browse the repository at this point in the history
* Add test to demo crash caused by incorrect event loop

* Fix crash

* Don't need our own ELG

* Add crash for streaming off event loop

* Ensure body streaming is done on the event loop

* Work around issues

* Much better event loop guarantees

* Fix compilation on older Swift versions
  • Loading branch information
0xTim committed Oct 6, 2023
1 parent c17d9b9 commit e38dfe4
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 13 deletions.
12 changes: 12 additions & 0 deletions Sources/Vapor/HTTP/BodyStream.swift
Expand Up @@ -46,6 +46,18 @@ public protocol BodyStreamWriter: Sendable {

extension BodyStreamWriter {
public func write(_ result: BodyStreamResult) -> EventLoopFuture<Void> {
// We need to ensure we're on the event loop here for write as there's
// no guarantee that users will be on the event loop
if self.eventLoop.inEventLoop {
return write0(result)
} else {
return self.eventLoop.flatSubmit {
self.write0(result)
}
}
}

private func write0(_ result: BodyStreamResult) -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.write(result, promise: promise)
return promise.futureResult
Expand Down
3 changes: 2 additions & 1 deletion Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Expand Up @@ -18,7 +18,8 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let box = NIOLoopBound((context, self), eventLoop: context.eventLoop)
let request = self.unwrapInboundIn(data)
self.responder.respond(to: request).whenComplete { response in
// hop(to:) is required here to ensure we're on the correct event loop
self.responder.respond(to: request).hop(to: context.eventLoop).whenComplete { response in
let (context, handler) = box.value
handler.serialize(response, for: request, context: context)
}
Expand Down
129 changes: 117 additions & 12 deletions Tests/VaporTests/PipelineTests.swift
Expand Up @@ -5,10 +5,18 @@ import NIOEmbedded
import NIOCore

final class PipelineTests: XCTestCase {
var app: Application!

override func setUp() async throws {
app = Application(.testing)
}

override func tearDown() async throws {
app.shutdown()
}


func testEchoHandlers() throws {
let app = Application(.testing)
defer { app.shutdown() }

app.on(.POST, "echo", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
request.body.drain { body in
Expand Down Expand Up @@ -59,9 +67,6 @@ final class PipelineTests: XCTestCase {
}

func testEOFFraming() throws {
let app = Application(.testing)
defer { app.shutdown() }

app.on(.POST, "echo", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
request.body.drain { body in
Expand Down Expand Up @@ -89,9 +94,6 @@ final class PipelineTests: XCTestCase {
}

func testBadStreamLength() throws {
let app = Application(.testing)
defer { app.shutdown() }

app.on(.POST, "echo", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
writer.write(.buffer(.init(string: "a")), promise: nil)
Expand All @@ -117,9 +119,6 @@ final class PipelineTests: XCTestCase {
}

func testInvalidHttp() throws {
let app = Application(.testing)
defer { app.shutdown() }

let channel = EmbeddedChannel()
try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait()
try channel.pipeline.addVaporHTTP1Handlers(
Expand All @@ -140,6 +139,112 @@ final class PipelineTests: XCTestCase {
XCTAssertEqual(channel.isActive, false)
try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string)
}

func testReturningResponseOnDifferentEventLoopDosentCrashLoopBoundBox() async throws {
struct ResponseThing: ResponseEncodable {
let eventLoop: EventLoop

func encodeResponse(for request: Vapor.Request) -> NIOCore.EventLoopFuture<Vapor.Response> {
let response = Response(status: .ok)
return eventLoop.future(response)
}
}

let eventLoop = app!.eventLoopGroup.next()
app.get("dont-crash") { req in
return ResponseThing(eventLoop: eventLoop)
}

try app.test(.GET, "dont-crash") { res in
XCTAssertEqual(res.status, .ok)
}

app.environment.arguments = ["serve"]
app.http.server.configuration.port = 0
try app.start()

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

let res = try await app.client.get("http://localhost:\(port)/dont-crash")
XCTAssertEqual(res.status, .ok)
}

func testReturningResponseFromMiddlewareOnDifferentEventLoopDosentCrashLoopBoundBox() async throws {
struct WrongEventLoopMiddleware: Middleware {
func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture<Response> {
next.respond(to: request).hop(to: request.application.eventLoopGroup.next())
}
}

app.grouped(WrongEventLoopMiddleware()).get("dont-crash") { req in
return "OK"
}

try app.test(.GET, "dont-crash") { res in
XCTAssertEqual(res.status, .ok)
}

app.environment.arguments = ["serve"]
app.http.server.configuration.port = 0
try app.start()

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

let res = try await app.client.get("http://localhost:\(port)/dont-crash")
XCTAssertEqual(res.status, .ok)
}

func testStreamingOffEventLoop() async throws {
let eventLoop = app.eventLoopGroup.next()
app.on(.POST, "stream", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
request.body.drain { body in
switch body {
case .buffer(let buffer):
return writer.write(.buffer(buffer)).hop(to: eventLoop)
case .error(let error):
return writer.write(.error(error)).hop(to: eventLoop)
case .end:
return writer.write(.end).hop(to: eventLoop)
}
}
}))
}

app.environment.arguments = ["serve"]
app.http.server.configuration.port = 0
try app.start()

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

struct ABody: Content {
let hello: String

init() {
self.hello = "hello"
}
}

let res = try await app.client.post("http://localhost:\(port)/stream", beforeSend: {
try $0.content.encode(ABody())
})
XCTAssertEqual(res.status, .ok)
}

override class func setUp() {
XCTAssert(isLoggingConfigured)
Expand Down

0 comments on commit e38dfe4

Please sign in to comment.