Skip to content

Commit

Permalink
Allow Routers to inherit state
Browse files Browse the repository at this point in the history
  • Loading branch information
jplatte committed Sep 17, 2022
1 parent f9b5e85 commit 3e68a16
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 71 deletions.
2 changes: 1 addition & 1 deletion axum-extra/src/routing/spa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl<B, T, F> SpaRouter<B, T, F> {

impl<B, F, T> From<SpaRouter<B, T, F>> for Router<(), B>
where
F: Clone + Send + 'static,
F: Clone + Send + Sync + 'static,
HandleError<Route<B, io::Error>, F, T>: Service<Request<B>, Error = Infallible>,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
<HandleError<Route<B, io::Error>, F, T> as Service<Request<B>>>::Future: Send,
Expand Down
122 changes: 122 additions & 0 deletions axum/src/handler/boxed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use std::{convert::Infallible, sync::Arc};

use super::Handler;
use crate::routing::Route;

pub(crate) struct BoxedHandler<S, B, E = Infallible>(Box<dyn ErasedHandler<S, B, E>>);

impl<S, B> BoxedHandler<S, B>
where
S: Send + Sync + 'static,
B: Send + 'static,
{
pub(crate) fn new<H, T>(handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
Self(Box::new(MakeErasedHandler {
handler,
into_route: |handler, state| Route::new(Handler::with_state_arc(handler, state)),
}))
}
}

impl<S, B, E> BoxedHandler<S, B, E> {
pub(crate) fn map<F, B2, E2>(self, f: F) -> BoxedHandler<S, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
B2: 'static,
E2: 'static,
{
BoxedHandler(Box::new(Map {
handler: self.0,
layer: Box::new(f),
}))
}

pub(crate) fn into_route(self, state: Arc<S>) -> Route<B, E> {
self.0.into_route(state)
}
}

impl<S, B, E> Clone for BoxedHandler<S, B, E> {
fn clone(&self) -> Self {
Self(self.0.clone_box())
}
}

trait ErasedHandler<S, B, E = Infallible>: Send {
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B, E>>;
fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B, E>;
}

struct MakeErasedHandler<H, S, B> {
handler: H,
into_route: fn(H, Arc<S>) -> Route<B>,
}

impl<H, S, B> ErasedHandler<S, B> for MakeErasedHandler<H, S, B>
where
H: Clone + Send + 'static,
S: 'static,
B: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B>> {
Box::new(self.clone())
}

fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B> {
(self.into_route)(self.handler, state)
}
}

impl<H: Clone, S, B> Clone for MakeErasedHandler<H, S, B> {
fn clone(&self) -> Self {
Self {
handler: self.handler.clone(),
into_route: self.into_route,
}
}
}

struct Map<S, B, E, B2, E2> {
handler: Box<dyn ErasedHandler<S, B, E>>,
layer: Box<dyn LayerFn<B, E, B2, E2>>,
}

impl<S, B, E, B2, E2> ErasedHandler<S, B2, E2> for Map<S, B, E, B2, E2>
where
S: 'static,
B: 'static,
E: 'static,
B2: 'static,
E2: 'static,
{
fn clone_box(&self) -> Box<dyn ErasedHandler<S, B2, E2>> {
Box::new(Self {
handler: self.handler.clone_box(),
layer: self.layer.clone_box(),
})
}

fn into_route(self: Box<Self>, state: Arc<S>) -> Route<B2, E2> {
(self.layer)(self.handler.into_route(state))
}
}

trait LayerFn<B, E, B2, E2>: FnOnce(Route<B, E>) -> Route<B2, E2> + Send {
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>>;
}

impl<F, B, E, B2, E2> LayerFn<B, E, B2, E2> for F
where
F: FnOnce(Route<B, E>) -> Route<B2, E2> + Clone + Send + 'static,
{
fn clone_box(&self) -> Box<dyn LayerFn<B, E, B2, E2>> {
Box::new(self.clone())
}
}
5 changes: 4 additions & 1 deletion axum/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,15 @@ use tower::ServiceExt;
use tower_layer::Layer;
use tower_service::Service;

mod boxed;
pub mod future;
mod into_service;
mod into_service_state_in_extension;
mod with_state;

pub(crate) use self::into_service_state_in_extension::IntoServiceStateInExtension;
pub(crate) use self::{
boxed::BoxedHandler, into_service_state_in_extension::IntoServiceStateInExtension,
};
pub use self::{into_service::IntoService, with_state::WithState};

/// Trait for async functions that can be used to handle requests.
Expand Down
88 changes: 57 additions & 31 deletions axum/src/routing/method_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ use crate::{
body::{Body, Bytes, HttpBody},
error_handling::{HandleError, HandleErrorLayer},
extract::connect_info::IntoMakeServiceWithConnectInfo,
handler::{Handler, IntoServiceStateInExtension},
handler::{BoxedHandler, Handler, IntoServiceStateInExtension},
http::{Method, Request, StatusCode},
response::Response,
routing::{future::RouteFuture, Fallback, MethodFilter, Route},
util::try_downcast,
};
use axum_core::response::IntoResponse;
use bytes::BytesMut;
use std::{
convert::Infallible,
fmt,
marker::PhantomData,
sync::Arc,
task::{Context, Poll},
};
Expand Down Expand Up @@ -521,9 +521,8 @@ pub struct MethodRouter<S = (), B = Body, E = Infallible> {
post: Option<Route<B, E>>,
put: Option<Route<B, E>>,
trace: Option<Route<B, E>>,
fallback: Fallback<B, E>,
fallback: Fallback<S, B, E>,
allow_header: AllowHeader,
_marker: PhantomData<fn() -> S>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -720,7 +719,6 @@ where
trace: None,
allow_header: AllowHeader::None,
fallback: Fallback::Default(fallback),
_marker: PhantomData,
}
}

Expand All @@ -741,7 +739,15 @@ where
}
}

pub(crate) fn downcast_state<S2>(self) -> MethodRouter<S2, B, E> {
pub(crate) fn downcast_state<S2>(
self,
state_for_fallback: Option<Arc<S>>,
) -> MethodRouter<S2, B, E>
where
E: 'static,
S: 'static,
S2: 'static,
{
MethodRouter {
get: self.get,
head: self.head,
Expand All @@ -751,9 +757,20 @@ where
post: self.post,
put: self.put,
trace: self.trace,
fallback: self.fallback,
fallback: match self.fallback {
Fallback::Default(route) => Fallback::Default(route),
Fallback::Service(route) => Fallback::Service(route),
Fallback::BoxedHandler(handler) => match state_for_fallback {
Some(state) => Fallback::Service(handler.into_route(state)),
None => Fallback::BoxedHandler(
try_downcast::<BoxedHandler<S2, B, E>, BoxedHandler<S, B, E>>(handler)
.unwrap_or_else(|_| {
panic!("we should have panicked earlier if state types don't match")
}),
),
},
},
allow_header: self.allow_header,
_marker: PhantomData,
}
}

Expand Down Expand Up @@ -823,31 +840,35 @@ where
}

#[doc = include_str!("../docs/method_routing/layer.md")]
pub fn layer<L, NewReqBody, NewError>(self, layer: L) -> MethodRouter<S, NewReqBody, NewError>
pub fn layer<L, NewReqBody: 'static, NewError: 'static>(
self,
layer: L,
) -> MethodRouter<S, NewReqBody, NewError>
where
L: Layer<Route<B, E>>,
L: Layer<Route<B, E>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>, Error = NewError> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
E: 'static,
S: 'static,
{
let layer_fn = |svc| {
let layer_fn = move |svc| {
let svc = layer.layer(svc);
let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
Route::new(svc)
};

MethodRouter {
get: self.get.map(layer_fn),
head: self.head.map(layer_fn),
delete: self.delete.map(layer_fn),
options: self.options.map(layer_fn),
patch: self.patch.map(layer_fn),
post: self.post.map(layer_fn),
put: self.put.map(layer_fn),
trace: self.trace.map(layer_fn),
get: self.get.map(layer_fn.clone()),
head: self.head.map(layer_fn.clone()),
delete: self.delete.map(layer_fn.clone()),
options: self.options.map(layer_fn.clone()),
patch: self.patch.map(layer_fn.clone()),
post: self.post.map(layer_fn.clone()),
put: self.put.map(layer_fn.clone()),
trace: self.trace.map(layer_fn.clone()),
fallback: self.fallback.map(layer_fn),
allow_header: self.allow_header,
_marker: self._marker,
}
}

Expand Down Expand Up @@ -924,15 +945,15 @@ where
}

#[track_caller]
fn merge_fallback<B, E>(
fallback: Fallback<B, E>,
fallback_other: Fallback<B, E>,
) -> Fallback<B, E> {
fn merge_fallback<S, B, E>(
fallback: Fallback<S, B, E>,
fallback_other: Fallback<S, B, E>,
) -> Fallback<S, B, E> {
match (fallback, fallback_other) {
(pick @ Fallback::Default(_), Fallback::Default(_)) => pick,
(Fallback::Default(_), pick @ Fallback::Service(_)) => pick,
(pick @ Fallback::Service(_), Fallback::Default(_)) => pick,
(Fallback::Service(_), Fallback::Service(_)) => {
(Fallback::Default(_), pick) => pick,
(pick, Fallback::Default(_)) => pick,
_ => {
panic!("Cannot merge two `MethodRouter`s that both have a fallback")
}
}
Expand Down Expand Up @@ -965,13 +986,14 @@ where
/// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, B, Infallible>
where
F: Clone + Send + 'static,
F: Clone + Send + Sync + 'static,
HandleError<Route<B, E>, F, T>: Service<Request<B>, Error = Infallible>,
<HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Future: Send,
<HandleError<Route<B, E>, F, T> as Service<Request<B>>>::Response: IntoResponse + Send,
T: 'static,
E: 'static,
B: 'static,
S: 'static,
{
self.layer(HandleErrorLayer::new(f))
}
Expand Down Expand Up @@ -1149,7 +1171,6 @@ impl<S, B, E> Clone for MethodRouter<S, B, E> {
trace: self.trace.clone(),
fallback: self.fallback.clone(),
allow_header: self.allow_header.clone(),
_marker: self._marker,
}
}
}
Expand Down Expand Up @@ -1224,7 +1245,7 @@ where

impl<S, B, E> Service<Request<B>> for WithState<S, B, E>
where
B: HttpBody,
B: HttpBody + Send,
S: Send + Sync + 'static,
{
type Response = Response;
Expand Down Expand Up @@ -1270,7 +1291,6 @@ where
trace,
fallback,
allow_header,
_marker: _,
},
} = self;

Expand All @@ -1291,6 +1311,12 @@ where
.strip_body(method == Method::HEAD),
Fallback::Service(fallback) => RouteFuture::from_future(fallback.oneshot_inner(req))
.strip_body(method == Method::HEAD),
Fallback::BoxedHandler(fallback) => RouteFuture::from_future(
fallback
.clone()
.into_route(Arc::clone(state))
.oneshot_inner(req),
),
};

match allow_header {
Expand Down

0 comments on commit 3e68a16

Please sign in to comment.