Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HTTP streaming improvements #2404

Merged
merged 11 commits into from Jun 24, 2020
16 changes: 8 additions & 8 deletions Sources/Vapor/HTTP/Server/HTTPServer.swift
Expand Up @@ -336,7 +336,7 @@ final class HTTPServerErrorHandler: ChannelInboundHandler {
}
}

private extension ChannelPipeline {
extension ChannelPipeline {
func addVaporHTTP2Handlers(
application: Application,
responder: Responder,
Expand Down Expand Up @@ -414,19 +414,19 @@ private extension ChannelPipeline {
case .disabled:
break
}

// add NIO -> HTTP request decoder
let serverReqDecoder = HTTPServerRequestDecoder(
application: application
)
handlers.append(serverReqDecoder)


// add NIO -> HTTP response encoder
let serverResEncoder = HTTPServerResponseEncoder(
serverHeader: configuration.serverName,
dateCache: .eventLoop(self.eventLoop)
)
handlers.append(serverResEncoder)

// add NIO -> HTTP request decoder
let serverReqDecoder = HTTPServerRequestDecoder(
application: application
)
handlers.append(serverReqDecoder)
// add server request -> response delegate
let handler = HTTPServerHandler(responder: responder)

Expand Down
28 changes: 9 additions & 19 deletions Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Expand Up @@ -13,22 +13,7 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)
self.responder.respond(to: request).whenComplete { response in
if case .stream(let stream) = request.bodyStorage, !stream.isClosed {
// If streaming request body has not been closed yet,
// drain it before sending the response.
stream.read { (result, promise) in
switch result {
case .buffer: break
case .end:
self.serialize(response, for: request, context: context)
case .error(let error):
self.serialize(.failure(error), for: request, context: context)
}
promise?.succeed(())
}
} else {
self.serialize(response, for: request, context: context)
}
self.serialize(response, for: request, context: context)
}
}

Expand All @@ -51,9 +36,14 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
default:
response.headers.add(name: .connection, value: request.isKeepAlive ? "keep-alive" : "close")
let done = context.write(self.wrapOutboundOut(response))
if !request.isKeepAlive {
done.whenComplete { _ in
context.close(mode: .output, promise: nil)
done.whenComplete { result in
switch result {
case .success:
if !request.isKeepAlive {
context.close(mode: .output, promise: nil)
}
case .failure(let error):
self.errorCaught(context: context, error: error)
}
}
}
Expand Down
55 changes: 48 additions & 7 deletions Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift
@@ -1,7 +1,7 @@
import NIO
import NIOHTTP1

final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHandler {
final class HTTPServerRequestDecoder: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPServerRequestPart
typealias InboundOut = Request
typealias OutboundIn = Never
Expand All @@ -11,6 +11,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
case awaitingBody(Request)
case awaitingEnd(Request, ByteBuffer)
case streamingBody(Request.BodyStream)
case skipping
}

var requestState: RequestState
Expand Down Expand Up @@ -79,6 +80,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
self.bodyStreamState.didReadBytes(buffer),
stream: stream
)
case .skipping: break
}
case .end(let tailHeaders):
assert(tailHeaders == nil, "Tail headers are not supported.")
Expand All @@ -95,6 +97,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
self.bodyStreamState.didEnd(),
stream: stream
)
case .skipping: break
}
self.requestState = .ready
}
Expand Down Expand Up @@ -127,6 +130,20 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
context.fireErrorCaught(error)
}

func channelInactive(context: ChannelHandlerContext) {
switch self.requestState {
case .streamingBody(let stream):
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didEnd(),
stream: stream
)
default:
break
}
context.fireChannelInactive()
}

func handleBodyStreamStateResult(
context: ChannelHandlerContext,
_ result: HTTPBodyStreamState.Result,
Expand All @@ -143,13 +160,13 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
self.bodyStreamState.didError(error),
stream: stream
)
case .success:
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didWrite(),
stream: stream
)
case .success: break
}
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didWrite(),
stream: stream
)
}
case .close(let maybeError):
if let error = maybeError {
Expand All @@ -162,6 +179,30 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
context.read()
}
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if event is HTTPServerResponseEncoder.ResponseEndSentEvent {
switch self.requestState {
case .streamingBody(let bodyStream):
// Response ended during request stream.
if !bodyStream.isBeingRead {
self.logger.trace("Response already sent, draining unhandled request stream.")
bodyStream.read { _, promise in
promise?.succeed(())
}
}
case .awaitingEnd, .awaitingBody:
// Response ended before request started streaming.
self.logger.trace("Response already sent, skipping request body.")
self.requestState = .skipping
case .ready:
// Response ended after request had been read.
break
default:
fatalError("Unexpected request state: \(self.requestState)")
}
}
}
}

struct HTTPBodyStreamState: CustomStringConvertible {
Expand Down
60 changes: 55 additions & 5 deletions Sources/Vapor/HTTP/Server/HTTPServerResponseEncoder.swift
Expand Up @@ -8,6 +8,8 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
/// Optional server header.
private let serverHeader: String?
private let dateCache: RFC1123DateCache

struct ResponseEndSentEvent { }

init(serverHeader: String?, dateCache: RFC1123DateCache) {
self.serverHeader = serverHeader
Expand All @@ -30,14 +32,17 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
status: response.status,
headers: response.headers
))), promise: nil)


if response.status == .noContent || response.forHeadRequest {
// don't send bodies for 204 (no content) responses
// or HEAD requests
context.fireUserInboundEventTriggered(ResponseEndSentEvent())
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: promise)
} else {
switch response.body.storage {
case .none:
context.fireUserInboundEventTriggered(ResponseEndSentEvent())
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: promise)
case .buffer(let buffer):
self.writeAndflush(buffer: buffer, context: context, promise: promise)
Expand All @@ -58,7 +63,12 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
buffer.writeDispatchData(data)
self.writeAndflush(buffer: buffer, context: context, promise: promise)
case .stream(let stream):
let channelStream = ChannelResponseBodyStream(context: context, handler: self, promise: promise)
let channelStream = ChannelResponseBodyStream(
context: context,
handler: self,
promise: promise,
count: stream.count == -1 ? nil : stream.count
)
stream.callback(channelStream)
}
}
Expand All @@ -67,31 +77,71 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
/// Writes a `ByteBuffer` to the context.
private func writeAndflush(buffer: ByteBuffer, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
if buffer.readableBytes > 0 {
_ = context.write(wrapOutboundOut(.body(.byteBuffer(buffer))))
context.write(wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil)
}
context.fireUserInboundEventTriggered(ResponseEndSentEvent())
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: promise)
}
}

private struct ChannelResponseBodyStream: BodyStreamWriter {
private final class ChannelResponseBodyStream: BodyStreamWriter {
let context: ChannelHandlerContext
let handler: HTTPServerResponseEncoder
let promise: EventLoopPromise<Void>?
let count: Int?
var currentCount: Int
var isComplete: Bool

var eventLoop: EventLoop {
return self.context.eventLoop
}

enum Error: Swift.Error {
case tooManyBytes
case notEnoughBytes
}

init(
context: ChannelHandlerContext,
handler: HTTPServerResponseEncoder,
promise: EventLoopPromise<Void>?,
count: Int?
) {
self.context = context
self.handler = handler
self.promise = promise
self.count = count
self.currentCount = 0
self.isComplete = false
}

func write(_ result: BodyStreamResult, promise: EventLoopPromise<Void>?) {
switch result {
case .buffer(let buffer):
self.context.writeAndFlush(self.handler.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise)
self.currentCount += buffer.readableBytes
if let count = self.count, self.currentCount > count {
self.promise?.fail(Error.tooManyBytes)
promise?.fail(Error.notEnoughBytes)
}
case .end:
self.promise?.succeed(())
self.isComplete = true
if let count = self.count, self.currentCount != count {
self.promise?.fail(Error.notEnoughBytes)
promise?.fail(Error.notEnoughBytes)
}
self.context.fireUserInboundEventTriggered(HTTPServerResponseEncoder.ResponseEndSentEvent())
self.context.writeAndFlush(self.handler.wrapOutboundOut(.end(nil)), promise: promise)
self.promise?.succeed(())
case .error(let error):
self.promise?.fail(error)
self.isComplete = true
self.context.fireUserInboundEventTriggered(HTTPServerResponseEncoder.ResponseEndSentEvent())
self.context.writeAndFlush(self.handler.wrapOutboundOut(.end(nil)), promise: promise)
self.promise?.fail(error)
}
}

deinit {
assert(self.isComplete, "Response body stream writer deinitialized before .end or .error was sent.")
}
}
3 changes: 2 additions & 1 deletion Sources/Vapor/Request/Request+Body.swift
Expand Up @@ -30,7 +30,8 @@ extension Request {
case .collected(let buffer):
_ = handler(.buffer(buffer))
_ = handler(.end)
case .none: break
case .none:
_ = handler(.end)
}
}

Expand Down
4 changes: 4 additions & 0 deletions Sources/Vapor/Request/Request+BodyStream.swift
Expand Up @@ -7,6 +7,10 @@ extension Request {

let eventLoop: EventLoop

var isBeingRead: Bool {
self.handler != nil
}

init(on eventLoop: EventLoop) {
self.eventLoop = eventLoop
self.isClosed = false
Expand Down
6 changes: 5 additions & 1 deletion Sources/Vapor/Response/Response+Body.swift
Expand Up @@ -37,7 +37,7 @@ extension Response {
}

/// The size of the HTTP body's data.
/// `nil` is a stream.
/// `-1` is a chunked stream.
public var count: Int {
switch self.storage {
case .data(let data): return data.count
Expand Down Expand Up @@ -147,6 +147,10 @@ extension Response {
public init(stream: @escaping (BodyStreamWriter) -> (), count: Int) {
self.storage = .stream(.init(count: count, callback: stream))
}

public init(stream: @escaping (BodyStreamWriter) -> ()) {
self.init(stream: stream, count: -1)
}

/// `ExpressibleByStringLiteral` conformance.
public init(stringLiteral value: String) {
Expand Down
14 changes: 11 additions & 3 deletions Sources/Vapor/Response/Response.swift
Expand Up @@ -162,9 +162,17 @@ public final class Response: CustomStringConvertible {
extension HTTPHeaders {
mutating func updateContentLength(_ contentLength: Int) {
let count = contentLength.description
self.remove(name: .transferEncoding)
if count != self[.contentLength].first {
self.replaceOrAdd(name: .contentLength, value: count)
switch contentLength {
case -1:
self.remove(name: .contentLength)
if "chunked" != self.first(name: .transferEncoding) {
self.add(name: .transferEncoding, value: "chunked")
}
default:
self.remove(name: .transferEncoding)
if count != self.first(name: .contentLength) {
self.replaceOrAdd(name: .contentLength, value: count)
}
}
}
}
6 changes: 1 addition & 5 deletions Tests/VaporTests/ClientTests.swift
Expand Up @@ -71,11 +71,7 @@ final class ClientTests: XCTestCase {
}

func testBoilerplateClient() throws {
let app = Application(.init(
name: "xctest",
arguments: ["vapor", "serve", "-b", "localhost:8080", "--log", "trace"]
))
try LoggingSystem.bootstrap(from: &app.environment)
let app = Application(.testing)
defer { app.shutdown() }

app.get("foo") { req -> EventLoopFuture<String> in
Expand Down