diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index c669438aaf..9f653ad2b4 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied <(T1, T2, T3, T4, T5, T6) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts> - and 26 others + and 25 others = note: required because of the requirements on the impl of `FromRequest<(), Body, axum_core::extract::private::ViaParts>` for `bool` note: required by a bound in `__axum_macros_check_handler_0_from_request_check` --> tests/debug_handler/fail/argument_not_extractor.rs:3:1 diff --git a/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.stderr b/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.stderr index 87f3ab36d1..03803249f3 100644 --- a/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.stderr +++ b/axum-macros/tests/debug_handler/fail/doesnt_implement_from_request_parts.stderr @@ -13,6 +13,6 @@ error[E0277]: the trait bound `String: FromRequestParts<()>` is not satisfied <(T1, T2, T3, T4, T5, T6) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts> - and 26 others + and 25 others = help: see issue #48214 = note: this error originates in the attribute macro `debug_handler` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr index 89a5ed55ad..ce8fd05a67 100644 --- a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr +++ b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr @@ -13,7 +13,7 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied (Response<()>, T1, T2, R) (Response<()>, T1, T2, T3, R) (Response<()>, T1, T2, T3, T4, R) - and 122 others + and 118 others note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` --> tests/debug_handler/fail/wrong_return_type.rs:4:23 | diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr index 2a2c804016..9d92f40a49 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr @@ -13,4 +13,4 @@ error[E0277]: the trait bound `String: FromRequestParts` is not satisfied <(T1, T2, T3, T4, T5, T6) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7) as FromRequestParts> <(T1, T2, T3, T4, T5, T6, T7, T8) as FromRequestParts> - and 27 others + and 26 others diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index a1323b47f3..bc9ca0e0fb 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -18,12 +18,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Add `middleware::from_extractor_with_state` and `middleware::from_extractor_with_state_arc` ([#1396]) - **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) +- **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400]) [#1371]: https://github.com/tokio-rs/axum/pull/1371 [#1387]: https://github.com/tokio-rs/axum/pull/1387 [#1389]: https://github.com/tokio-rs/axum/pull/1389 [#1396]: https://github.com/tokio-rs/axum/pull/1396 [#1397]: https://github.com/tokio-rs/axum/pull/1397 +[#1400]: https://github.com/tokio-rs/axum/pull/1400 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/extract/content_length_limit.rs b/axum/src/extract/content_length_limit.rs deleted file mode 100644 index f5775842b8..0000000000 --- a/axum/src/extract/content_length_limit.rs +++ /dev/null @@ -1,274 +0,0 @@ -use super::{rejection::*, FromRequest}; -use async_trait::async_trait; -use axum_core::{extract::FromRequestParts, response::IntoResponse}; -use http::{request::Parts, Method, Request}; -use http_body::Limited; -use std::ops::Deref; - -/// Extractor that will reject requests with a body larger than some size. -/// -/// `GET`, `HEAD`, and `OPTIONS` requests are rejected if they have a `Content-Length` header, -/// otherwise they're accepted without the body being checked. -/// -/// Note: `ContentLengthLimit` can wrap types that extract the body (for example, [`Form`] or [`Json`]) -/// if that is the case, the inner type will consume the request's body, which means the -/// `ContentLengthLimit` must come *last* if the handler uses several extractors. See -/// ["the order of extractors"][order-of-extractors] -/// -/// [order-of-extractors]: crate::extract#the-order-of-extractors -/// [`Form`]: crate::form::Form -/// [`Json`]: crate::json::Json -/// -/// # Example -/// -/// ```rust,no_run -/// use axum::{ -/// extract::ContentLengthLimit, -/// routing::post, -/// Router, -/// }; -/// -/// async fn handler(body: ContentLengthLimit) { -/// // ... -/// } -/// -/// let app = Router::new().route("/", post(handler)); -/// # async { -/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); -/// # }; -/// ``` -#[derive(Debug, Clone)] -pub struct ContentLengthLimit(pub T); - -#[async_trait] -impl FromRequest for ContentLengthLimit -where - T: FromRequest + FromRequest, Rejection = R>, - R: IntoResponse + Send, - B: Send + 'static, - S: Send + Sync, -{ - type Rejection = ContentLengthLimitRejection; - - async fn from_request(req: Request, state: &S) -> Result { - let (parts, body) = req.into_parts(); - - let value = if let Some(err) = validate::(&parts).err() { - match err { - RequestValidationError::LengthRequiredStream => { - // `Limited` supports limiting streams, so use that instead since this is a - // streaming request - let body = Limited::new(body, N as usize); - let req = Request::from_parts(parts, body); - T::from_request(req, state) - .await - .map_err(ContentLengthLimitRejection::Inner)? - } - other => return Err(other.into()), - } - } else { - let req = Request::from_parts(parts, body); - T::from_request(req, state) - .await - .map_err(ContentLengthLimitRejection::Inner)? - }; - - Ok(Self(value)) - } -} - -#[async_trait] -impl FromRequestParts for ContentLengthLimit -where - T: FromRequestParts, - T::Rejection: IntoResponse, - S: Send + Sync, -{ - type Rejection = ContentLengthLimitRejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - validate::(parts)?; - - let value = T::from_request_parts(parts, state) - .await - .map_err(ContentLengthLimitRejection::Inner)?; - - Ok(Self(value)) - } -} - -fn validate(parts: &Parts) -> Result<(), RequestValidationError> { - let content_length = parts - .headers - .get(http::header::CONTENT_LENGTH) - .and_then(|value| value.to_str().ok()?.parse::().ok()); - - match (content_length, &parts.method) { - (content_length, &(Method::GET | Method::HEAD | Method::OPTIONS)) => { - if content_length.is_some() { - return Err(RequestValidationError::ContentLengthNotAllowed); - } else if parts - .headers - .get(http::header::TRANSFER_ENCODING) - .map_or(false, |value| value.as_bytes() == b"chunked") - { - return Err(RequestValidationError::LengthRequiredChunkedHeadOrGet); - } - } - (Some(content_length), _) if content_length > N => { - return Err(RequestValidationError::PayloadTooLarge); - } - (None, _) => { - return Err(RequestValidationError::LengthRequiredStream); - } - _ => {} - } - - Ok(()) -} - -impl Deref for ContentLengthLimit { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -/// Similar to `ContentLengthLimitRejection` but more fine grained in that we can tell the -/// difference between `LengthRequiredStream` and `LengthRequiredChunkedHeadOrGet` -enum RequestValidationError { - PayloadTooLarge, - LengthRequiredStream, - LengthRequiredChunkedHeadOrGet, - ContentLengthNotAllowed, -} - -impl From for ContentLengthLimitRejection { - fn from(inner: RequestValidationError) -> Self { - match inner { - RequestValidationError::PayloadTooLarge => Self::PayloadTooLarge(PayloadTooLarge), - RequestValidationError::LengthRequiredStream - | RequestValidationError::LengthRequiredChunkedHeadOrGet => { - Self::LengthRequired(LengthRequired) - } - RequestValidationError::ContentLengthNotAllowed => { - Self::ContentLengthNotAllowed(ContentLengthNotAllowed) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - body::Bytes, - routing::{get, post}, - test_helpers::*, - Router, - }; - use http::StatusCode; - use serde::Deserialize; - - #[tokio::test] - async fn body_with_length_limit() { - use std::iter::repeat; - - #[derive(Debug, Deserialize)] - #[allow(dead_code)] - struct Input { - foo: String, - } - - const LIMIT: u64 = 8; - - let app = Router::new().route( - "/", - post(|_body: ContentLengthLimit| async {}), - ); - - let client = TestClient::new(app); - let res = client - .post("/") - .body(repeat(0_u8).take((LIMIT - 1) as usize).collect::>()) - .send() - .await; - assert_eq!(res.status(), StatusCode::OK); - - let res = client - .post("/") - .body(repeat(0_u8).take(LIMIT as usize).collect::>()) - .send() - .await; - assert_eq!(res.status(), StatusCode::OK); - - let res = client - .post("/") - .body(repeat(0_u8).take((LIMIT + 1) as usize).collect::>()) - .send() - .await; - assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); - - let chunk = repeat(0_u8).take(LIMIT as usize).collect::(); - let res = client - .post("/") - .body(reqwest::Body::wrap_stream(futures_util::stream::iter( - vec![Ok::<_, std::io::Error>(chunk)], - ))) - .send() - .await; - assert_eq!(res.status(), StatusCode::OK); - - let chunk = repeat(0_u8).take((LIMIT + 1) as usize).collect::(); - let res = client - .post("/") - .body(reqwest::Body::wrap_stream(futures_util::stream::iter( - vec![Ok::<_, std::io::Error>(chunk)], - ))) - .send() - .await; - assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); - } - - #[tokio::test] - async fn get_request_without_content_length_is_accepted() { - let app = Router::new().route("/", get(|_body: ContentLengthLimit| async {})); - - let client = TestClient::new(app); - - let res = client.get("/").send().await; - assert_eq!(res.status(), StatusCode::OK); - } - - #[tokio::test] - async fn get_request_with_content_length_is_rejected() { - let app = Router::new().route("/", get(|_body: ContentLengthLimit| async {})); - - let client = TestClient::new(app); - - let res = client - .get("/") - .header("content-length", 3) - .body("foo") - .send() - .await; - assert_eq!(res.status(), StatusCode::BAD_REQUEST); - } - - #[tokio::test] - async fn get_request_with_chunked_encoding_is_rejected() { - let app = Router::new().route("/", get(|_body: ContentLengthLimit| async {})); - - let client = TestClient::new(app); - - let res = client - .get("/") - .header("transfer-encoding", "chunked") - .body("3\r\nfoo\r\n0\r\n\r\n") - .send() - .await; - - assert_eq!(res.status(), StatusCode::LENGTH_REQUIRED); - } -} diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index c6e8bf53fc..698f8b175b 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -9,7 +9,6 @@ pub mod rejection; #[cfg(feature = "ws")] pub mod ws; -mod content_length_limit; mod host; mod raw_query; mod request_parts; @@ -25,7 +24,6 @@ pub use axum_macros::{FromRequest, FromRequestParts}; #[allow(deprecated)] pub use self::{ connect_info::ConnectInfo, - content_length_limit::ContentLengthLimit, host::Host, path::Path, raw_query::RawQuery, diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 9d2a47c92d..65813abdbd 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -49,7 +49,8 @@ use std::{ /// ``` /// /// For security reasons it's recommended to combine this with -/// [`ContentLengthLimit`](super::ContentLengthLimit) to limit the size of the request payload. +/// [`RequestBodyLimitLayer`](tower_http::limit::RequestBodyLimitLayer) +/// to limit the size of the request payload. #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] #[derive(Debug)] pub struct Multipart { diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index d6f11c31fa..01f8d84918 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -47,30 +47,6 @@ define_rejection! { pub struct MissingExtension(Error); } -define_rejection! { - #[status = PAYLOAD_TOO_LARGE] - #[body = "Request payload is too large"] - /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if - /// the request body is too large. - pub struct PayloadTooLarge; -} - -define_rejection! { - #[status = LENGTH_REQUIRED] - #[body = "Content length header is required"] - /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if - /// the request is missing the `Content-Length` header or it is invalid. - pub struct LengthRequired; -} - -define_rejection! { - #[status = BAD_REQUEST] - #[body = "`GET`, `HEAD`, `OPTIONS` requests are not allowed to have a `Content-Length` header"] - /// Rejection type for [`ContentLengthLimit`](super::ContentLengthLimit) if - /// the request is `GET`, `HEAD`, or `OPTIONS` and has a `Content-Length` header. - pub struct ContentLengthNotAllowed; -} - define_rejection! { #[status = INTERNAL_SERVER_ERROR] #[body = "No paths parameters found for matched route"] @@ -216,64 +192,5 @@ composite_rejection! { } } -/// Rejection used for [`ContentLengthLimit`](super::ContentLengthLimit). -/// -/// Contains one variant for each way the -/// [`ContentLengthLimit`](super::ContentLengthLimit) extractor can fail. -#[derive(Debug)] -#[non_exhaustive] -pub enum ContentLengthLimitRejection { - #[allow(missing_docs)] - PayloadTooLarge(PayloadTooLarge), - #[allow(missing_docs)] - LengthRequired(LengthRequired), - #[allow(missing_docs)] - ContentLengthNotAllowed(ContentLengthNotAllowed), - #[allow(missing_docs)] - Inner(T), -} - -impl IntoResponse for ContentLengthLimitRejection -where - T: IntoResponse, -{ - fn into_response(self) -> Response { - match self { - Self::PayloadTooLarge(inner) => inner.into_response(), - Self::LengthRequired(inner) => inner.into_response(), - Self::ContentLengthNotAllowed(inner) => inner.into_response(), - Self::Inner(inner) => inner.into_response(), - } - } -} - -impl std::fmt::Display for ContentLengthLimitRejection -where - T: std::fmt::Display, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::PayloadTooLarge(inner) => inner.fmt(f), - Self::LengthRequired(inner) => inner.fmt(f), - Self::ContentLengthNotAllowed(inner) => inner.fmt(f), - Self::Inner(inner) => inner.fmt(f), - } - } -} - -impl std::error::Error for ContentLengthLimitRejection -where - T: std::error::Error + 'static, -{ - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - Self::PayloadTooLarge(inner) => Some(inner), - Self::LengthRequired(inner) => Some(inner), - Self::ContentLengthNotAllowed(inner) => Some(inner), - Self::Inner(inner) => Some(inner), - } - } -} - #[cfg(feature = "headers")] pub use crate::typed_header::{TypedHeaderRejection, TypedHeaderRejectionReason}; diff --git a/examples/key-value-store/Cargo.toml b/examples/key-value-store/Cargo.toml index e4c9063cbf..73fb191b5d 100644 --- a/examples/key-value-store/Cargo.toml +++ b/examples/key-value-store/Cargo.toml @@ -8,6 +8,12 @@ publish = false axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tower = { version = "0.4", features = ["util", "timeout", "load-shed", "limit"] } -tower-http = { version = "0.3.0", features = ["add-extension", "auth", "compression-full", "trace"] } +tower-http = { version = "0.3.0", features = [ + "add-extension", + "auth", + "compression-full", + "limit", + "trace", +] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 4f46af7430..1d33b36447 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -9,7 +9,7 @@ use axum::{ body::Bytes, error_handling::HandleErrorLayer, - extract::{ContentLengthLimit, Path, State}, + extract::{DefaultBodyLimit, Path, State}, handler::Handler, http::StatusCode, response::IntoResponse, @@ -25,7 +25,8 @@ use std::{ }; use tower::{BoxError, ServiceBuilder}; use tower_http::{ - auth::RequireAuthorizationLayer, compression::CompressionLayer, trace::TraceLayer, + auth::RequireAuthorizationLayer, compression::CompressionLayer, limit::RequestBodyLimitLayer, + trace::TraceLayer, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -48,7 +49,12 @@ async fn main() { // Add compression to `kv_get` get(kv_get.layer(CompressionLayer::new())) // But don't compress `kv_set` - .post(kv_set), + .post_service( + ServiceBuilder::new() + .layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new(1024 * 5_000 /* ~5mb */)) + .service(kv_set.with_state(Arc::clone(&shared_state))), + ), ) .route("/keys", get(list_keys)) // Nest our admin routes under `/admin` @@ -94,11 +100,7 @@ async fn kv_get( } } -async fn kv_set( - Path(key): Path, - State(state): State, - ContentLengthLimit(bytes): ContentLengthLimit, // ~5mb -) { +async fn kv_set(Path(key): Path, State(state): State, bytes: Bytes) { state.write().unwrap().db.insert(key, bytes); } diff --git a/examples/multipart-form/Cargo.toml b/examples/multipart-form/Cargo.toml index fc8f971a86..392c52c46f 100644 --- a/examples/multipart-form/Cargo.toml +++ b/examples/multipart-form/Cargo.toml @@ -7,6 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["multipart"] } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.3.0", features = ["trace"] } +tower-http = { version = "0.3.0", features = ["limit", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/multipart-form/src/main.rs b/examples/multipart-form/src/main.rs index 7c6c141088..bcddc60a6d 100644 --- a/examples/multipart-form/src/main.rs +++ b/examples/multipart-form/src/main.rs @@ -5,12 +5,13 @@ //! ``` use axum::{ - extract::{ContentLengthLimit, Multipart}, + extract::{DefaultBodyLimit, Multipart}, response::Html, routing::get, Router, }; use std::net::SocketAddr; +use tower_http::limit::RequestBodyLimitLayer; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; #[tokio::main] @@ -26,6 +27,10 @@ async fn main() { // build our application with some routes let app = Router::new() .route("/", get(show_form).post(accept_form)) + .layer(DefaultBodyLimit::disable()) + .layer(RequestBodyLimitLayer::new( + 250 * 1024 * 1024, /* 250mb */ + )) .layer(tower_http::trace::TraceLayer::new_for_http()); // run it with hyper @@ -58,14 +63,7 @@ async fn show_form() -> Html<&'static str> { ) } -async fn accept_form( - ContentLengthLimit(mut multipart): ContentLengthLimit< - Multipart, - { - 250 * 1024 * 1024 /* 250mb */ - }, - >, -) { +async fn accept_form(mut multipart: Multipart) { while let Some(field) = multipart.next_field().await.unwrap() { let name = field.name().unwrap().to_string(); let file_name = field.file_name().unwrap().to_string();