Skip to content

Commit

Permalink
Make Async Request Body actually work (#3096)
Browse files Browse the repository at this point in the history
* Add an actual failing test for async request body

* Fix event loop crash in async body

* Fix logic bug with backpressure in async request body

* Add backpressure test

* Update Tests/AsyncTests/AsyncRequestTests.swift

Co-authored-by: Gwynne Raskind <gwynne@vapor.codes>

* Try working out failing test

* Tell the delegate we stopped

* Add test to ensure we clean up correctly

* Disable dodgy test

---------

Co-authored-by: Gwynne Raskind <gwynne@vapor.codes>
  • Loading branch information
0xTim and gwynne committed Nov 7, 2023
1 parent 1d075c8 commit d682e05
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 21 deletions.
9 changes: 9 additions & 0 deletions NOTICES.txt
Expand Up @@ -19,3 +19,12 @@ from Swift Metrics.
* https://www.apache.org/licenses/LICENSE-2.0
* HOMEPAGE:
* https://github.com/apple/swift-metrics

This product contains an implementation of AsyncLazySequence taken from Async
HTTP Client

* LICENSE (Apache License 2.0):
* https://www.apache.org/licenses/LICENSE-2.0
* HOMEPAGE:
* https://github.com/swift-server/async-http-client

@@ -1,4 +1,3 @@
#if compiler(>=5.7)
import NIOCore
import NIOConcurrencyHelpers

Expand All @@ -13,11 +12,12 @@ extension Request.Body {
/// in `Request.Body/makeAsyncIterator()` method.
fileprivate final class AsyncSequenceDelegate: @unchecked Sendable, NIOAsyncSequenceProducerDelegate {
private enum State {
case notCalledYet
case noSignalReceived
case waitingForSignalFromConsumer(EventLoopPromise<Void>)
}

private var _state: State = .noSignalReceived
private var _state: State = .notCalledYet
private let eventLoop: any EventLoop

init(eventLoop: any EventLoop) {
Expand All @@ -27,6 +27,9 @@ extension Request.Body {
private func produceMore0() {
self.eventLoop.preconditionInEventLoop()
switch self._state {
case .notCalledYet:
// We can just return here to sign to the producer that we want more data
break
case .noSignalReceived:
preconditionFailure()
case .waitingForSignalFromConsumer(let promise):
Expand All @@ -38,6 +41,9 @@ extension Request.Body {
private func didTerminate0() {
self.eventLoop.preconditionInEventLoop()
switch self._state {
case .notCalledYet:
// Means didn't hit the backpressure limits, so just return
break
case .noSignalReceived:
// we will inform the producer, since the next write will fail.
break
Expand All @@ -50,7 +56,7 @@ extension Request.Body {
func registerBackpressurePromise(_ promise: EventLoopPromise<Void>) {
self.eventLoop.preconditionInEventLoop()
switch self._state {
case .noSignalReceived:
case .noSignalReceived, .notCalledYet:
self._state = .waitingForSignalFromConsumer(promise)
case .waitingForSignalFromConsumer:
preconditionFailure()
Expand Down Expand Up @@ -140,6 +146,7 @@ extension Request.Body: AsyncSequence {
// The consumer dropped the sequence.
// Inform the producer that we don't want more data
// by returning an error in the future.
delegate.didTerminate()
return request.eventLoop.makeFailedFuture(CancellationError())
case .stopProducing:
// The consumer is too slow.
Expand All @@ -166,4 +173,3 @@ extension Request.Body: AsyncSequence {
return AsyncIterator(underlying: producer.sequence.makeAsyncIterator())
}
}
#endif
21 changes: 15 additions & 6 deletions Sources/Vapor/Request/Request+BodyStream.swift
Expand Up @@ -27,8 +27,17 @@ extension Request {
self.allocator = byteBufferAllocator
}

/// `read(_:)` **must** be called when on an `EventLoop`
func read(_ handler: @escaping (BodyStreamResult, EventLoopPromise<Void>?) -> ()) {
func read(_ handler: @escaping @Sendable (BodyStreamResult, EventLoopPromise<Void>?) -> ()) {
if self.eventLoop.inEventLoop {
read0(handler)
} else {
self.eventLoop.execute {
self.read0(handler)
}
}
}

func read0(_ handler: @escaping @Sendable (BodyStreamResult, EventLoopPromise<Void>?) -> ()) {
self.eventLoop.preconditionInEventLoop()
self.handlerBuffer.value.handler = handler
for (result, promise) in self.handlerBuffer.value.buffer {
Expand Down Expand Up @@ -72,17 +81,17 @@ extension Request {
// See https://github.com/vapor/vapor/issues/2906
return eventLoop.flatSubmit {
let promise = eventLoop.makePromise(of: ByteBuffer.self)
var data = self.allocator.buffer(capacity: 0)
let data = NIOLoopBoundBox(self.allocator.buffer(capacity: 0), eventLoop: eventLoop)
self.read { chunk, next in
switch chunk {
case .buffer(var buffer):
if let max = max, data.readableBytes + buffer.readableBytes >= max {
if let max = max, data.value.readableBytes + buffer.readableBytes >= max {
promise.fail(Abort(.payloadTooLarge))
} else {
data.writeBuffer(&buffer)
data.value.writeBuffer(&buffer)
}
case .error(let error): promise.fail(error)
case .end: promise.succeed(data)
case .end: promise.succeed(data.value)
}
next?.succeed(())
}
Expand Down
230 changes: 219 additions & 11 deletions Tests/AsyncTests/AsyncRequestTests.swift
@@ -1,8 +1,10 @@
#if compiler(>=5.7) && canImport(_Concurrency)
import XCTVapor
import XCTest
import Vapor
import NIOCore
import AsyncHTTPClient
import Atomics
import NIOConcurrencyHelpers

fileprivate extension String {
static func randomDigits(length: Int = 999) -> String {
Expand All @@ -16,9 +18,22 @@ fileprivate extension String {

final class AsyncRequestTests: XCTestCase {

func testStreamingRequest() throws {
let app = Application(.testing)
defer { app.shutdown() }
var app: Application!
var eventLoopGroup: EventLoopGroup!

override func setUp() async throws {
eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 4)
app = Application(.testing, .shared(eventLoopGroup))
}

override func tearDown() async throws {
app.shutdown()
try await eventLoopGroup.shutdownGracefully()
}

func testStreamingRequest() async throws {
app.http.server.configuration.hostname = "127.0.0.1"
app.http.server.configuration.port = 0

let testValue = String.randomDigits()

Expand All @@ -33,13 +48,206 @@ final class AsyncRequestTests: XCTestCase {
return string
}

try app.testable().test(.POST, "/stream", beforeRequest: { req in
req.body = ByteBuffer(string: testValue)
}) { res in
XCTAssertEqual(res.status, .ok)
let returnedString = try XCTUnwrap(try res.content.decode(String.self))
XCTAssertEqual(testValue, returnedString)
app.environment.arguments = ["serve"]
XCTAssertNoThrow(try app.start())

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let ip = localAddress.ipAddress,
let port = localAddress.port else {
return XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
}

var request = HTTPClientRequest(url: "http://\(ip):\(port)/stream")
request.method = .POST
request.body = .stream(testValue.utf8.async, length: .unknown)

let response: HTTPClientResponse = try await app.http.client.shared.execute(request, timeout: .seconds(5))
XCTAssertEqual(response.status, .ok)
let body = try await response.body.collect(upTo: 1024 * 1024)
XCTAssertEqual(body.string, testValue)
}

func testStreamingRequestBodyCleansUp() async throws {
app.http.server.configuration.hostname = "127.0.0.1"
app.http.server.configuration.port = 0

let bytesTheServerRead = ManagedAtomic<Int>(0)

app.on(.POST, "hello", body: .stream) { req async throws -> Response in
var bodyIterator = req.body.makeAsyncIterator()
let firstChunk = try await bodyIterator.next()
bytesTheServerRead.wrappingIncrement(by: firstChunk?.readableBytes ?? 0, ordering: .relaxed)
throw Abort(.internalServerError)
}

app.environment.arguments = ["serve"]
XCTAssertNoThrow(try app.start())

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let ip = localAddress.ipAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

var oneMBBB = ByteBuffer(repeating: 0x41, count: 1024 * 1024)
let oneMB = try XCTUnwrap(oneMBBB.readData(length: oneMBBB.readableBytes))
var request = HTTPClientRequest(url: "http://\(ip):\(port)/hello")
request.method = .POST
request.body = .stream(oneMB.async, length: .known(oneMB.count))
let response = try await app.http.client.shared.execute(request, timeout: .seconds(5))

XCTAssertGreaterThan(bytesTheServerRead.load(ordering: .relaxed), 0)
XCTAssertEqual(response.status, .internalServerError)
}

// TODO: Re-enable once it reliably works and doesn't cause issues with trying to shut the application down
// This may require some work in Vapor
func _testRequestBodyBackpressureWorksWithAsyncStreaming() async throws {
app.http.server.configuration.hostname = "127.0.0.1"
app.http.server.configuration.port = 0

let numberOfTimesTheServerGotOfferedBytes = ManagedAtomic<Int>(0)
let bytesTheServerSaw = ManagedAtomic<Int>(0)
let bytesTheClientSent = ManagedAtomic<Int>(0)
let serverSawEnd = ManagedAtomic<Bool>(false)
let serverSawRequest = ManagedAtomic<Bool>(false)

let requestHandlerTask: NIOLockedValueBox<Task<Response, Error>?> = .init(nil)

app.on(.POST, "hello", body: .stream) { req async throws -> Response in
requestHandlerTask.withLockedValue {
$0 = Task {
XCTAssertTrue(serverSawRequest.compareExchange(expected: false, desired: true, ordering: .relaxed).exchanged)
var bodyIterator = req.body.makeAsyncIterator()
let firstChunk = try await bodyIterator.next() // read only first chunk
numberOfTimesTheServerGotOfferedBytes.wrappingIncrement(ordering: .relaxed)
bytesTheServerSaw.wrappingIncrement(by: firstChunk?.readableBytes ?? 0, ordering: .relaxed)
defer {
_ = bodyIterator // make sure to not prematurely cancelling the sequence
}
try await Task.sleep(nanoseconds: 10_000_000_000) // wait "forever"
serverSawEnd.store(true, ordering: .relaxed)
return Response(status: .ok)
}
}

do {
let task = requestHandlerTask.withLockedValue { $0 }
return try await task!.value
} catch {
throw Abort(.internalServerError)
}
}

app.environment.arguments = ["serve"]
XCTAssertNoThrow(try app.start())

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let ip = localAddress.ipAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

final class ResponseDelegate: HTTPClientResponseDelegate {
typealias Response = Void

private let bytesTheClientSent: ManagedAtomic<Int>

init(bytesTheClientSent: ManagedAtomic<Int>) {
self.bytesTheClientSent = bytesTheClientSent
}

func didFinishRequest(task: HTTPClient.Task<Response>) throws -> Response {
return ()
}

func didSendRequestPart(task: HTTPClient.Task<Response>, _ part: IOData) {
self.bytesTheClientSent.wrappingIncrement(by: part.readableBytes, ordering: .relaxed)
}
}

let tenMB = ByteBuffer(repeating: 0x41, count: 10 * 1024 * 1024)
let request = try! HTTPClient.Request(url: "http://\(ip):\(port)/hello",
method: .POST,
headers: [:],
body: .byteBuffer(tenMB))
let delegate = ResponseDelegate(bytesTheClientSent: bytesTheClientSent)
let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup))
XCTAssertThrowsError(try httpClient.execute(request: request,
delegate: delegate,
deadline: .now() + .milliseconds(500)).wait()) { error in
if let error = error as? HTTPClientError {
XCTAssert(error == .readTimeout || error == .deadlineExceeded)
} else {
XCTFail("unexpected error: \(error)")
}
}

XCTAssertEqual(1, numberOfTimesTheServerGotOfferedBytes.load(ordering: .relaxed))
XCTAssertGreaterThan(tenMB.readableBytes, bytesTheServerSaw.load(ordering: .relaxed))
XCTAssertGreaterThan(tenMB.readableBytes, bytesTheClientSent.load(ordering: .relaxed))
XCTAssertEqual(0, bytesTheClientSent.load(ordering: .relaxed)) // We'd only see this if we sent the full 10 MB.
XCTAssertFalse(serverSawEnd.load(ordering: .relaxed))
XCTAssertTrue(serverSawRequest.load(ordering: .relaxed))

requestHandlerTask.withLockedValue { $0?.cancel() }
try await httpClient.shutdown()
}
}

// This was taken from AsyncHTTPClients's AsyncRequestTests.swift code.
// The license for the original work is reproduced below. See NOTICES.txt for
// more.

//===----------------------------------------------------------------------===//
//
// This source file is part of the AsyncHTTPClient open source project
//
// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

struct AsyncLazySequence<Base: Sequence>: AsyncSequence {
typealias Element = Base.Element
struct AsyncIterator: AsyncIteratorProtocol {
var iterator: Base.Iterator
init(iterator: Base.Iterator) {
self.iterator = iterator
}

mutating func next() async throws -> Base.Element? {
self.iterator.next()
}
}

var base: Base

init(base: Base) {
self.base = base
}

func makeAsyncIterator() -> AsyncIterator {
.init(iterator: self.base.makeIterator())
}
}

extension AsyncLazySequence: Sendable where Base: Sendable {}
extension AsyncLazySequence.AsyncIterator: Sendable where Base.Iterator: Sendable {}

extension Sequence {
/// Turns `self` into an `AsyncSequence` by vending each element of `self` asynchronously.
var async: AsyncLazySequence<Self> {
.init(base: self)
}
}
#endif

0 comments on commit d682e05

Please sign in to comment.