Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add map_response and friends #1414

Merged
merged 2 commits into from Sep 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Expand Up @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **added:** Add `map_request`, `map_request_with_state`, and
`map_request_with_state_arc` for transforming the request with an async
function ([#1408])
- **added:** Add `map_response`, `map_response_with_state`, and
`map_response_with_state_arc` for transforming the response with an async
function
- **breaking:** `ContentLengthLimit` has been removed. `Use DefaultBodyLimit` instead ([#1400])
- **changed:** `Router` no longer implements `Service`, call `.into_service()`
on it to obtain a `RouterService` that does ([#1368])
Expand Down
341 changes: 341 additions & 0 deletions axum/src/middleware/map_response.rs
@@ -0,0 +1,341 @@
use crate::response::{IntoResponse, Response};
use axum_core::extract::FromRequestParts;
use futures_util::future::BoxFuture;
use http::Request;
use std::{
any::type_name,
convert::Infallible,
fmt,
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower_layer::Layer;
use tower_service::Service;

/// Create a middleware from an async function that transforms a response.
///
/// This differs from [`tower::util::MapResponse`] in that it allows you to easily run axum-specific
/// extractors.
///
/// # Example
///
/// ```
/// use axum::{
/// Router,
/// routing::get,
/// middleware::map_response,
/// response::Response,
/// };
///
/// async fn set_header<B>(mut response: Response<B>) -> Response<B> {
/// response.headers_mut().insert("x-foo", "foo".parse().unwrap());
/// response
/// }
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .layer(map_response(set_header));
/// # let _: Router = app;
/// ```
///
/// # Running extractors
///
/// ```
/// use axum::{
/// Router,
/// routing::get,
/// middleware::map_response,
/// extract::Path,
/// response::Response,
/// };
/// use std::collections::HashMap;
///
/// async fn log_path_params<B>(
/// Path(path_params): Path<HashMap<String, String>>,
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
/// response: Response<B>,
/// ) -> Response<B> {
davidpdrsn marked this conversation as resolved.
Show resolved Hide resolved
/// tracing::debug!(?path_params);
/// response
/// }
///
/// let app = Router::new()
/// .route("/", get(|| async { /* ... */ }))
/// .layer(map_response(log_path_params));
/// # let _: Router = app;
/// ```
///
/// Note that to access state you must use either [`map_response_with_state`] or [`map_response_with_state_arc`].
pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
map_response_with_state((), f)
}

/// Create a middleware from an async function that transforms a response, with the given state.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
///
/// # Example
///
/// ```rust
/// use axum::{
/// Router,
/// http::StatusCode,
/// routing::get,
/// response::Response,
/// middleware::map_response_with_state,
/// extract::State,
/// };
///
/// #[derive(Clone)]
/// struct AppState { /* ... */ }
///
/// async fn my_middleware<B>(
/// State(state): State<AppState>,
/// // you can add more extractors here but they must
/// // all implement `FromRequestParts`
/// // `FromRequest` is not allowed
/// response: Response<B>,
/// ) -> Response<B> {
/// // do something with `state` and `response`...
/// response
/// }
///
/// let state = AppState { /* ... */ };
///
/// let app = Router::with_state(state.clone())
/// .route("/", get(|| async { /* ... */ }))
/// .route_layer(map_response_with_state(state, my_middleware));
/// # let app: Router<_> = app;
/// ```
pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
map_response_with_state_arc(Arc::new(state), f)
}

/// Create a middleware from an async function that transforms a response, with the given [`Arc`]'ed
/// state.
///
/// See [`map_response_with_state`] for an example.
///
/// See [`State`](crate::extract::State) for more details about accessing state.
pub fn map_response_with_state_arc<F, S, T>(state: Arc<S>, f: F) -> MapResponseLayer<F, S, T> {
MapResponseLayer {
f,
state,
_extractor: PhantomData,
}
}

/// A [`tower::Layer`] from an async function that transforms a response.
///
/// Created with [`map_response`]. See that function for more details.
pub struct MapResponseLayer<F, S, T> {
f: F,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>,
}

impl<F, S, T> Clone for MapResponseLayer<F, S, T>
where
F: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor,
}
}
}

impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
where
F: Clone,
{
type Service = MapResponse<F, S, I, T>;

fn layer(&self, inner: I) -> Self::Service {
MapResponse {
f: self.f.clone(),
state: Arc::clone(&self.state),
inner,
_extractor: PhantomData,
}
}
}

impl<F, S, T> fmt::Debug for MapResponseLayer<F, S, T>
where
S: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MapResponseLayer")
// Write out the type name, without quoting it as `&type_name::<F>()` would
.field("f", &format_args!("{}", type_name::<F>()))
.field("state", &self.state)
.finish()
}
}

/// A middleware created from an async function that transforms a response.
///
/// Created with [`map_response`]. See that function for more details.
pub struct MapResponse<F, S, I, T> {
f: F,
inner: I,
state: Arc<S>,
_extractor: PhantomData<fn() -> T>,
}

impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
where
F: Clone,
I: Clone,
{
fn clone(&self) -> Self {
Self {
f: self.f.clone(),
inner: self.inner.clone(),
state: Arc::clone(&self.state),
_extractor: self._extractor,
}
}
}

macro_rules! impl_service {
(
$($ty:ident),*
) => {
#[allow(non_snake_case, unused_mut)]
impl<F, Fut, S, I, B, ResBody, $($ty,)*> Service<Request<B>> for MapResponse<F, S, I, ($($ty,)*)>
where
F: FnMut($($ty,)* Response<ResBody>) -> Fut + Clone + Send + 'static,
$( $ty: FromRequestParts<S> + Send, )*
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we require FromRequestParts and not FromRequest. Not sure there would be a way to recover the request afterwards so we can pass it to the inner service 🤔

If users need the request body in a map_response I think a full from_fn might be more appropriate.

Copy link
Member

@jplatte jplatte Sep 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is exactly right IMO.

Fut: Future + Send + 'static,
Fut::Output: IntoResponse + Send + 'static,
I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
+ Clone
+ Send
+ 'static,
I::Future: Send + 'static,
B: Send + 'static,
ResBody: Send + 'static,
S: Send + Sync + 'static,
{
type Response = Response;
type Error = Infallible;
type Future = ResponseFuture;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}


fn call(&mut self, req: Request<B>) -> Self::Future {
let not_ready_inner = self.inner.clone();
let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);

let mut f = self.f.clone();
let _state = Arc::clone(&self.state);

let future = Box::pin(async move {
let (mut parts, body) = req.into_parts();

$(
let $ty = match $ty::from_request_parts(&mut parts, &_state).await {
Ok(value) => value,
Err(rejection) => return rejection.into_response(),
};
)*

let req = Request::from_parts(parts, body);

match ready_inner.call(req).await {
Ok(res) => {
f($($ty,)* res).await.into_response()
}
Err(err) => match err {}
}
});

ResponseFuture {
inner: future
}
}
}
};
}

impl_service!();
impl_service!(T1);
impl_service!(T1, T2);
impl_service!(T1, T2, T3);
impl_service!(T1, T2, T3, T4);
impl_service!(T1, T2, T3, T4, T5);
impl_service!(T1, T2, T3, T4, T5, T6);
impl_service!(T1, T2, T3, T4, T5, T6, T7);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);

impl<F, S, I, T> fmt::Debug for MapResponse<F, S, I, T>
where
S: fmt::Debug,
I: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MapResponse")
.field("f", &format_args!("{}", type_name::<F>()))
.field("inner", &self.inner)
.field("state", &self.state)
.finish()
}
}

/// Response future for [`MapResponse`].
pub struct ResponseFuture {
inner: BoxFuture<'static, Response>,
}

impl Future for ResponseFuture {
type Output = Result<Response, Infallible>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.inner.as_mut().poll(cx).map(Ok)
}
}

impl fmt::Debug for ResponseFuture {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}

#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::*;
use crate::{test_helpers::TestClient, Router};

#[tokio::test]
async fn works() {
async fn add_header<B>(mut res: Response<B>) -> Response<B> {
res.headers_mut().insert("x-foo", "foo".parse().unwrap());
res
}

let app = Router::new().layer(map_response(add_header));
let client = TestClient::new(app);

let res = client.get("/").send().await;

assert_eq!(res.headers()["x-foo"], "foo");
}
}
6 changes: 6 additions & 0 deletions axum/src/middleware/mod.rs
Expand Up @@ -5,6 +5,7 @@
mod from_extractor;
mod from_fn;
mod map_request;
mod map_response;

pub use self::from_extractor::{
from_extractor, from_extractor_with_state, from_extractor_with_state_arc, FromExtractor,
Expand All @@ -17,6 +18,10 @@ pub use self::map_request::{
map_request, map_request_with_state, map_request_with_state_arc, IntoMapRequestResult,
MapRequest, MapRequestLayer,
};
pub use self::map_response::{
map_response, map_response_with_state, map_response_with_state_arc, MapResponse,
MapResponseLayer,
};
pub use crate::extension::AddExtension;

pub mod future {
Expand All @@ -25,4 +30,5 @@ pub mod future {
pub use super::from_extractor::ResponseFuture as FromExtractorResponseFuture;
pub use super::from_fn::ResponseFuture as FromFnResponseFuture;
pub use super::map_request::ResponseFuture as MapRequestResponseFuture;
pub use super::map_response::ResponseFuture as MapResponseResponseFuture;
}