diff --git a/Sources/Vapor/Routing/Routes.swift b/Sources/Vapor/Routing/Routes.swift index 8db15565d6..63bd978009 100644 --- a/Sources/Vapor/Routing/Routes.swift +++ b/Sources/Vapor/Routing/Routes.swift @@ -1,14 +1,18 @@ public final class Routes: RoutesBuilder, CustomStringConvertible { public var all: [Route] + + /// Default value used by `HTTPBodyStreamStrategy.collect` when `maxSize` is `nil`. + public var defaultMaxBodySize: ByteCount public var description: String { return self.all.description } - + public init() { self.all = [] + self.defaultMaxBodySize = "16kb" } - + public func add(_ route: Route) { self.all.append(route) } diff --git a/Sources/Vapor/Routing/RoutesBuilder+Group.swift b/Sources/Vapor/Routing/RoutesBuilder+Group.swift index 0ca8939eb7..90f06d457c 100644 --- a/Sources/Vapor/Routing/RoutesBuilder+Group.swift +++ b/Sources/Vapor/Routing/RoutesBuilder+Group.swift @@ -40,7 +40,7 @@ private final class HTTPRoutesGroup: RoutesBuilder { /// Additional components. let path: [PathComponent] - + /// Creates a new `PathGroup`. init(root: RoutesBuilder, path: [PathComponent]) { self.root = root diff --git a/Sources/Vapor/Routing/RoutesBuilder+Method.swift b/Sources/Vapor/Routing/RoutesBuilder+Method.swift index b2d8b77958..958e9c94df 100644 --- a/Sources/Vapor/Routing/RoutesBuilder+Method.swift +++ b/Sources/Vapor/Routing/RoutesBuilder+Method.swift @@ -1,12 +1,12 @@ /// Determines how an incoming HTTP request's body is collected. public enum HTTPBodyStreamStrategy { - /// The HTTP request's body will be collected into memory before the route handler is - /// called. The max size will determine how much data can be collected. The default is - /// `1 << 14`. + /// The HTTP request's body will be collected into memory up to a maximum size + /// before the route handler is called. The application's configured default max body + /// size will be used unless otherwise specified. /// - /// See `collect(maxSize:)` to set a lower max body size. + /// See `collect(maxSize:)` to specify a custom max collection size. public static var collect: HTTPBodyStreamStrategy { - return .collect(maxSize: 1 << 14) + return .collect(maxSize: nil) } /// The HTTP request's body will not be collected first before the route handler is called @@ -16,9 +16,11 @@ public enum HTTPBodyStreamStrategy { /// The HTTP request's body will be collected into memory before the route handler is /// called. /// - /// If a `maxSize` is supplied, the request body size in bytes will be limited. Requests - /// exceeding that size will result in an error. - case collect(maxSize: Int?) + /// `maxSize` Limits the maximum amount of memory in bytes that will be used to + /// collect a streaming body. Streaming requests exceeding that size will result in an error. + /// Passing `nil` results in the application's default max body size being used. This + /// parameter does not affect non-streaming requests. + case collect(maxSize: ByteCount?) } extension RoutesBuilder { @@ -147,8 +149,10 @@ extension RoutesBuilder { { let responder = BasicResponder { request in if case .collect(let max) = body, request.body.data == nil { - return request.body.collect(max: max).flatMapThrowing { _ in - return try closure(request) + return request.body.collect( + max: max?.value ?? request.application.routes.defaultMaxBodySize.value + ).flatMapThrowing { _ in + try closure(request) }.encodeResponse(for: request) } else { return try closure(request) diff --git a/Sources/Vapor/Utilities/ByteCount.swift b/Sources/Vapor/Utilities/ByteCount.swift new file mode 100644 index 0000000000..b18e8cda76 --- /dev/null +++ b/Sources/Vapor/Utilities/ByteCount.swift @@ -0,0 +1,66 @@ +import Foundation + +/// Represents a number of bytes: +/// +/// let bytes: ByteCount = "1mb" +/// print(bytes.value) // 1048576 +/// +/// let bytes: ByteCount = 1_000_000 +/// print(bytes.value) // 1000000 + +/// let bytes: ByteCount = "2kb" +/// print(bytes.value) // 2048 +public struct ByteCount: Equatable { + /// The value in Bytes + public let value: Int + + public init(value: Int) { + self.value = value + } +} + +extension ByteCount: ExpressibleByIntegerLiteral { + /// Initializes the `ByteCount` with the raw byte count + /// - Parameter value: The number of bytes + public init(integerLiteral value: Int) { + self.value = value + } +} + +extension ByteCount: ExpressibleByStringLiteral { + /// Initializes the `ByteCount` via a descriptive string. Available suffixes are: + /// `kb`, `mb`, `gb`, `tb` + /// - Parameter value: The string value (`1mb`) + public init(stringLiteral value: String) { + // Short path if it's an int wrapped in a string + if let intValue = Int(value) { + self.value = intValue + return + } + + let validSuffixes = [ + "kb": 10, + "mb": 20, + "gb": 30, + "tb": 40 + ] + + let cleanValue = value.lowercased().trimmingCharacters(in: .whitespaces).replacingOccurrences(of: " ", with: "") + for suffix in validSuffixes { + guard cleanValue.hasSuffix(suffix.key) else { continue } + guard let stringIntValue = cleanValue.components(separatedBy: suffix.key).first else { + fatalError("Invalid string format") + } + + guard let intValue = Int(stringIntValue) else { + fatalError("Invalid int value: \(stringIntValue)") + } + + self.value = intValue << suffix.value + return + } + + // Assert failure here because all cases are handled in the above loop + fatalError("Could not parse byte count string: \(value)") + } +} diff --git a/Tests/VaporTests/RouteTests.swift b/Tests/VaporTests/RouteTests.swift index 7457bd0fbe..d83de880f1 100644 --- a/Tests/VaporTests/RouteTests.swift +++ b/Tests/VaporTests/RouteTests.swift @@ -268,4 +268,38 @@ final class RouteTests: XCTestCase { XCTAssertEqual(res.body.string, "bar") } } + + func testConfigurableMaxBodySize() throws { + let app = Application(.testing) + defer { app.shutdown() } + + XCTAssertEqual(app.routes.defaultMaxBodySize, 16384) + app.routes.defaultMaxBodySize = 1 + XCTAssertEqual(app.routes.defaultMaxBodySize, 1) + + app.on(.POST, "default") { request in + HTTPStatus.ok + } + app.on(.POST, "1kb", body: .collect(maxSize: "1kb")) { request in + HTTPStatus.ok + } + app.on(.POST, "1mb", body: .collect(maxSize: "1mb")) { request in + HTTPStatus.ok + } + app.on(.POST, "1gb", body: .collect(maxSize: "1gb")) { request in + HTTPStatus.ok + } + + var buffer = ByteBufferAllocator().buffer(capacity: 0) + buffer.writeBytes(Array(repeating: 0, count: 500_000)) + try app.testable(method: .running).test(.POST, "/default", body: buffer) { res in + XCTAssertEqual(res.status, .payloadTooLarge) + }.test(.POST, "/1kb", body: buffer) { res in + XCTAssertEqual(res.status, .payloadTooLarge) + }.test(.POST, "/1mb", body: buffer) { res in + XCTAssertEqual(res.status, .ok) + }.test(.POST, "/1gb", body: buffer) { res in + XCTAssertEqual(res.status, .ok) + } + } } diff --git a/Tests/VaporTests/ServerTests.swift b/Tests/VaporTests/ServerTests.swift index 81846c3ab4..b030ab3262 100644 --- a/Tests/VaporTests/ServerTests.swift +++ b/Tests/VaporTests/ServerTests.swift @@ -95,7 +95,7 @@ final class ServerTests: XCTestCase { let payload = [UInt8].random(count: 1 << 20) - app.on(.POST, "payload", body: .collect(maxSize: nil)) { req -> HTTPStatus in + app.on(.POST, "payload", body: .collect(maxSize: "1gb")) { req -> HTTPStatus in guard let data = req.body.data else { throw Abort(.internalServerError) } diff --git a/Tests/VaporTests/UtilityTests.swift b/Tests/VaporTests/UtilityTests.swift index 03a2d0a7b8..a218cfddd6 100644 --- a/Tests/VaporTests/UtilityTests.swift +++ b/Tests/VaporTests/UtilityTests.swift @@ -21,4 +21,24 @@ final class UtilityTests: XCTestCase { XCTAssertEqual(data.base32EncodedString(), "AEBAGBA") XCTAssertEqual(Data(base32Encoded: "AEBAGBA"), data) } + + func testByteCount() throws { + let twoKbUpper: ByteCount = "2 KB" + XCTAssertEqual(twoKbUpper.value, 2_048) + + let twoKb: ByteCount = "2kb" + XCTAssertEqual(twoKb.value, 2_048) + + let oneMb: ByteCount = "1mb" + XCTAssertEqual(oneMb.value, 1_048_576) + + let oneGb: ByteCount = "1gb" + XCTAssertEqual(oneGb.value, 1_073_741_824) + + let oneTb: ByteCount = "1tb" + XCTAssertEqual(oneTb.value, 1_099_511_627_776) + + let intBytes: ByteCount = 1_000_000 + XCTAssertEqual(intBytes.value, 1_000_000) + } }