Skip to content

Commit

Permalink
HTTPServerErrorHandler: Improve Error Handling for HTTPParserError (#…
Browse files Browse the repository at this point in the history
…2922)

* fix: HTTPServerErrorHandler catch HTTPParserError
Refs: #2921
* add reference to https://github.com/apple/swift-nio/blob/00341c92770e0a7bebdc5fda783f08765eb3ff56/Sources/NIOHTTP1/HTTPServerProtocolErrorHandler.swift
* set hasUnterminatedResponse in write
* error-handler needs to be before response encoder
* add test for invalid http
* rephrase and improve comments slightly
* remove http1 error handler from http2 pipeline
---------

Co-authored-by: Tim Condon <0xTim@users.noreply.github.com>
Co-authored-by: Gwynne Raskind <gwynne@vapor.codes>
  • Loading branch information
3 people committed Sep 22, 2023
1 parent d79fad4 commit 1f2b44b
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 7 deletions.
79 changes: 73 additions & 6 deletions Sources/Vapor/HTTP/Server/HTTPServer.swift
Expand Up @@ -452,17 +452,80 @@ private final class HTTPServerConnection: Sendable {
}
}

final class HTTPServerErrorHandler: ChannelInboundHandler {
/// A simple channel handler that catches errors emitted by parsing HTTP requests
/// and sends 400 Bad Request responses.
///
/// This channel handler provides the basic behaviour that the majority of simple HTTP
/// servers want. This handler does not suppress the parser errors: it allows them to
/// continue to pass through the pipeline so that other handlers (e.g. logging ones) can
/// deal with the error.
///
/// adapted from: https://github.com/apple/swift-nio/blob/00341c92770e0a7bebdc5fda783f08765eb3ff56/Sources/NIOHTTP1/HTTPServerProtocolErrorHandler.swift
final class HTTP1ServerErrorHandler: ChannelDuplexHandler, RemovableChannelHandler {
typealias InboundIn = Never
typealias InboundOut = Never
typealias OutboundIn = HTTPServerResponsePart
typealias OutboundOut = HTTPServerResponsePart
let logger: Logger
private var hasUnterminatedResponse: Bool = false

init(logger: Logger) {
self.logger = logger
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
self.logger.debug("Unhandled HTTP server error: \(error)")
context.close(mode: .output, promise: nil)
if let error = error as? HTTPParserError {
self.makeHTTPParserErrorResponse(context: context, error: error)
}

// Now pass the error on in case someone else wants to see it.
// In the Vapor ChannelPipeline the connection will eventually
// be closed by the NIOCloseOnErrorHandler
context.fireErrorCaught(error)
}

private func makeHTTPParserErrorResponse(context: ChannelHandlerContext, error: HTTPParserError) {
// Any HTTPParserError is automatically fatal, and we don't actually need (or want) to
// provide that error to the client: we just want to inform them something went wrong
// and then close off the pipeline. However, we can only send an
// HTTP error response if another response hasn't started yet.
//
// A side note here: we cannot block or do any delayed work.
// The channel might be closed right after we return from this function.
if !self.hasUnterminatedResponse {
self.logger.debug("Bad Request - Invalid HTTP: \(error)")
let headers = HTTPHeaders([("Connection", "close"), ("Content-Length", "0")])
let head = HTTPResponseHead(version: .http1_1, status: .badRequest, headers: headers)
context.write(self.wrapOutboundOut(.head(head)), promise: nil)
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
}
}

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let res = self.unwrapOutboundIn(data)
switch res {
case .head(let head) where head.isInformational:
precondition(!self.hasUnterminatedResponse)
case .head:
precondition(!self.hasUnterminatedResponse)
self.hasUnterminatedResponse = true
case .body:
precondition(self.hasUnterminatedResponse)
case .end:
precondition(self.hasUnterminatedResponse)
self.hasUnterminatedResponse = false
}
context.write(data, promise: promise)
}
}

extension HTTPResponseHead {
/// Determines if the head is purely informational. If a head is informational another head will follow this
/// head eventually.
///
/// This is also from SwiftNIO
var isInformational: Bool {
100 <= self.status.code && self.status.code < 200 && self.status.code != 101
}
}

Expand Down Expand Up @@ -496,7 +559,8 @@ extension ChannelPipeline {
handlers.append(handler)

return self.addHandlers(handlers).flatMap {
self.addHandler(HTTPServerErrorHandler(logger: configuration.logger))
// close the connection in case of any errors
self.addHandler(NIOCloseOnErrorHandler())
}
}

Expand Down Expand Up @@ -544,6 +608,9 @@ extension ChannelPipeline {
break
}

let errorHandler = HTTP1ServerErrorHandler(logger: configuration.logger)
handlers.append(errorHandler)

// add NIO -> HTTP response encoder
let serverResEncoder = HTTPServerResponseEncoder(
serverHeader: configuration.serverName,
Expand All @@ -568,9 +635,9 @@ extension ChannelPipeline {
handlers.append(upgrader)
handlers.append(handler)

// wait to add delegate as final step
return self.addHandlers(handlers).flatMap {
self.addHandler(HTTPServerErrorHandler(logger: configuration.logger))
// close the connection in case of any errors
self.addHandler(NIOCloseOnErrorHandler())
}
}
}
Expand Down
30 changes: 29 additions & 1 deletion Tests/VaporTests/PipelineTests.swift
@@ -1,4 +1,5 @@
@testable import Vapor
import enum NIOHTTP1.HTTPParserError
import XCTest
import NIOEmbedded
import NIOCore
Expand Down Expand Up @@ -107,12 +108,39 @@ final class PipelineTests: XCTestCase {
).wait()

XCTAssertEqual(channel.isActive, true)
try channel.writeInbound(ByteBuffer(string: "POST /echo HTTP/1.1\r\n\r\n"))
// throws a notEnoughBytes error which is good
XCTAssertThrowsError(try channel.writeInbound(ByteBuffer(string: "POST /echo HTTP/1.1\r\n\r\n")))
XCTAssertEqual(channel.isActive, false)
try XCTAssertContains(channel.readOutbound(as: ByteBuffer.self)?.string, "HTTP/1.1 200 OK")
try XCTAssertEqual(channel.readOutbound(as: ByteBuffer.self)?.string, "a")
try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string)
}

func testInvalidHttp() throws {
let app = Application(.testing)
defer { app.shutdown() }

let channel = EmbeddedChannel()
try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait()
try channel.pipeline.addVaporHTTP1Handlers(
application: app,
responder: app.responder,
configuration: app.http.server.configuration
).wait()

XCTAssertEqual(channel.isActive, true)
let request = ByteBuffer(string: "POST /echo/þ HTTP/1.1\r\n\r\n")
XCTAssertThrowsError(try channel.writeInbound(request)) { error in
if let error = error as? HTTPParserError {
XCTAssertEqual(error, HTTPParserError.invalidURL)
} else {
XCTFail("Caught error \"\(error)\"")
}
}
XCTAssertEqual(channel.isActive, false)
try XCTAssertContains(channel.readOutbound(as: ByteBuffer.self)?.string, "HTTP/1.1 400 Bad Request")
try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string)
}

override class func setUp() {
XCTAssert(isLoggingConfigured)
Expand Down

0 comments on commit 1f2b44b

Please sign in to comment.