Skip to content

Commit

Permalink
Refactor request body streaming state handling (#2357)
Browse files Browse the repository at this point in the history
* Refactor request body streaming state handling

* test early exit streaming response

* add body stream tests
  • Loading branch information
tanner0101 committed May 27, 2020
1 parent a049aff commit 77b84cf
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 46 deletions.
30 changes: 30 additions & 0 deletions Sources/Development/routes.swift
Expand Up @@ -176,6 +176,36 @@ public func routes(_ app: Application) throws {
print("in route")
return .ok
}

app.on(.POST, "upload", body: .stream) { req -> EventLoopFuture<HTTPStatus> in
return req.application.fileio.openFile(
path: "/Users/tanner/Desktop/foo.txt",
mode: .write,
flags: .allowFileCreation(),
eventLoop: req.eventLoop
).flatMap { fileHandle in
let promise = req.eventLoop.makePromise(of: HTTPStatus.self)
req.body.drain { part in
switch part {
case .buffer(let buffer):
return req.application.fileio.write(
fileHandle: fileHandle,
buffer: buffer,
eventLoop: req.eventLoop
)
case .error(let error):
promise.fail(error)
try! fileHandle.close()
return req.eventLoop.makeSucceededFuture(())
case .end:
promise.succeed(.ok)
try! fileHandle.close()
return req.eventLoop.makeSucceededFuture(())
}
}
return promise.futureResult
}
}
}

struct TestError: AbortError {
Expand Down
27 changes: 27 additions & 0 deletions Sources/Vapor/HTTP/BodyStream.swift
Expand Up @@ -10,6 +10,33 @@ public enum BodyStreamResult {
case end
}

extension BodyStreamResult: CustomStringConvertible {
public var description: String {
switch self {
case .buffer(let buffer):
return "buffer(\(buffer.readableBytes) bytes)"
case .error(let error):
return "error(\(error))"
case .end:
return "end"
}
}
}

extension BodyStreamResult: CustomDebugStringConvertible {
public var debugDescription: String {
switch self {
case .buffer(let buffer):
let value = String(decoding: buffer.readableBytesView, as: UTF8.self)
return "buffer(\(value))"
case .error(let error):
return "error(\(error))"
case .end:
return "end"
}
}
}

public protocol BodyStreamWriter {
var eventLoop: EventLoop { get }
func write(_ result: BodyStreamResult, promise: EventLoopPromise<Void>?)
Expand Down
33 changes: 25 additions & 8 deletions Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Expand Up @@ -12,20 +12,37 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)

// query delegate for response
self.responder.respond(to: request).whenComplete { response in
switch response {
case .failure(let error):
self.errorCaught(context: context, error: error)
case .success(let response):
if request.method == .HEAD {
response.forHeadRequest = true
if case .stream(let stream) = request.bodyStorage, !stream.isClosed {
// If streaming request body has not been closed yet,
// drain it before sending the response.
stream.read { (result, promise) in
switch result {
case .buffer: break
case .end:
self.serialize(response, for: request, context: context)
case .error(let error):
self.serialize(.failure(error), for: request, context: context)
}
promise?.succeed(())
}
} else {
self.serialize(response, for: request, context: context)
}
}
}

func serialize(_ response: Result<Response, Error>, for request: Request, context: ChannelHandlerContext) {
switch response {
case .failure(let error):
self.errorCaught(context: context, error: error)
case .success(let response):
if request.method == .HEAD {
response.forHeadRequest = true
}
self.serialize(response, for: request, context: context)
}
}

func serialize(_ response: Response, for request: Request, context: ChannelHandlerContext) {
switch request.version.major {
Expand Down
218 changes: 195 additions & 23 deletions Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift
Expand Up @@ -14,17 +14,16 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
}

var requestState: RequestState
var bodyStreamState: HTTPBodyStreamState

private let logger: Logger
var pendingWriteCount: Int
var hasReadPending: Bool
var application: Application

init(application: Application) {
self.application = application
self.requestState = .ready
self.logger = Logger(label: "codes.vapor.server")
self.pendingWriteCount = 0
self.hasReadPending = false
self.bodyStreamState = .init()
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
Expand Down Expand Up @@ -63,11 +62,23 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
let stream = Request.BodyStream(on: context.eventLoop)
request.bodyStorage = .stream(stream)
context.fireChannelRead(self.wrapInboundOut(request))
self.write(.buffer(previousBuffer), to: stream, context: context)
self.write(.buffer(buffer), to: stream, context: context)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReadBytes(previousBuffer),
stream: stream
)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReadBytes(buffer),
stream: stream
)
self.requestState = .streamingBody(stream)
case .streamingBody(let stream):
self.write(.buffer(buffer), to: stream, context: context)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReadBytes(buffer),
stream: stream
)
}
case .end(let tailHeaders):
assert(tailHeaders == nil, "Tail headers are not supported.")
Expand All @@ -79,43 +90,204 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
request.bodyStorage = .collected(buffer)
context.fireChannelRead(self.wrapInboundOut(request))
case .streamingBody(let stream):
self.write(.end, to: stream, context: context)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didEnd(),
stream: stream
)
}
self.requestState = .ready
}
}

func read(context: ChannelHandlerContext) {
if self.pendingWriteCount <= 0 {
switch self.requestState {
case .streamingBody(let stream):
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReceiveReadRequest(),
stream: stream
)
default:
context.read()
} else {
self.hasReadPending = true
}
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
switch self.requestState {
case .streamingBody(let stream):
stream.write(.error(error), promise: nil)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didError(error),
stream: stream
)
default:
break
}
context.fireErrorCaught(error)
}

func write(_ part: BodyStreamResult, to stream: Request.BodyStream, context: ChannelHandlerContext) {
self.pendingWriteCount += 1
stream.write(part).whenComplete { result in
self.pendingWriteCount -= 1
if self.hasReadPending {
self.hasReadPending = false
self.read(context: context)
func handleBodyStreamStateResult(
context: ChannelHandlerContext,
_ result: HTTPBodyStreamState.Result,
stream: Request.BodyStream
) {
switch result.action {
case .nothing: break
case .write(let buffer):
stream.write(.buffer(buffer)).whenComplete { writeResult in
switch writeResult {
case .failure(let error):
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didError(error),
stream: stream
)
case .success:
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didWrite(),
stream: stream
)
}
}
switch result {
case .failure(let error):
self.logger.error("Could not write body: \(error)")
case .success: break
case .close(let maybeError):
if let error = maybeError {
stream.write(.error(error), promise: nil)
} else {
stream.write(.end, promise: nil)
}
}
if result.callRead {
context.read()
}
}
}

struct HTTPBodyStreamState: CustomStringConvertible {
struct Result {
enum Action {
case nothing
case write(ByteBuffer)
case close(Error?)
}
let action: Action
let callRead: Bool
}

private struct BufferState {
var bufferedWrites: CircularBuffer<ByteBuffer>
var heldUpRead: Bool
var hasClosed: Bool

mutating func append(_ buffer: ByteBuffer) {
self.bufferedWrites.append(buffer)
}

var isEmpty: Bool {
return self.bufferedWrites.isEmpty
}

mutating func removeFirst() -> ByteBuffer {
return self.bufferedWrites.removeFirst()
}
}

private enum State {
case idle
case writing(BufferState)
case error(Error)
}

private var state: State

var description: String {
"\(self.state)"
}

init() {
self.state = .idle
}

mutating func didReadBytes(_ buffer: ByteBuffer) -> Result {
switch self.state {
case .idle:
self.state = .writing(.init(
bufferedWrites: .init(),
heldUpRead: false,
hasClosed: false
))
return .init(action: .write(buffer), callRead: false)
case .writing(var buffers):
buffers.append(buffer)
self.state = .writing(buffers)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didReceiveReadRequest() -> Result {
switch self.state {
case .idle:
return .init(action: .nothing, callRead: true)
case .writing(var buffers):
buffers.heldUpRead = true
self.state = .writing(buffers)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didEnd() -> Result {
switch self.state {
case .idle:
return .init(action: .close(nil), callRead: false)
case .writing(var buffers):
buffers.hasClosed = true
self.state = .writing(buffers)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didError(_ error: Error) -> Result {
switch self.state {
case .idle:
self.state = .error(error)
return .init(action: .close(error), callRead: false)
case .writing:
self.state = .error(error)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didWrite() -> Result {
switch self.state {
case .idle:
self.illegalTransition()
case .writing(var buffers):
if buffers.isEmpty {
self.state = .idle
return .init(
action: buffers.hasClosed ? .close(nil) : .nothing,
callRead: buffers.heldUpRead
)
} else {
let first = buffers.removeFirst()
self.state = .writing(buffers)
return .init(action: .write(first), callRead: false)
}
case .error(let error):
return .init(action: .close(error), callRead: false)
}
}

private func illegalTransition(_ function: String = #function) -> Never {
preconditionFailure("illegal transition \(function) in \(self)")
}
}

0 comments on commit 77b84cf

Please sign in to comment.