diff --git a/Sources/Vapor/HTTP/Server/HTTPServer.swift b/Sources/Vapor/HTTP/Server/HTTPServer.swift index b498aab030..aec5bfac79 100644 --- a/Sources/Vapor/HTTP/Server/HTTPServer.swift +++ b/Sources/Vapor/HTTP/Server/HTTPServer.swift @@ -336,7 +336,7 @@ final class HTTPServerErrorHandler: ChannelInboundHandler { } } -private extension ChannelPipeline { +extension ChannelPipeline { func addVaporHTTP2Handlers( application: Application, responder: Responder, @@ -414,19 +414,19 @@ private extension ChannelPipeline { case .disabled: break } - - // add NIO -> HTTP request decoder - let serverReqDecoder = HTTPServerRequestDecoder( - application: application - ) - handlers.append(serverReqDecoder) - + // add NIO -> HTTP response encoder let serverResEncoder = HTTPServerResponseEncoder( serverHeader: configuration.serverName, dateCache: .eventLoop(self.eventLoop) ) handlers.append(serverResEncoder) + + // add NIO -> HTTP request decoder + let serverReqDecoder = HTTPServerRequestDecoder( + application: application + ) + handlers.append(serverReqDecoder) // add server request -> response delegate let handler = HTTPServerHandler(responder: responder) diff --git a/Sources/Vapor/HTTP/Server/HTTPServerHandler.swift b/Sources/Vapor/HTTP/Server/HTTPServerHandler.swift index 7638888e43..0da6ec0959 100644 --- a/Sources/Vapor/HTTP/Server/HTTPServerHandler.swift +++ b/Sources/Vapor/HTTP/Server/HTTPServerHandler.swift @@ -13,22 +13,7 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler { func channelRead(context: ChannelHandlerContext, data: NIOAny) { let request = self.unwrapInboundIn(data) self.responder.respond(to: request).whenComplete { response in - if case .stream(let stream) = request.bodyStorage, !stream.isClosed { - // If streaming request body has not been closed yet, - // drain it before sending the response. - stream.read { (result, promise) in - switch result { - case .buffer: break - case .end: - self.serialize(response, for: request, context: context) - case .error(let error): - self.serialize(.failure(error), for: request, context: context) - } - promise?.succeed(()) - } - } else { - self.serialize(response, for: request, context: context) - } + self.serialize(response, for: request, context: context) } } @@ -51,9 +36,14 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler { default: response.headers.add(name: .connection, value: request.isKeepAlive ? "keep-alive" : "close") let done = context.write(self.wrapOutboundOut(response)) - if !request.isKeepAlive { - done.whenComplete { _ in - context.close(mode: .output, promise: nil) + done.whenComplete { result in + switch result { + case .success: + if !request.isKeepAlive { + context.close(mode: .output, promise: nil) + } + case .failure(let error): + self.errorCaught(context: context, error: error) } } } diff --git a/Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift b/Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift index 71b9298f7b..3c3a558274 100644 --- a/Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift +++ b/Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift @@ -1,7 +1,7 @@ import NIO import NIOHTTP1 -final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHandler { +final class HTTPServerRequestDecoder: ChannelInboundHandler, RemovableChannelHandler { typealias InboundIn = HTTPServerRequestPart typealias InboundOut = Request typealias OutboundIn = Never @@ -11,6 +11,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand case awaitingBody(Request) case awaitingEnd(Request, ByteBuffer) case streamingBody(Request.BodyStream) + case skipping } var requestState: RequestState @@ -79,6 +80,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand self.bodyStreamState.didReadBytes(buffer), stream: stream ) + case .skipping: break } case .end(let tailHeaders): assert(tailHeaders == nil, "Tail headers are not supported.") @@ -95,6 +97,7 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand self.bodyStreamState.didEnd(), stream: stream ) + case .skipping: break } self.requestState = .ready } @@ -127,6 +130,20 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand context.fireErrorCaught(error) } + func channelInactive(context: ChannelHandlerContext) { + switch self.requestState { + case .streamingBody(let stream): + self.handleBodyStreamStateResult( + context: context, + self.bodyStreamState.didEnd(), + stream: stream + ) + default: + break + } + context.fireChannelInactive() + } + func handleBodyStreamStateResult( context: ChannelHandlerContext, _ result: HTTPBodyStreamState.Result, @@ -143,13 +160,13 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand self.bodyStreamState.didError(error), stream: stream ) - case .success: - self.handleBodyStreamStateResult( - context: context, - self.bodyStreamState.didWrite(), - stream: stream - ) + case .success: break } + self.handleBodyStreamStateResult( + context: context, + self.bodyStreamState.didWrite(), + stream: stream + ) } case .close(let maybeError): if let error = maybeError { @@ -162,6 +179,30 @@ final class HTTPServerRequestDecoder: ChannelDuplexHandler, RemovableChannelHand context.read() } } + + func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + if event is HTTPServerResponseEncoder.ResponseEndSentEvent { + switch self.requestState { + case .streamingBody(let bodyStream): + // Response ended during request stream. + if !bodyStream.isBeingRead { + self.logger.trace("Response already sent, draining unhandled request stream.") + bodyStream.read { _, promise in + promise?.succeed(()) + } + } + case .awaitingEnd, .awaitingBody: + // Response ended before request started streaming. + self.logger.trace("Response already sent, skipping request body.") + self.requestState = .skipping + case .ready: + // Response ended after request had been read. + break + default: + fatalError("Unexpected request state: \(self.requestState)") + } + } + } } struct HTTPBodyStreamState: CustomStringConvertible { diff --git a/Sources/Vapor/HTTP/Server/HTTPServerResponseEncoder.swift b/Sources/Vapor/HTTP/Server/HTTPServerResponseEncoder.swift index 17f9ae33b9..6098d4602a 100644 --- a/Sources/Vapor/HTTP/Server/HTTPServerResponseEncoder.swift +++ b/Sources/Vapor/HTTP/Server/HTTPServerResponseEncoder.swift @@ -8,6 +8,8 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH /// Optional server header. private let serverHeader: String? private let dateCache: RFC1123DateCache + + struct ResponseEndSentEvent { } init(serverHeader: String?, dateCache: RFC1123DateCache) { self.serverHeader = serverHeader @@ -30,14 +32,17 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH status: response.status, headers: response.headers ))), promise: nil) + if response.status == .noContent || response.forHeadRequest { // don't send bodies for 204 (no content) responses // or HEAD requests + context.fireUserInboundEventTriggered(ResponseEndSentEvent()) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: promise) } else { switch response.body.storage { case .none: + context.fireUserInboundEventTriggered(ResponseEndSentEvent()) context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: promise) case .buffer(let buffer): self.writeAndflush(buffer: buffer, context: context, promise: promise) @@ -58,7 +63,12 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH buffer.writeDispatchData(data) self.writeAndflush(buffer: buffer, context: context, promise: promise) case .stream(let stream): - let channelStream = ChannelResponseBodyStream(context: context, handler: self, promise: promise) + let channelStream = ChannelResponseBodyStream( + context: context, + handler: self, + promise: promise, + count: stream.count == -1 ? nil : stream.count + ) stream.callback(channelStream) } } @@ -67,31 +77,71 @@ final class HTTPServerResponseEncoder: ChannelOutboundHandler, RemovableChannelH /// Writes a `ByteBuffer` to the context. private func writeAndflush(buffer: ByteBuffer, context: ChannelHandlerContext, promise: EventLoopPromise?) { if buffer.readableBytes > 0 { - _ = context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))) + context.write(wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) } + context.fireUserInboundEventTriggered(ResponseEndSentEvent()) context.writeAndFlush(wrapOutboundOut(.end(nil)), promise: promise) } } -private struct ChannelResponseBodyStream: BodyStreamWriter { +private final class ChannelResponseBodyStream: BodyStreamWriter { let context: ChannelHandlerContext let handler: HTTPServerResponseEncoder let promise: EventLoopPromise? + let count: Int? + var currentCount: Int + var isComplete: Bool var eventLoop: EventLoop { return self.context.eventLoop } + + enum Error: Swift.Error { + case tooManyBytes + case notEnoughBytes + } + + init( + context: ChannelHandlerContext, + handler: HTTPServerResponseEncoder, + promise: EventLoopPromise?, + count: Int? + ) { + self.context = context + self.handler = handler + self.promise = promise + self.count = count + self.currentCount = 0 + self.isComplete = false + } func write(_ result: BodyStreamResult, promise: EventLoopPromise?) { switch result { case .buffer(let buffer): self.context.writeAndFlush(self.handler.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: promise) + self.currentCount += buffer.readableBytes + if let count = self.count, self.currentCount > count { + self.promise?.fail(Error.tooManyBytes) + promise?.fail(Error.notEnoughBytes) + } case .end: - self.promise?.succeed(()) + self.isComplete = true + if let count = self.count, self.currentCount != count { + self.promise?.fail(Error.notEnoughBytes) + promise?.fail(Error.notEnoughBytes) + } + self.context.fireUserInboundEventTriggered(HTTPServerResponseEncoder.ResponseEndSentEvent()) self.context.writeAndFlush(self.handler.wrapOutboundOut(.end(nil)), promise: promise) + self.promise?.succeed(()) case .error(let error): - self.promise?.fail(error) + self.isComplete = true + self.context.fireUserInboundEventTriggered(HTTPServerResponseEncoder.ResponseEndSentEvent()) self.context.writeAndFlush(self.handler.wrapOutboundOut(.end(nil)), promise: promise) + self.promise?.fail(error) } } + + deinit { + assert(self.isComplete, "Response body stream writer deinitialized before .end or .error was sent.") + } } diff --git a/Sources/Vapor/Request/Request+Body.swift b/Sources/Vapor/Request/Request+Body.swift index 7e1fb97073..c2eecf108e 100644 --- a/Sources/Vapor/Request/Request+Body.swift +++ b/Sources/Vapor/Request/Request+Body.swift @@ -30,7 +30,8 @@ extension Request { case .collected(let buffer): _ = handler(.buffer(buffer)) _ = handler(.end) - case .none: break + case .none: + _ = handler(.end) } } diff --git a/Sources/Vapor/Request/Request+BodyStream.swift b/Sources/Vapor/Request/Request+BodyStream.swift index 402fdc933d..f499a7f0cb 100644 --- a/Sources/Vapor/Request/Request+BodyStream.swift +++ b/Sources/Vapor/Request/Request+BodyStream.swift @@ -7,6 +7,10 @@ extension Request { let eventLoop: EventLoop + var isBeingRead: Bool { + self.handler != nil + } + init(on eventLoop: EventLoop) { self.eventLoop = eventLoop self.isClosed = false diff --git a/Sources/Vapor/Response/Response+Body.swift b/Sources/Vapor/Response/Response+Body.swift index 9a6727b186..8308774287 100644 --- a/Sources/Vapor/Response/Response+Body.swift +++ b/Sources/Vapor/Response/Response+Body.swift @@ -37,7 +37,7 @@ extension Response { } /// The size of the HTTP body's data. - /// `nil` is a stream. + /// `-1` is a chunked stream. public var count: Int { switch self.storage { case .data(let data): return data.count @@ -147,6 +147,10 @@ extension Response { public init(stream: @escaping (BodyStreamWriter) -> (), count: Int) { self.storage = .stream(.init(count: count, callback: stream)) } + + public init(stream: @escaping (BodyStreamWriter) -> ()) { + self.init(stream: stream, count: -1) + } /// `ExpressibleByStringLiteral` conformance. public init(stringLiteral value: String) { diff --git a/Sources/Vapor/Response/Response.swift b/Sources/Vapor/Response/Response.swift index 277cfde913..5eae6e6ef3 100644 --- a/Sources/Vapor/Response/Response.swift +++ b/Sources/Vapor/Response/Response.swift @@ -162,9 +162,17 @@ public final class Response: CustomStringConvertible { extension HTTPHeaders { mutating func updateContentLength(_ contentLength: Int) { let count = contentLength.description - self.remove(name: .transferEncoding) - if count != self[.contentLength].first { - self.replaceOrAdd(name: .contentLength, value: count) + switch contentLength { + case -1: + self.remove(name: .contentLength) + if "chunked" != self.first(name: .transferEncoding) { + self.add(name: .transferEncoding, value: "chunked") + } + default: + self.remove(name: .transferEncoding) + if count != self.first(name: .contentLength) { + self.replaceOrAdd(name: .contentLength, value: count) + } } } } diff --git a/Tests/VaporTests/ClientTests.swift b/Tests/VaporTests/ClientTests.swift index 447a64f542..14d782c776 100644 --- a/Tests/VaporTests/ClientTests.swift +++ b/Tests/VaporTests/ClientTests.swift @@ -71,11 +71,7 @@ final class ClientTests: XCTestCase { } func testBoilerplateClient() throws { - let app = Application(.init( - name: "xctest", - arguments: ["vapor", "serve", "-b", "localhost:8080", "--log", "trace"] - )) - try LoggingSystem.bootstrap(from: &app.environment) + let app = Application(.testing) defer { app.shutdown() } app.get("foo") { req -> EventLoopFuture in diff --git a/Tests/VaporTests/PipelineTests.swift b/Tests/VaporTests/PipelineTests.swift new file mode 100644 index 0000000000..2ed33e2680 --- /dev/null +++ b/Tests/VaporTests/PipelineTests.swift @@ -0,0 +1,118 @@ +@testable import Vapor +import XCTest + +final class PipelineTests: XCTestCase { + func testEchoHandlers() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.on(.POST, "echo", body: .stream) { request -> Response in + Response(body: .init(stream: { writer in + request.body.drain { body in + switch body { + case .buffer(let buffer): + return writer.write(.buffer(buffer)) + case .error(let error): + return writer.write(.error(error)) + case .end: + return writer.write(.end) + } + } + })) + } + + let channel = EmbeddedChannel() + try channel.pipeline.addVaporHTTP1Handlers( + application: app, + responder: app.responder, + configuration: app.http.server.configuration + ).wait() + + try channel.writeInbound(ByteBuffer(string: "POST /echo HTTP/1.1\r\ntransfer-encoding: chunked\r\n\r\n1\r\na\r\n")) + try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string) + + try channel.writeInbound(ByteBuffer(string: "1\r\nb\r\n")) + let chunk = try channel.readOutbound(as: ByteBuffer.self)?.string + XCTAssertContains(chunk, "HTTP/1.1 200 OK") + XCTAssertContains(chunk, "connection: keep-alive") + XCTAssertContains(chunk, "transfer-encoding: chunked") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "1\r\n") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "a") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "\r\n") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "1\r\n") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "b") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "\r\n") + try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string) + + try channel.writeInbound(ByteBuffer(string: "1\r\nc\r\n")) + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "1\r\n") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "c") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "\r\n") + try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string) + + try channel.writeInbound(ByteBuffer(string: "0\r\n\r\n")) + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "0\r\n\r\n") + try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string) + } + + func testEOFFraming() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.on(.POST, "echo", body: .stream) { request -> Response in + Response(body: .init(stream: { writer in + request.body.drain { body in + switch body { + case .buffer(let buffer): + return writer.write(.buffer(buffer)) + case .error(let error): + return writer.write(.error(error)) + case .end: + return writer.write(.end) + } + } + })) + } + + let channel = EmbeddedChannel() + try channel.pipeline.addVaporHTTP1Handlers( + application: app, + responder: app.responder, + configuration: app.http.server.configuration + ).wait() + + try channel.writeInbound(ByteBuffer(string: "POST /echo HTTP/1.1\r\n\r\n")) + try XCTAssertContains(channel.readOutbound(as: ByteBuffer.self)?.string, "HTTP/1.1 200 OK") + } + + func testBadStreamLength() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.on(.POST, "echo", body: .stream) { request -> Response in + Response(body: .init(stream: { writer in + writer.write(.buffer(.init(string: "a")), promise: nil) + writer.write(.end, promise: nil) + }, count: 2)) + } + + let channel = EmbeddedChannel() + try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait() + try channel.pipeline.addVaporHTTP1Handlers( + application: app, + responder: app.responder, + configuration: app.http.server.configuration + ).wait() + + XCTAssertEqual(channel.isActive, true) + try channel.writeInbound(ByteBuffer(string: "POST /echo HTTP/1.1\r\n\r\n")) + XCTAssertEqual(channel.isActive, false) + try XCTAssertContains(channel.readOutbound(as: ByteBuffer.self)?.string, "HTTP/1.1 200 OK") + try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "a") + try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string) + } + + override class func setUp() { + XCTAssert(isLoggingConfigured) + } +} diff --git a/Tests/VaporTests/ServerTests.swift b/Tests/VaporTests/ServerTests.swift index a8e39a647b..0c807bc54a 100644 --- a/Tests/VaporTests/ServerTests.swift +++ b/Tests/VaporTests/ServerTests.swift @@ -1,5 +1,6 @@ import Vapor import XCTest +import protocol AsyncHTTPClient.HTTPClientResponseDelegate final class ServerTests: XCTestCase { func testPortOverride() throws { @@ -201,6 +202,122 @@ final class ServerTests: XCTestCase { XCTAssertEqual(res.status, .ok) }) } + + func testEchoServer() throws { + let app = Application(.testing) + defer { app.shutdown() } + + final class Context { + var server: [String] + var client: [String] + init() { + self.server = [] + self.client = [] + } + } + let context = Context() + + app.on(.POST, "echo", body: .stream) { request -> Response in + Response(body: .init(stream: { writer in + request.body.drain { body in + switch body { + case .buffer(let buffer): + context.server.append(buffer.string) + return writer.write(.buffer(buffer)) + case .error(let error): + return writer.write(.error(error)) + case .end: + return writer.write(.end) + } + } + })) + } + + let port = 1337 + app.http.server.configuration.port = port + try app.start() + + let request = try HTTPClient.Request( + url: "http://localhost:\(port)/echo", + method: .POST, + headers: [ + "transfer-encoding": "chunked" + ], + body: .stream(length: nil, { stream in + stream.write(.byteBuffer(.init(string: "foo"))).flatMap { + stream.write(.byteBuffer(.init(string: "bar"))) + }.flatMap { + stream.write(.byteBuffer(.init(string: "baz"))) + } + }) + ) + + final class ResponseDelegate: HTTPClientResponseDelegate { + typealias Response = HTTPClient.Response + + let context: Context + init(context: Context) { + self.context = context + } + + func didReceiveBodyPart( + task: HTTPClient.Task, + _ buffer: ByteBuffer + ) -> EventLoopFuture { + self.context.client.append(buffer.string) + return task.eventLoop.makeSucceededFuture(()) + } + + func didFinishRequest(task: HTTPClient.Task) throws -> HTTPClient.Response { + .init(host: "", status: .ok, headers: [:], body: nil) + } + } + let response = ResponseDelegate(context: context) + _ = try app.http.client.shared.execute( + request: request, + delegate: response + ).wait() + + XCTAssertEqual(context.server, ["foo", "bar", "baz"]) + XCTAssertEqual(context.client, ["foo", "bar", "baz"]) + } + + func testSkipStreaming() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.on(.POST, "echo", body: .stream) { request in + "hello, world" + } + + let port = 1337 + app.http.server.configuration.port = port + try app.start() + + let request = try HTTPClient.Request( + url: "http://localhost:\(port)/echo", + method: .POST, + headers: [ + "transfer-encoding": "chunked" + ], + body: .stream(length: nil, { stream in + stream.write(.byteBuffer(.init(string: "foo"))).flatMap { + stream.write(.byteBuffer(.init(string: "bar"))) + }.flatMap { + stream.write(.byteBuffer(.init(string: "baz"))) + } + }) + ) + + let a = try app.http.client.shared.execute(request: request).wait() + XCTAssertEqual(a.status, .ok) + let b = try app.http.client.shared.execute(request: request).wait() + XCTAssertEqual(b.status, .ok) + } + + override class func setUp() { + XCTAssertTrue(isLoggingConfigured) + } } extension Application.Servers.Provider { diff --git a/Tests/VaporTests/Utilities/TestLogging.swift b/Tests/VaporTests/Utilities/TestLogging.swift new file mode 100644 index 0000000000..823c8139fb --- /dev/null +++ b/Tests/VaporTests/Utilities/TestLogging.swift @@ -0,0 +1,15 @@ +import Foundation +import Logging + +let isLoggingConfigured: Bool = { + LoggingSystem.bootstrap { label in + var handler = StreamLogHandler.standardOutput(label: label) + handler.logLevel = env("LOG_LEVEL").flatMap { Logger.Level(rawValue: $0) } ?? .debug + return handler + } + return true +}() + +func env(_ name: String) -> String? { + ProcessInfo.processInfo.environment[name] +}