Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor request body streaming state handling #2357

Merged
merged 3 commits into from May 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions Sources/Development/routes.swift
Expand Up @@ -176,6 +176,36 @@ public func routes(_ app: Application) throws {
print("in route")
return .ok
}

app.on(.POST, "upload", body: .stream) { req -> EventLoopFuture<HTTPStatus> in
return req.application.fileio.openFile(
path: "/Users/tanner/Desktop/foo.txt",
mode: .write,
flags: .allowFileCreation(),
eventLoop: req.eventLoop
).flatMap { fileHandle in
let promise = req.eventLoop.makePromise(of: HTTPStatus.self)
req.body.drain { part in
switch part {
case .buffer(let buffer):
return req.application.fileio.write(
fileHandle: fileHandle,
buffer: buffer,
eventLoop: req.eventLoop
)
case .error(let error):
promise.fail(error)
try! fileHandle.close()
return req.eventLoop.makeSucceededFuture(())
case .end:
promise.succeed(.ok)
try! fileHandle.close()
return req.eventLoop.makeSucceededFuture(())
}
}
return promise.futureResult
}
}
}

struct TestError: AbortError {
Expand Down
27 changes: 27 additions & 0 deletions Sources/Vapor/HTTP/BodyStream.swift
Expand Up @@ -10,6 +10,33 @@ public enum BodyStreamResult {
case end
}

extension BodyStreamResult: CustomStringConvertible {
public var description: String {
switch self {
case .buffer(let buffer):
return "buffer(\(buffer.readableBytes) bytes)"
case .error(let error):
return "error(\(error))"
case .end:
return "end"
}
}
}

extension BodyStreamResult: CustomDebugStringConvertible {
public var debugDescription: String {
switch self {
case .buffer(let buffer):
let value = String(decoding: buffer.readableBytesView, as: UTF8.self)
return "buffer(\(value))"
case .error(let error):
return "error(\(error))"
case .end:
return "end"
}
}
}

public protocol BodyStreamWriter {
var eventLoop: EventLoop { get }
func write(_ result: BodyStreamResult, promise: EventLoopPromise<Void>?)
Expand Down
33 changes: 25 additions & 8 deletions Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Expand Up @@ -12,20 +12,37 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)

// query delegate for response
self.responder.respond(to: request).whenComplete { response in
switch response {
case .failure(let error):
self.errorCaught(context: context, error: error)
case .success(let response):
if request.method == .HEAD {
response.forHeadRequest = true
if case .stream(let stream) = request.bodyStorage, !stream.isClosed {
// If streaming request body has not been closed yet,
// drain it before sending the response.
stream.read { (result, promise) in
switch result {
case .buffer: break
case .end:
self.serialize(response, for: request, context: context)
case .error(let error):
self.serialize(.failure(error), for: request, context: context)
}
promise?.succeed(())
}
} else {
self.serialize(response, for: request, context: context)
}
}
}

func serialize(_ response: Result<Response, Error>, for request: Request, context: ChannelHandlerContext) {
switch response {
case .failure(let error):
self.errorCaught(context: context, error: error)
case .success(let response):
if request.method == .HEAD {
response.forHeadRequest = true
}
self.serialize(response, for: request, context: context)
}
}

func serialize(_ response: Response, for request: Request, context: ChannelHandlerContext) {
switch request.version.major {
Expand Down
218 changes: 195 additions & 23 deletions Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift
Expand Up @@ -14,17 +14,16 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
}

var requestState: RequestState
var bodyStreamState: HTTPBodyStreamState

private let logger: Logger
var pendingWriteCount: Int
var hasReadPending: Bool
var application: Application

init(application: Application) {
self.application = application
self.requestState = .ready
self.logger = Logger(label: "codes.vapor.server")
self.pendingWriteCount = 0
self.hasReadPending = false
self.bodyStreamState = .init()
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
Expand Down Expand Up @@ -63,11 +62,23 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
let stream = Request.BodyStream(on: context.eventLoop)
request.bodyStorage = .stream(stream)
context.fireChannelRead(self.wrapInboundOut(request))
self.write(.buffer(previousBuffer), to: stream, context: context)
self.write(.buffer(buffer), to: stream, context: context)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReadBytes(previousBuffer),
stream: stream
)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReadBytes(buffer),
stream: stream
)
self.requestState = .streamingBody(stream)
case .streamingBody(let stream):
self.write(.buffer(buffer), to: stream, context: context)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReadBytes(buffer),
stream: stream
)
}
case .end(let tailHeaders):
assert(tailHeaders == nil, "Tail headers are not supported.")
Expand All @@ -79,43 +90,204 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
request.bodyStorage = .collected(buffer)
context.fireChannelRead(self.wrapInboundOut(request))
case .streamingBody(let stream):
self.write(.end, to: stream, context: context)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didEnd(),
stream: stream
)
}
self.requestState = .ready
}
}

func read(context: ChannelHandlerContext) {
if self.pendingWriteCount <= 0 {
switch self.requestState {
case .streamingBody(let stream):
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didReceiveReadRequest(),
stream: stream
)
default:
context.read()
} else {
self.hasReadPending = true
}
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
switch self.requestState {
case .streamingBody(let stream):
stream.write(.error(error), promise: nil)
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didError(error),
stream: stream
)
default:
break
}
context.fireErrorCaught(error)
}

func write(_ part: BodyStreamResult, to stream: Request.BodyStream, context: ChannelHandlerContext) {
self.pendingWriteCount += 1
stream.write(part).whenComplete { result in
self.pendingWriteCount -= 1
if self.hasReadPending {
self.hasReadPending = false
self.read(context: context)
func handleBodyStreamStateResult(
context: ChannelHandlerContext,
_ result: HTTPBodyStreamState.Result,
stream: Request.BodyStream
) {
switch result.action {
case .nothing: break
case .write(let buffer):
stream.write(.buffer(buffer)).whenComplete { writeResult in
switch writeResult {
case .failure(let error):
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didError(error),
stream: stream
)
case .success:
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didWrite(),
stream: stream
)
}
}
switch result {
case .failure(let error):
self.logger.error("Could not write body: \(error)")
case .success: break
case .close(let maybeError):
if let error = maybeError {
stream.write(.error(error), promise: nil)
} else {
stream.write(.end, promise: nil)
}
}
if result.callRead {
context.read()
}
}
}

struct HTTPBodyStreamState: CustomStringConvertible {
struct Result {
enum Action {
case nothing
case write(ByteBuffer)
case close(Error?)
}
let action: Action
let callRead: Bool
}

private struct BufferState {
var bufferedWrites: CircularBuffer<ByteBuffer>
var heldUpRead: Bool
var hasClosed: Bool

mutating func append(_ buffer: ByteBuffer) {
self.bufferedWrites.append(buffer)
}

var isEmpty: Bool {
return self.bufferedWrites.isEmpty
}

mutating func removeFirst() -> ByteBuffer {
return self.bufferedWrites.removeFirst()
}
}

private enum State {
case idle
case writing(BufferState)
case error(Error)
}

private var state: State

var description: String {
"\(self.state)"
}

init() {
self.state = .idle
}

mutating func didReadBytes(_ buffer: ByteBuffer) -> Result {
switch self.state {
case .idle:
self.state = .writing(.init(
bufferedWrites: .init(),
heldUpRead: false,
hasClosed: false
))
return .init(action: .write(buffer), callRead: false)
case .writing(var buffers):
buffers.append(buffer)
self.state = .writing(buffers)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didReceiveReadRequest() -> Result {
switch self.state {
case .idle:
return .init(action: .nothing, callRead: true)
case .writing(var buffers):
buffers.heldUpRead = true
self.state = .writing(buffers)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didEnd() -> Result {
switch self.state {
case .idle:
return .init(action: .close(nil), callRead: false)
case .writing(var buffers):
buffers.hasClosed = true
self.state = .writing(buffers)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didError(_ error: Error) -> Result {
switch self.state {
case .idle:
self.state = .error(error)
return .init(action: .close(error), callRead: false)
case .writing:
self.state = .error(error)
return .init(action: .nothing, callRead: false)
case .error:
return .init(action: .nothing, callRead: false)
}
}

mutating func didWrite() -> Result {
switch self.state {
case .idle:
self.illegalTransition()
case .writing(var buffers):
if buffers.isEmpty {
self.state = .idle
return .init(
action: buffers.hasClosed ? .close(nil) : .nothing,
callRead: buffers.heldUpRead
)
} else {
let first = buffers.removeFirst()
self.state = .writing(buffers)
return .init(action: .write(first), callRead: false)
}
case .error(let error):
return .init(action: .close(error), callRead: false)
}
}

private func illegalTransition(_ function: String = #function) -> Never {
preconditionFailure("illegal transition \(function) in \(self)")
}
}