Skip to content

Commit

Permalink
Add RequestExt::{with_limited_body, into_limited_body} (#1420)
Browse files Browse the repository at this point in the history
* Move RequestExt and RequestPartsExt into axum-core

* Add RequestExt::into_limited_body

… and use it for Bytes extraction.

* Add RequestExt::with_limited_body

… and use it for Multipart extraction.

Co-authored-by: David Pedersen <david.pdrsn@gmail.com>
  • Loading branch information
jplatte and davidpdrsn committed Sep 28, 2022
1 parent be54583 commit b942481
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 44 deletions.
23 changes: 21 additions & 2 deletions axum/src/ext_traits/mod.rs → axum-core/src/ext_traits/mod.rs
@@ -1,15 +1,34 @@
pub(crate) mod request;
pub(crate) mod request_parts;
pub(crate) mod service;

#[cfg(test)]
mod tests {
use std::convert::Infallible;

use crate::extract::{FromRef, FromRequestParts};
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;

#[derive(Debug, Default, Clone, Copy)]
pub(crate) struct State<S>(pub(crate) S);

#[async_trait]
impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where
InnerState: FromRef<OuterState>,
OuterState: Send + Sync,
{
type Rejection = Infallible;

async fn from_request_parts(
_parts: &mut Parts,
state: &OuterState,
) -> Result<Self, Self::Rejection> {
let inner_state = InnerState::from_ref(state);
Ok(Self(inner_state))
}
}

// some extractor that requires the state, such as `SignedCookieJar`
pub(crate) struct RequiresState(pub(crate) String);

Expand Down
@@ -1,6 +1,7 @@
use axum_core::extract::{FromRequest, FromRequestParts};
use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts};
use futures_util::future::BoxFuture;
use http::Request;
use http_body::Limited;

mod sealed {
pub trait Sealed<B> {}
Expand Down Expand Up @@ -48,6 +49,16 @@ pub trait RequestExt<B>: sealed::Sealed<B> + Sized {
where
E: FromRequestParts<S> + 'static,
S: Send + Sync;

/// Apply the [default body limit](crate::extract::DefaultBodyLimit).
///
/// If it is disabled, return the request as-is in `Err`.
fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>>;

/// Consumes the request, returning the body wrapped in [`Limited`] if a
/// [default limit](crate::extract::DefaultBodyLimit) is in place, or not wrapped if the
/// default limit is disabled.
fn into_limited_body(self) -> Result<Limited<B>, B>;
}

impl<B> RequestExt<B> for Request<B>
Expand Down Expand Up @@ -105,14 +116,36 @@ where
result
})
}

fn with_limited_body(self) -> Result<Request<Limited<B>>, Request<B>> {
// update docs in `axum-core/src/extract/default_body_limit.rs` and
// `axum/src/docs/extract.md` if this changes
const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb

match self.extensions().get::<DefaultBodyLimitKind>().copied() {
Some(DefaultBodyLimitKind::Disable) => Err(self),
Some(DefaultBodyLimitKind::Limit(limit)) => {
Ok(self.map(|b| http_body::Limited::new(b, limit)))
}
None => Ok(self.map(|b| http_body::Limited::new(b, DEFAULT_LIMIT))),
}
}

fn into_limited_body(self) -> Result<Limited<B>, B> {
self.with_limited_body()
.map(Request::into_body)
.map_err(Request::into_body)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use crate::{
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::Method;
use hyper::Body;

Expand Down
@@ -1,4 +1,4 @@
use axum_core::extract::FromRequestParts;
use crate::extract::FromRequestParts;
use futures_util::future::BoxFuture;
use http::request::Parts;

Expand Down Expand Up @@ -53,9 +53,11 @@ mod tests {
use std::convert::Infallible;

use super::*;
use crate::{ext_traits::tests::RequiresState, extract::State};
use crate::{
ext_traits::tests::{RequiresState, State},
extract::FromRef,
};
use async_trait::async_trait;
use axum_core::extract::FromRef;
use http::{Method, Request};

#[tokio::test]
Expand All @@ -73,7 +75,10 @@ mod tests {

let state = "state".to_owned();

let State(extracted_state): State<String> = parts.extract_with_state(&state).await.unwrap();
let State(extracted_state): State<String> = parts
.extract_with_state::<State<String>, String>(&state)
.await
.unwrap();

assert_eq!(extracted_state, state);
}
Expand Down
1 change: 1 addition & 0 deletions axum-core/src/extract/mod.rs
Expand Up @@ -16,6 +16,7 @@ mod from_ref;
mod request_parts;
mod tuple;

pub(crate) use self::default_body_limit::DefaultBodyLimitKind;
pub use self::{default_body_limit::DefaultBodyLimit, from_ref::FromRef};

mod private {
Expand Down
30 changes: 7 additions & 23 deletions axum-core/src/extract/request_parts.rs
@@ -1,7 +1,5 @@
use super::{
default_body_limit::DefaultBodyLimitKind, rejection::*, FromRequest, FromRequestParts,
};
use crate::BoxError;
use super::{rejection::*, FromRequest, FromRequestParts};
use crate::{BoxError, RequestExt};
use async_trait::async_trait;
use bytes::Bytes;
use http::{request::Parts, HeaderMap, Method, Request, Uri, Version};
Expand Down Expand Up @@ -84,27 +82,13 @@ where
type Rejection = BytesRejection;

async fn from_request(req: Request<B>, _: &S) -> Result<Self, Self::Rejection> {
// update docs in `axum-core/src/extract/default_body_limit.rs` and
// `axum/src/docs/extract.md` if this changes
const DEFAULT_LIMIT: usize = 2_097_152; // 2 mb

let limit_kind = req.extensions().get::<DefaultBodyLimitKind>().copied();
let bytes = match limit_kind {
Some(DefaultBodyLimitKind::Disable) => crate::body::to_bytes(req.into_body())
let bytes = match req.into_limited_body() {
Ok(limited_body) => crate::body::to_bytes(limited_body)
.await
.map_err(FailedToBufferBody::from_err)?,
Err(unlimited_body) => crate::body::to_bytes(unlimited_body)
.await
.map_err(FailedToBufferBody::from_err)?,
Some(DefaultBodyLimitKind::Limit(limit)) => {
let body = http_body::Limited::new(req.into_body(), limit);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
None => {
let body = http_body::Limited::new(req.into_body(), DEFAULT_LIMIT);
crate::body::to_bytes(body)
.await
.map_err(FailedToBufferBody::from_err)?
}
};

Ok(bytes)
Expand Down
3 changes: 3 additions & 0 deletions axum-core/src/lib.rs
Expand Up @@ -52,6 +52,7 @@
pub(crate) mod macros;

mod error;
mod ext_traits;
pub use self::error::Error;

pub mod body;
Expand All @@ -60,3 +61,5 @@ pub mod response;

/// Alias for a type-erased error type.
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;

pub use self::ext_traits::{request::RequestExt, request_parts::RequestPartsExt};
2 changes: 2 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -40,6 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
you likely need to re-enable the `tokio` feature ([#1382])
- **breaking:** `handler::{WithState, IntoService}` are merged into one type,
named `HandlerService` ([#1418])
- **changed:** The default body limit now applies to the `Multipart` extractor ([#1420])
- **added:** String and binary `From` impls have been added to `extract::ws::Message`
to be more inline with `tungstenite` ([#1421])

Expand All @@ -54,6 +55,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1408]: https://github.com/tokio-rs/axum/pull/1408
[#1414]: https://github.com/tokio-rs/axum/pull/1414
[#1418]: https://github.com/tokio-rs/axum/pull/1418
[#1420]: https://github.com/tokio-rs/axum/pull/1420
[#1421]: https://github.com/tokio-rs/axum/pull/1421

# 0.6.0-rc.2 (10. September, 2022)
Expand Down
12 changes: 5 additions & 7 deletions axum/src/extract/multipart.rs
Expand Up @@ -6,6 +6,7 @@ use super::{BodyStream, FromRequest};
use crate::body::{Bytes, HttpBody};
use crate::BoxError;
use async_trait::async_trait;
use axum_core::RequestExt;
use futures_util::stream::Stream;
use http::header::{HeaderMap, CONTENT_TYPE};
use http::Request;
Expand Down Expand Up @@ -47,10 +48,6 @@ use std::{
/// # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// For security reasons it's recommended to combine this with
/// [`RequestBodyLimitLayer`](tower_http::limit::RequestBodyLimitLayer)
/// to limit the size of the request payload.
#[cfg_attr(docsrs, doc(cfg(feature = "multipart")))]
#[derive(Debug)]
pub struct Multipart {
Expand All @@ -69,10 +66,11 @@ where

async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
let boundary = parse_boundary(req.headers()).ok_or(InvalidBoundary)?;
let stream = match BodyStream::from_request(req, state).await {
Ok(stream) => stream,
Err(err) => match err {},
let stream_result = match req.with_limited_body() {
Ok(limited) => BodyStream::from_request(limited, state).await,
Err(unlimited) => BodyStream::from_request(unlimited, state).await,
};
let stream = stream_result.unwrap_or_else(|err| match err {});
let multipart = multer::Multipart::new(stream, boundary);
Ok(Self { inner: multipart })
}
Expand Down
8 changes: 3 additions & 5 deletions axum/src/lib.rs
Expand Up @@ -434,12 +434,12 @@
#[macro_use]
pub(crate) mod macros;

mod ext_traits;
mod extension;
#[cfg(feature = "form")]
mod form;
#[cfg(feature = "json")]
mod json;
mod service_ext;
#[cfg(feature = "headers")]
mod typed_header;
mod util;
Expand Down Expand Up @@ -483,11 +483,9 @@ pub use self::typed_header::TypedHeader;
pub use self::form::Form;

#[doc(inline)]
pub use axum_core::{BoxError, Error};
pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt};

#[cfg(feature = "macros")]
pub use axum_macros::debug_handler;

pub use self::ext_traits::{
request::RequestExt, request_parts::RequestPartsExt, service::ServiceExt,
};
pub use self::service_ext::ServiceExt;
File renamed without changes.

0 comments on commit b942481

Please sign in to comment.