From b94248191e66418de7a4dab34a9e0636eb83845e Mon Sep 17 00:00:00 2001 From: Jonas Platte Date: Wed, 28 Sep 2022 22:20:47 +0200 Subject: [PATCH] Add RequestExt::{with_limited_body, into_limited_body} (#1420) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Move RequestExt and RequestPartsExt into axum-core * Add RequestExt::into_limited_body … and use it for Bytes extraction. * Add RequestExt::with_limited_body … and use it for Multipart extraction. Co-authored-by: David Pedersen --- {axum => axum-core}/src/ext_traits/mod.rs | 23 ++++++++++- {axum => axum-core}/src/ext_traits/request.rs | 39 +++++++++++++++++-- .../src/ext_traits/request_parts.rs | 13 +++++-- axum-core/src/extract/mod.rs | 1 + axum-core/src/extract/request_parts.rs | 30 ++++---------- axum-core/src/lib.rs | 3 ++ axum/CHANGELOG.md | 2 + axum/src/extract/multipart.rs | 12 +++--- axum/src/lib.rs | 8 ++-- .../{ext_traits/service.rs => service_ext.rs} | 0 10 files changed, 87 insertions(+), 44 deletions(-) rename {axum => axum-core}/src/ext_traits/mod.rs (52%) rename {axum => axum-core}/src/ext_traits/request.rs (79%) rename {axum => axum-core}/src/ext_traits/request_parts.rs (89%) rename axum/src/{ext_traits/service.rs => service_ext.rs} (100%) diff --git a/axum/src/ext_traits/mod.rs b/axum-core/src/ext_traits/mod.rs similarity index 52% rename from axum/src/ext_traits/mod.rs rename to axum-core/src/ext_traits/mod.rs index ca039f6ca9..02595fbeac 100644 --- a/axum/src/ext_traits/mod.rs +++ b/axum-core/src/ext_traits/mod.rs @@ -1,15 +1,34 @@ pub(crate) mod request; pub(crate) mod request_parts; -pub(crate) mod service; #[cfg(test)] mod tests { use std::convert::Infallible; + use crate::extract::{FromRef, FromRequestParts}; use async_trait::async_trait; - use axum_core::extract::{FromRef, FromRequestParts}; use http::request::Parts; + #[derive(Debug, Default, Clone, Copy)] + pub(crate) struct State(pub(crate) S); + + #[async_trait] + impl FromRequestParts for State + where + InnerState: FromRef, + OuterState: Send + Sync, + { + type Rejection = Infallible; + + async fn from_request_parts( + _parts: &mut Parts, + state: &OuterState, + ) -> Result { + let inner_state = InnerState::from_ref(state); + Ok(Self(inner_state)) + } + } + // some extractor that requires the state, such as `SignedCookieJar` pub(crate) struct RequiresState(pub(crate) String); diff --git a/axum/src/ext_traits/request.rs b/axum-core/src/ext_traits/request.rs similarity index 79% rename from axum/src/ext_traits/request.rs rename to axum-core/src/ext_traits/request.rs index f4fad5b8be..05ca77ea2a 100644 --- a/axum/src/ext_traits/request.rs +++ b/axum-core/src/ext_traits/request.rs @@ -1,6 +1,7 @@ -use axum_core::extract::{FromRequest, FromRequestParts}; +use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts}; use futures_util::future::BoxFuture; use http::Request; +use http_body::Limited; mod sealed { pub trait Sealed {} @@ -48,6 +49,16 @@ pub trait RequestExt: sealed::Sealed + Sized { where E: FromRequestParts + 'static, S: Send + Sync; + + /// Apply the [default body limit](crate::extract::DefaultBodyLimit). + /// + /// If it is disabled, return the request as-is in `Err`. + fn with_limited_body(self) -> Result>, Request>; + + /// Consumes the request, returning the body wrapped in [`Limited`] if a + /// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the + /// default limit is disabled. + fn into_limited_body(self) -> Result, B>; } impl RequestExt for Request @@ -105,14 +116,36 @@ where result }) } + + fn with_limited_body(self) -> Result>, Request> { + // update docs in `axum-core/src/extract/default_body_limit.rs` and + // `axum/src/docs/extract.md` if this changes + const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb + + match self.extensions().get::().copied() { + Some(DefaultBodyLimitKind::Disable) => Err(self), + Some(DefaultBodyLimitKind::Limit(limit)) => { + Ok(self.map(|b| http_body::Limited::new(b, limit))) + } + None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))), + } + } + + fn into_limited_body(self) -> Result, B> { + self.with_limited_body() + .map(Request::into_body) + .map_err(Request::into_body) + } } #[cfg(test)] mod tests { use super::*; - use crate::{ext_traits::tests::RequiresState, extract::State}; + use crate::{ + ext_traits::tests::{RequiresState, State}, + extract::FromRef, + }; use async_trait::async_trait; - use axum_core::extract::FromRef; use http::Method; use hyper::Body; diff --git a/axum/src/ext_traits/request_parts.rs b/axum-core/src/ext_traits/request_parts.rs similarity index 89% rename from axum/src/ext_traits/request_parts.rs rename to axum-core/src/ext_traits/request_parts.rs index c35f7c7445..7e136d58f1 100644 --- a/axum/src/ext_traits/request_parts.rs +++ b/axum-core/src/ext_traits/request_parts.rs @@ -1,4 +1,4 @@ -use axum_core::extract::FromRequestParts; +use crate::extract::FromRequestParts; use futures_util::future::BoxFuture; use http::request::Parts; @@ -53,9 +53,11 @@ mod tests { use std::convert::Infallible; use super::*; - use crate::{ext_traits::tests::RequiresState, extract::State}; + use crate::{ + ext_traits::tests::{RequiresState, State}, + extract::FromRef, + }; use async_trait::async_trait; - use axum_core::extract::FromRef; use http::{Method, Request}; #[tokio::test] @@ -73,7 +75,10 @@ mod tests { let state = "state".to_owned(); - let State(extracted_state): State = parts.extract_with_state(&state).await.unwrap(); + let State(extracted_state): State = parts + .extract_with_state::, String>(&state) + .await + .unwrap(); assert_eq!(extracted_state, state); } diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index 96b62d2849..83b7e6fee5 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -16,6 +16,7 @@ mod from_ref; mod request_parts; mod tuple; +pub(crate) use self::default_body_limit::DefaultBodyLimitKind; pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef}; mod private { diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index fc151a2da0..05d7d7277b 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,7 +1,5 @@ -use super::{ - default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, FromRequestParts, -}; -use crate::BoxError; +use super::{rejection::*, FromRequest, FromRequestParts}; +use crate::{BoxError, RequestExt}; use async_trait::async_trait; use bytes::Bytes; use http::{request::Parts, HeaderMap, Method, Request, Uri, Version}; @@ -84,27 +82,13 @@ where type Rejection = BytesRejection; async fn from_request(req: Request, _: &S) -> Result { - // update docs in `axum-core/src/extract/default_body_limit.rs` and - // `axum/src/docs/extract.md` if this changes - const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb - - let limit_kind = req.extensions().get::().copied(); - let bytes = match limit_kind { - Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(req.into_body()) + let bytes = match req.into_limited_body() { + Ok(limited_body) => crate::body::to_bytes(limited_body) + .await + .map_err(FailedToBufferBody::from_err)?, + Err(unlimited_body) => crate::body::to_bytes(unlimited_body) .await .map_err(FailedToBufferBody::from_err)?, - Some(DefaultBodyLimitKind::Limit(limit)) => { - let body = http_body::Limited::new(req.into_body(), limit); - crate::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)? - } - None => { - let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT); - crate::body::to_bytes(body) - .await - .map_err(FailedToBufferBody::from_err)? - } }; Ok(bytes) diff --git a/axum-core/src/lib.rs b/axum-core/src/lib.rs index 63e363e1ed..52a382dccc 100644 --- a/axum-core/src/lib.rs +++ b/axum-core/src/lib.rs @@ -52,6 +52,7 @@ pub(crate) mod macros; mod error; +mod ext_traits; pub use self::error::Error; pub mod body; @@ -60,3 +61,5 @@ pub mod response; /// Alias for a type-erased error type. pub type BoxError = Box; + +pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt}; diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 2589459bb8..7db363b81a 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 you likely need to re-enable the `tokio` feature ([#1382]) - **breaking:** `handler::{WithState, IntoService}` are merged into one type, named `HandlerService` ([#1418]) +- **changed:** The default body limit now applies to the `Multipart` extractor ([#1420]) - **added:** String and binary `From` impls have been added to `extract::ws::Message` to be more inline with `tungstenite` ([#1421]) @@ -54,6 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1408]: https://github.com/tokio-rs/axum/pull/1408 [#1414]: https://github.com/tokio-rs/axum/pull/1414 [#1418]: https://github.com/tokio-rs/axum/pull/1418 +[#1420]: https://github.com/tokio-rs/axum/pull/1420 [#1421]: https://github.com/tokio-rs/axum/pull/1421 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 65813abdbd..76dd52dec9 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -6,6 +6,7 @@ use super::{BodyStream, FromRequest}; use crate::body::{Bytes, HttpBody}; use crate::BoxError; use async_trait::async_trait; +use axum_core::RequestExt; use futures_util::stream::Stream; use http::header::{HeaderMap, CONTENT_TYPE}; use http::Request; @@ -47,10 +48,6 @@ use std::{ /// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); /// # }; /// ``` -/// -/// For security reasons it's recommended to combine this with -/// [`RequestBodyLimitLayer`](tower_http::limit::RequestBodyLimitLayer) -/// to limit the size of the request payload. #[cfg_attr(docsrs, doc(cfg(feature = "multipart")))] #[derive(Debug)] pub struct Multipart { @@ -69,10 +66,11 @@ where async fn from_request(req: Request, state: &S) -> Result { let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?; - let stream = match BodyStream::from_request(req, state).await { - Ok(stream) => stream, - Err(err) => match err {}, + let stream_result = match req.with_limited_body() { + Ok(limited) => BodyStream::from_request(limited, state).await, + Err(unlimited) => BodyStream::from_request(unlimited, state).await, }; + let stream = stream_result.unwrap_or_else(|err| match err {}); let multipart = multer::Multipart::new(stream, boundary); Ok(Self { inner: multipart }) } diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 97aa56b9c2..837701e416 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -434,12 +434,12 @@ #[macro_use] pub(crate) mod macros; -mod ext_traits; mod extension; #[cfg(feature = "form")] mod form; #[cfg(feature = "json")] mod json; +mod service_ext; #[cfg(feature = "headers")] mod typed_header; mod util; @@ -483,11 +483,9 @@ pub use self::typed_header::TypedHeader; pub use self::form::Form; #[doc(inline)] -pub use axum_core::{BoxError, Error}; +pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt}; #[cfg(feature = "macros")] pub use axum_macros::debug_handler; -pub use self::ext_traits::{ - request::RequestExt, request_parts::RequestPartsExt, service::ServiceExt, -}; +pub use self::service_ext::ServiceExt; diff --git a/axum/src/ext_traits/service.rs b/axum/src/service_ext.rs similarity index 100% rename from axum/src/ext_traits/service.rs rename to axum/src/service_ext.rs