From 2077d500218ebfdc9395f12a71d0d97af4656c2e Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Sun, 25 Sep 2022 13:44:10 +0200 Subject: [PATCH] Add `map_request` and friends (#1408) * Add `map_request` and friends * finish it * changelog ref * Apply suggestions from code review Co-authored-by: Jonas Platte * address review feedback * Apply suggestions from code review Co-authored-by: Jonas Platte Co-authored-by: Jonas Platte --- axum/CHANGELOG.md | 4 + axum/src/middleware/from_fn.rs | 11 +- axum/src/middleware/map_request.rs | 464 +++++++++++++++++++++++++++++ axum/src/middleware/mod.rs | 6 + 4 files changed, 483 insertions(+), 2 deletions(-) create mode 100644 axum/src/middleware/map_request.rs diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index bc9ca0e0fb..51a6622e61 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **added:** Add `middleware::from_extractor_with_state` and `middleware::from_extractor_with_state_arc` ([#1396]) - **added:** Add `DefaultBodyLimit::max` for changing the default body limit ([#1397]) +- **added:** Add `map_request`, `map_request_with_state`, and + `map_request_with_state_arc` for transforming the request with an async + function ([#1408]) - **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400]) [#1371]: https://github.com/tokio-rs/axum/pull/1371 @@ -26,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#1396]: https://github.com/tokio-rs/axum/pull/1396 [#1397]: https://github.com/tokio-rs/axum/pull/1397 [#1400]: https://github.com/tokio-rs/axum/pull/1400 +[#1408]: https://github.com/tokio-rs/axum/pull/1408 # 0.6.0-rc.2 (10. September, 2022) diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index 23d25acf62..184941d94a 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -77,6 +77,9 @@ use tower_service::Service; /// async fn my_middleware( /// TypedHeader(auth): TypedHeader>, /// Query(query_params): Query>, +/// // you can add more extractors here but the last +/// // extractor must implement `FromRequest` which +/// // `Request` does /// req: Request, /// next: Next, /// ) -> Response { @@ -117,7 +120,9 @@ pub fn from_fn(f: F) -> FromFnLayer { /// /// async fn my_middleware( /// State(state): State, -/// // you can add more extractors here... +/// // you can add more extractors here but the last +/// // extractor must implement `FromRequest` which +/// // `Request` does /// req: Request, /// next: Next, /// ) -> Response { @@ -140,6 +145,8 @@ pub fn from_fn_with_state(state: S, f: F) -> FromFnLayer { /// Create a middleware from an async function with the given [`Arc`]'ed state. /// +/// See [`from_fn_with_state`] for an example. +/// /// See [`State`](crate::extract::State) for more details about accessing state. pub fn from_fn_with_state_arc(state: Arc, f: F) -> FromFnLayer { FromFnLayer { @@ -396,7 +403,7 @@ mod tests { } async fn handle(headers: HeaderMap) -> String { - (&headers["x-axum-test"]).to_str().unwrap().to_owned() + headers["x-axum-test"].to_str().unwrap().to_owned() } let app = Router::new() diff --git a/axum/src/middleware/map_request.rs b/axum/src/middleware/map_request.rs new file mode 100644 index 0000000000..54d31913e6 --- /dev/null +++ b/axum/src/middleware/map_request.rs @@ -0,0 +1,464 @@ +use crate::response::{IntoResponse, Response}; +use axum_core::extract::{FromRequest, FromRequestParts}; +use futures_util::future::BoxFuture; +use http::Request; +use std::{ + any::type_name, + convert::Infallible, + fmt, + future::Future, + marker::PhantomData, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Create a middleware from an async function that transforms a request. +/// +/// This differs from [`tower::util::MapRequest`] in that it allows you to easily run axum-specific +/// extractors. +/// +/// # Example +/// +/// ``` +/// use axum::{ +/// Router, +/// routing::get, +/// middleware::map_request, +/// http::Request, +/// }; +/// +/// async fn set_header(mut request: Request) -> Request { +/// request.headers_mut().insert("x-foo", "foo".parse().unwrap()); +/// request +/// } +/// +/// async fn handler(request: Request) { +/// // `request` will have an `x-foo` header +/// } +/// +/// let app = Router::new() +/// .route("/", get(handler)) +/// .layer(map_request(set_header)); +/// # let _: Router = app; +/// ``` +/// +/// # Rejection the request +/// +/// The function given to `map_request` is allowed to also return a `Result` which can be used to +/// reject the request and return a response immediately, without calling the remaining +/// middleware. +/// +/// Specifically the valid return types are: +/// +/// - `Request` +/// - `Request, E> where E: IntoResponse` +/// +/// ``` +/// use axum::{ +/// Router, +/// http::{Request, StatusCode}, +/// routing::get, +/// middleware::map_request, +/// }; +/// +/// async fn auth(request: Request) -> Result, StatusCode> { +/// let auth_header = request.headers() +/// .get(http::header::AUTHORIZATION) +/// .and_then(|header| header.to_str().ok()); +/// +/// match auth_header { +/// Some(auth_header) if token_is_valid(auth_header) => Ok(request), +/// _ => Err(StatusCode::UNAUTHORIZED), +/// } +/// } +/// +/// fn token_is_valid(token: &str) -> bool { +/// // ... +/// # false +/// } +/// +/// let app = Router::new() +/// .route("/", get(|| async { /* ... */ })) +/// .route_layer(map_request(auth)); +/// # let app: Router = app; +/// ``` +/// +/// # Running extractors +/// +/// ``` +/// use axum::{ +/// Router, +/// routing::get, +/// middleware::map_request, +/// extract::Path, +/// http::Request, +/// }; +/// use std::collections::HashMap; +/// +/// async fn log_path_params( +/// Path(path_params): Path>, +/// request: Request, +/// ) -> Request { +/// tracing::debug!(?path_params); +/// request +/// } +/// +/// let app = Router::new() +/// .route("/", get(|| async { /* ... */ })) +/// .layer(map_request(log_path_params)); +/// # let _: Router = app; +/// ``` +/// +/// Note that to access state you must use either [`map_request_with_state`] or [`map_request_with_state_arc`]. +pub fn map_request(f: F) -> MapRequestLayer { + map_request_with_state((), f) +} + +/// Create a middleware from an async function that transforms a request, with the given state. +/// +/// See [`State`](crate::extract::State) for more details about accessing state. +/// +/// # Example +/// +/// ```rust +/// use axum::{ +/// Router, +/// http::{Request, StatusCode}, +/// routing::get, +/// response::IntoResponse, +/// middleware::map_request_with_state, +/// extract::State, +/// }; +/// +/// #[derive(Clone)] +/// struct AppState { /* ... */ } +/// +/// async fn my_middleware( +/// State(state): State, +/// // you can add more extractors here but the last +/// // extractor must implement `FromRequest` which +/// // `Request` does +/// request: Request, +/// ) -> Request { +/// // do something with `state` and `request`... +/// request +/// } +/// +/// let state = AppState { /* ... */ }; +/// +/// let app = Router::with_state(state.clone()) +/// .route("/", get(|| async { /* ... */ })) +/// .route_layer(map_request_with_state(state, my_middleware)); +/// # let app: Router<_> = app; +/// ``` +pub fn map_request_with_state(state: S, f: F) -> MapRequestLayer { + map_request_with_state_arc(Arc::new(state), f) +} + +/// Create a middleware from an async function that transforms a request, with the given [`Arc`]'ed +/// state. +/// +/// See [`map_request_with_state`] for an example. +/// +/// See [`State`](crate::extract::State) for more details about accessing state. +pub fn map_request_with_state_arc(state: Arc, f: F) -> MapRequestLayer { + MapRequestLayer { + f, + state, + _extractor: PhantomData, + } +} + +/// A [`tower::Layer`] from an async function that transforms a request. +/// +/// Created with [`map_request`]. See that function for more details. +pub struct MapRequestLayer { + f: F, + state: Arc, + _extractor: PhantomData T>, +} + +impl Clone for MapRequestLayer +where + F: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + state: Arc::clone(&self.state), + _extractor: self._extractor, + } + } +} + +impl Layer for MapRequestLayer +where + F: Clone, +{ + type Service = MapRequest; + + fn layer(&self, inner: I) -> Self::Service { + MapRequest { + f: self.f.clone(), + state: Arc::clone(&self.state), + inner, + _extractor: PhantomData, + } + } +} + +impl fmt::Debug for MapRequestLayer +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapRequestLayer") + // Write out the type name, without quoting it as `&type_name::()` would + .field("f", &format_args!("{}", type_name::())) + .field("state", &self.state) + .finish() + } +} + +/// A middleware created from an async function that transforms a request. +/// +/// Created with [`map_request`]. See that function for more details. +pub struct MapRequest { + f: F, + inner: I, + state: Arc, + _extractor: PhantomData T>, +} + +impl Clone for MapRequest +where + F: Clone, + I: Clone, +{ + fn clone(&self) -> Self { + Self { + f: self.f.clone(), + inner: self.inner.clone(), + state: Arc::clone(&self.state), + _extractor: self._extractor, + } + } +} + +macro_rules! impl_service { + ( + [$($ty:ident),*], $last:ident + ) => { + #[allow(non_snake_case, unused_mut)] + impl Service> for MapRequest + where + F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static, + $( $ty: FromRequestParts + Send, )* + $last: FromRequest + Send, + Fut: Future + Send + 'static, + Fut::Output: IntoMapRequestResult + Send + 'static, + I: Service, Error = Infallible> + + Clone + + Send + + 'static, + I::Response: IntoResponse, + I::Future: Send + 'static, + B: Send + 'static, + S: Send + Sync + 'static, + { + type Response = Response; + type Error = Infallible; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let not_ready_inner = self.inner.clone(); + let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner); + + let mut f = self.f.clone(); + let state = Arc::clone(&self.state); + + let future = Box::pin(async move { + let (mut parts, body) = req.into_parts(); + + $( + let $ty = match $ty::from_request_parts(&mut parts, &state).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + )* + + let req = Request::from_parts(parts, body); + + let $last = match $last::from_request(req, &state).await { + Ok(value) => value, + Err(rejection) => return rejection.into_response(), + }; + + match f($($ty,)* $last).await.into_map_request_result() { + Ok(req) => { + ready_inner.call(req).await.into_response() + } + Err(res) => { + res + } + } + }); + + ResponseFuture { + inner: future + } + } + } + }; +} + +impl_service!([], T1); +impl_service!([T1], T2); +impl_service!([T1, T2], T3); +impl_service!([T1, T2, T3], T4); +impl_service!([T1, T2, T3, T4], T5); +impl_service!([T1, T2, T3, T4, T5], T6); +impl_service!([T1, T2, T3, T4, T5, T6], T7); +impl_service!([T1, T2, T3, T4, T5, T6, T7], T8); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8], T9); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9], T10); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10], T11); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11], T12); +impl_service!([T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12], T13); +impl_service!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13], + T14 +); +impl_service!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14], + T15 +); +impl_service!( + [T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15], + T16 +); + +impl fmt::Debug for MapRequest +where + S: fmt::Debug, + I: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MapRequest") + .field("f", &format_args!("{}", type_name::())) + .field("inner", &self.inner) + .field("state", &self.state) + .finish() + } +} + +/// Response future for [`MapRequest`]. +pub struct ResponseFuture { + inner: BoxFuture<'static, Response>, +} + +impl Future for ResponseFuture { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.inner.as_mut().poll(cx).map(Ok) + } +} + +impl fmt::Debug for ResponseFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ResponseFuture").finish() + } +} + +mod private { + use crate::{http::Request, response::IntoResponse}; + + pub trait Sealed {} + impl Sealed for Result, E> where E: IntoResponse {} + impl Sealed for Request {} +} + +/// Trait implemented by types that can be returned from [`map_request`], +/// [`map_request_with_state`], and [`map_request_with_state_arc`]. +/// +/// This trait is sealed such that it cannot be implemented outside this crate. +pub trait IntoMapRequestResult: private::Sealed { + /// Perform the conversion. + fn into_map_request_result(self) -> Result, Response>; +} + +impl IntoMapRequestResult for Result, E> +where + E: IntoResponse, +{ + fn into_map_request_result(self) -> Result, Response> { + self.map_err(IntoResponse::into_response) + } +} + +impl IntoMapRequestResult for Request { + fn into_map_request_result(self) -> Result, Response> { + Ok(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{routing::get, test_helpers::TestClient, Router}; + use http::{HeaderMap, StatusCode}; + + #[tokio::test] + async fn works() { + async fn add_header(mut req: Request) -> Request { + req.headers_mut().insert("x-foo", "foo".parse().unwrap()); + req + } + + async fn handler(headers: HeaderMap) -> Response { + headers["x-foo"] + .to_str() + .unwrap() + .to_owned() + .into_response() + } + + let app = Router::new() + .route("/", get(handler)) + .layer(map_request(add_header)); + let client = TestClient::new(app); + + let res = client.get("/").send().await; + + assert_eq!(res.text().await, "foo"); + } + + #[tokio::test] + async fn works_for_short_circutting() { + async fn add_header(_req: Request) -> Result, (StatusCode, &'static str)> { + Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong")) + } + + async fn handler(_headers: HeaderMap) -> Response { + unreachable!() + } + + let app = Router::new() + .route("/", get(handler)) + .layer(map_request(add_header)); + let client = TestClient::new(app); + + let res = client.get("/").send().await; + + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!(res.text().await, "something went wrong"); + } +} diff --git a/axum/src/middleware/mod.rs b/axum/src/middleware/mod.rs index 6dde14894e..9e14825c1e 100644 --- a/axum/src/middleware/mod.rs +++ b/axum/src/middleware/mod.rs @@ -4,6 +4,7 @@ mod from_extractor; mod from_fn; +mod map_request; pub use self::from_extractor::{ from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor, @@ -12,6 +13,10 @@ pub use self::from_extractor::{ pub use self::from_fn::{ from_fn, from_fn_with_state, from_fn_with_state_arc, FromFn, FromFnLayer, Next, }; +pub use self::map_request::{ + map_request, map_request_with_state, map_request_with_state_arc, IntoMapRequestResult, + MapRequest, MapRequestLayer, +}; pub use crate::extension::AddExtension; pub mod future { @@ -19,4 +24,5 @@ pub mod future { pub use super::from_extractor::ResponseFuture as FromExtractorResponseFuture; pub use super::from_fn::ResponseFuture as FromFnResponseFuture; + pub use super::map_request::ResponseFuture as MapRequestResponseFuture; }