Skip to content

Commit

Permalink
HTTP streaming fixes (#2404)
Browse files Browse the repository at this point in the history
* echo server

* warning

* fix tests

* drain unhandled request after response end

* rm commented code

* add handler tests

* test EOF framing, fixes #2391

* test incorrect response body stream count, fixes #2392

* http fixes
  • Loading branch information
tanner0101 committed Jun 24, 2020
1 parent de17edc commit 2f9be8b
Show file tree
Hide file tree
Showing 12 changed files with 393 additions and 49 deletions.
16 changes: 8 additions & 8 deletions Sources/Vapor/HTTP/Server/HTTPServer.swift
Expand Up @@ -336,7 +336,7 @@ final class HTTPServerErrorHandler: ChannelInboundHandler {
}
}

private extension ChannelPipeline {
extension ChannelPipeline {
func addVaporHTTP2Handlers(
application: Application,
responder: Responder,
Expand Down Expand Up @@ -414,19 +414,19 @@ private extension ChannelPipeline {
case .disabled:
break
}

// add NIO -> HTTP request decoder
let serverReqDecoder = HTTPServerRequestDecoder(
application: application
)
handlers.append(serverReqDecoder)


// add NIO -> HTTP response encoder
let serverResEncoder = HTTPServerResponseEncoder(
serverHeader: configuration.serverName,
dateCache: .eventLoop(self.eventLoop)
)
handlers.append(serverResEncoder)

// add NIO -> HTTP request decoder
let serverReqDecoder = HTTPServerRequestDecoder(
application: application
)
handlers.append(serverReqDecoder)
// add server request -> response delegate
let handler = HTTPServerHandler(responder: responder)

Expand Down
28 changes: 9 additions & 19 deletions Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Expand Up @@ -13,22 +13,7 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let request = self.unwrapInboundIn(data)
self.responder.respond(to: request).whenComplete { response in
if case .stream(let stream) = request.bodyStorage, !stream.isClosed {
// If streaming request body has not been closed yet,
// drain it before sending the response.
stream.read { (result, promise) in
switch result {
case .buffer: break
case .end:
self.serialize(response, for: request, context: context)
case .error(let error):
self.serialize(.failure(error), for: request, context: context)
}
promise?.succeed(())
}
} else {
self.serialize(response, for: request, context: context)
}
self.serialize(response, for: request, context: context)
}
}

Expand All @@ -51,9 +36,14 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
default:
response.headers.add(name: .connection, value: request.isKeepAlive ? "keep-alive" : "close")
let done = context.write(self.wrapOutboundOut(response))
if !request.isKeepAlive {
done.whenComplete { _ in
context.close(mode: .output, promise: nil)
done.whenComplete { result in
switch result {
case .success:
if !request.isKeepAlive {
context.close(mode: .output, promise: nil)
}
case .failure(let error):
self.errorCaught(context: context, error: error)
}
}
}
Expand Down
55 changes: 48 additions & 7 deletions Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift
@@ -1,7 +1,7 @@
import NIO
import NIOHTTP1

final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHandler {
final class HTTPServerRequestDecoder: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = HTTPServerRequestPart
typealias InboundOut = Request
typealias OutboundIn = Never
Expand All @@ -11,6 +11,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
case awaitingBody(Request)
case awaitingEnd(Request, ByteBuffer)
case streamingBody(Request.BodyStream)
case skipping
}

var requestState: RequestState
Expand Down Expand Up @@ -79,6 +80,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
self.bodyStreamState.didReadBytes(buffer),
stream: stream
)
case .skipping: break
}
case .end(let tailHeaders):
assert(tailHeaders == nil, "Tail headers are not supported.")
Expand All @@ -95,6 +97,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
self.bodyStreamState.didEnd(),
stream: stream
)
case .skipping: break
}
self.requestState = .ready
}
Expand Down Expand Up @@ -127,6 +130,20 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
context.fireErrorCaught(error)
}

func channelInactive(context: ChannelHandlerContext) {
switch self.requestState {
case .streamingBody(let stream):
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didEnd(),
stream: stream
)
default:
break
}
context.fireChannelInactive()
}

func handleBodyStreamStateResult(
context: ChannelHandlerContext,
_ result: HTTPBodyStreamState.Result,
Expand All @@ -143,13 +160,13 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
self.bodyStreamState.didError(error),
stream: stream
)
case .success:
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didWrite(),
stream: stream
)
case .success: break
}
self.handleBodyStreamStateResult(
context: context,
self.bodyStreamState.didWrite(),
stream: stream
)
}
case .close(let maybeError):
if let error = maybeError {
Expand All @@ -162,6 +179,30 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand
context.read()
}
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if event is HTTPServerResponseEncoder.ResponseEndSentEvent {
switch self.requestState {
case .streamingBody(let bodyStream):
// Response ended during request stream.
if !bodyStream.isBeingRead {
self.logger.trace("Response already sent, draining unhandled request stream.")
bodyStream.read { _, promise in
promise?.succeed(())
}
}
case .awaitingEnd, .awaitingBody:
// Response ended before request started streaming.
self.logger.trace("Response already sent, skipping request body.")
self.requestState = .skipping
case .ready:
// Response ended after request had been read.
break
default:
fatalError("Unexpected request state: \(self.requestState)")
}
}
}
}

struct HTTPBodyStreamState: CustomStringConvertible {
Expand Down
60 changes: 55 additions & 5 deletions Sources/Vapor/HTTP/Server/HTTPServerResponseEncoder.swift
Expand Up @@ -8,6 +8,8 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
/// Optional server header.
private let serverHeader: String?
private let dateCache: RFC1123DateCache

struct ResponseEndSentEvent { }

init(serverHeader: String?, dateCache: RFC1123DateCache) {
self.serverHeader = serverHeader
Expand All @@ -30,14 +32,17 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
status: response.status,
headers: response.headers
))), promise: nil)


if response.status == .noContent || response.forHeadRequest {
// don't send bodies for 204 (no content) responses
// or HEAD requests
context.fireUserInboundEventTriggered(ResponseEndSentEvent())
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: promise)
} else {
switch response.body.storage {
case .none:
context.fireUserInboundEventTriggered(ResponseEndSentEvent())
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: promise)
case .buffer(let buffer):
self.writeAndflush(buffer: buffer, context: context, promise: promise)
Expand All @@ -58,7 +63,12 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
buffer.writeDispatchData(data)
self.writeAndflush(buffer: buffer, context: context, promise: promise)
case .stream(let stream):
let channelStream = ChannelResponseBodyStream(context: context, handler: self, promise: promise)
let channelStream = ChannelResponseBodyStream(
context: context,
handler: self,
promise: promise,
count: stream.count == -1 ? nil : stream.count
)
stream.callback(channelStream)
}
}
Expand All @@ -67,31 +77,71 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH
/// Writes a `ByteBuffer` to the context.
private func writeAndflush(buffer: ByteBuffer, context: ChannelHandlerContext, promise: EventLoopPromise<Void>?) {
if buffer.readableBytes > 0 {
_ = context.write(wrapOutboundOut(.body(.byteBuffer(buffer))))
context.write(wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil)
}
context.fireUserInboundEventTriggered(ResponseEndSentEvent())
context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: promise)
}
}

private struct ChannelResponseBodyStream: BodyStreamWriter {
private final class ChannelResponseBodyStream: BodyStreamWriter {
let context: ChannelHandlerContext
let handler: HTTPServerResponseEncoder
let promise: EventLoopPromise<Void>?
let count: Int?
var currentCount: Int
var isComplete: Bool

var eventLoop: EventLoop {
return self.context.eventLoop
}

enum Error: Swift.Error {
case tooManyBytes
case notEnoughBytes
}

init(
context: ChannelHandlerContext,
handler: HTTPServerResponseEncoder,
promise: EventLoopPromise<Void>?,
count: Int?
) {
self.context = context
self.handler = handler
self.promise = promise
self.count = count
self.currentCount = 0
self.isComplete = false
}

func write(_ result: BodyStreamResult, promise: EventLoopPromise<Void>?) {
switch result {
case .buffer(let buffer):
self.context.writeAndFlush(self.handler.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise)
self.currentCount += buffer.readableBytes
if let count = self.count, self.currentCount > count {
self.promise?.fail(Error.tooManyBytes)
promise?.fail(Error.notEnoughBytes)
}
case .end:
self.promise?.succeed(())
self.isComplete = true
if let count = self.count, self.currentCount != count {
self.promise?.fail(Error.notEnoughBytes)
promise?.fail(Error.notEnoughBytes)
}
self.context.fireUserInboundEventTriggered(HTTPServerResponseEncoder.ResponseEndSentEvent())
self.context.writeAndFlush(self.handler.wrapOutboundOut(.end(nil)), promise: promise)
self.promise?.succeed(())
case .error(let error):
self.promise?.fail(error)
self.isComplete = true
self.context.fireUserInboundEventTriggered(HTTPServerResponseEncoder.ResponseEndSentEvent())
self.context.writeAndFlush(self.handler.wrapOutboundOut(.end(nil)), promise: promise)
self.promise?.fail(error)
}
}

deinit {
assert(self.isComplete, "Response body stream writer deinitialized before .end or .error was sent.")
}
}
3 changes: 2 additions & 1 deletion Sources/Vapor/Request/Request+Body.swift
Expand Up @@ -30,7 +30,8 @@ extension Request {
case .collected(let buffer):
_ = handler(.buffer(buffer))
_ = handler(.end)
case .none: break
case .none:
_ = handler(.end)
}
}

Expand Down
4 changes: 4 additions & 0 deletions Sources/Vapor/Request/Request+BodyStream.swift
Expand Up @@ -7,6 +7,10 @@ extension Request {

let eventLoop: EventLoop

var isBeingRead: Bool {
self.handler != nil
}

init(on eventLoop: EventLoop) {
self.eventLoop = eventLoop
self.isClosed = false
Expand Down
6 changes: 5 additions & 1 deletion Sources/Vapor/Response/Response+Body.swift
Expand Up @@ -37,7 +37,7 @@ extension Response {
}

/// The size of the HTTP body's data.
/// `nil` is a stream.
/// `-1` is a chunked stream.
public var count: Int {
switch self.storage {
case .data(let data): return data.count
Expand Down Expand Up @@ -147,6 +147,10 @@ extension Response {
public init(stream: @escaping (BodyStreamWriter) -> (), count: Int) {
self.storage = .stream(.init(count: count, callback: stream))
}

public init(stream: @escaping (BodyStreamWriter) -> ()) {
self.init(stream: stream, count: -1)
}

/// `ExpressibleByStringLiteral` conformance.
public init(stringLiteral value: String) {
Expand Down
14 changes: 11 additions & 3 deletions Sources/Vapor/Response/Response.swift
Expand Up @@ -162,9 +162,17 @@ public final class Response: CustomStringConvertible {
extension HTTPHeaders {
mutating func updateContentLength(_ contentLength: Int) {
let count = contentLength.description
self.remove(name: .transferEncoding)
if count != self[.contentLength].first {
self.replaceOrAdd(name: .contentLength, value: count)
switch contentLength {
case -1:
self.remove(name: .contentLength)
if "chunked" != self.first(name: .transferEncoding) {
self.add(name: .transferEncoding, value: "chunked")
}
default:
self.remove(name: .transferEncoding)
if count != self.first(name: .contentLength) {
self.replaceOrAdd(name: .contentLength, value: count)
}
}
}
}
6 changes: 1 addition & 5 deletions Tests/VaporTests/ClientTests.swift
Expand Up @@ -71,11 +71,7 @@ final class ClientTests: XCTestCase {
}

func testBoilerplateClient() throws {
let app = Application(.init(
name: "xctest",
arguments: ["vapor", "serve", "-b", "localhost:8080", "--log", "trace"]
))
try LoggingSystem.bootstrap(from: &app.environment)
let app = Application(.testing)
defer { app.shutdown() }

app.get("foo") { req -> EventLoopFuture<String> in
Expand Down

0 comments on commit 2f9be8b

Please sign in to comment.