diff --git a/axum-core/src/extract/from_ref.rs b/axum-core/src/extract/from_ref.rs index c0124140e5..cee9d3377a 100644 --- a/axum-core/src/extract/from_ref.rs +++ b/axum-core/src/extract/from_ref.rs @@ -5,7 +5,10 @@ /// /// See [`State`] for more details on how library authors should use this trait. /// +/// This trait can be derived using `#[derive(axum_macros::FromRef)]`. +/// /// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html +/// [`#[derive(axum_macros::FromRef)]`]: https://docs.rs/axum-macros/latest/axum_macros/derive.FromRef.html // NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is // defined in axum. That allows crate authors to use it when implementing extractors. pub trait FromRef { diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 47fa8b1e0d..5a5789c7cd 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None +- **added:** Add `#[derive(FromRef)]` ([#1430]) + +[#1430]: https://github.com/tokio-rs/axum/pull/1430 # 0.3.0-rc.1 (23. August, 2022) diff --git a/axum-macros/src/from_ref.rs b/axum-macros/src/from_ref.rs new file mode 100644 index 0000000000..3ebb418015 --- /dev/null +++ b/axum-macros/src/from_ref.rs @@ -0,0 +1,39 @@ +use proc_macro2::{Ident, TokenStream}; +use quote::quote_spanned; +use syn::{spanned::Spanned, Field, ItemStruct}; + +pub(crate) fn expand(item: ItemStruct) -> TokenStream { + item.fields + .iter() + .enumerate() + .map(|(idx, field)| expand_field(&item.ident, idx, field)) + .collect() +} + +fn expand_field(state: &Ident, idx: usize, field: &Field) -> TokenStream { + let field_ty = &field.ty; + let span = field.ty.span(); + + let body = if let Some(field_ident) = &field.ident { + quote_spanned! {span=> state.#field_ident.clone() } + } else { + let idx = syn::Index { + index: idx as _, + span: field.span(), + }; + quote_spanned! {span=> state.#idx.clone() } + }; + + quote_spanned! {span=> + impl ::axum::extract::FromRef<#state> for #field_ty { + fn from_ref(state: &#state) -> Self { + #body + } + } + } +} + +#[test] +fn ui() { + crate::run_ui_tests("from_ref"); +} diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 0644f057a0..3d864cd5b2 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -49,6 +49,7 @@ use syn::{parse::Parse, Type}; mod attr_parsing; mod debug_handler; +mod from_ref; mod from_request; mod typed_path; mod with_position; @@ -573,6 +574,47 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream { expand_with(input, typed_path::expand) } +/// Derive an implementation of [`FromRef`] for each field in a struct. +/// +/// # Example +/// +/// ``` +/// use axum_macros::FromRef; +/// use axum::{Router, routing::get, extract::State}; +/// +/// # +/// # type AuthToken = String; +/// # type DatabasePool = (); +/// # +/// // This will implement `FromRef` for each field in the struct. +/// #[derive(FromRef, Clone)] +/// struct AppState { +/// auth_token: AuthToken, +/// database_pool: DatabasePool, +/// } +/// +/// // So those types can be extracted via `State` +/// async fn handler(State(auth_token): State) {} +/// +/// async fn other_handler(State(database_pool): State) {} +/// +/// # let auth_token = Default::default(); +/// # let database_pool = Default::default(); +/// let state = AppState { +/// auth_token, +/// database_pool, +/// }; +/// +/// let app = Router::with_state(state).route("/", get(handler).post(other_handler)); +/// # let _: Router = app; +/// ``` +/// +/// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html +#[proc_macro_derive(FromRef, attributes(from_ref))] +pub fn derive_from_ref(item: TokenStream) -> TokenStream { + expand_with(item, |item| Ok(from_ref::expand(item))) +} + fn expand_with(input: TokenStream, f: F) -> TokenStream where F: FnOnce(I) -> syn::Result, diff --git a/axum-macros/tests/from_ref/pass/basic.rs b/axum-macros/tests/from_ref/pass/basic.rs new file mode 100644 index 0000000000..4a6631d308 --- /dev/null +++ b/axum-macros/tests/from_ref/pass/basic.rs @@ -0,0 +1,19 @@ +use axum_macros::FromRef; +use axum::{Router, routing::get, extract::State}; + +// This will implement `FromRef` for each field in the struct. +#[derive(Clone, FromRef)] +struct AppState { + auth_token: String, +} + +// So those types can be extracted via `State` +async fn handler(_: State) {} + +fn main() { + let state = AppState { + auth_token: Default::default(), + }; + + let _: Router = Router::with_state(state).route("/", get(handler)); +} diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 8119024079..3401c05ccb 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **changed:** The default body limit now applies to the `Multipart` extractor ([#1420]) - **added:** String and binary `From` impls have been added to `extract::ws::Message` to be more inline with `tungstenite` ([#1421]) +- **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430]) - **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from [`axum-macros`] behind the `macros` feature ([#1352]) diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 9dde11f518..a11921df41 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -19,7 +19,7 @@ mod state; pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts}; #[cfg(feature = "macros")] -pub use axum_macros::{FromRequest, FromRequestParts}; +pub use axum_macros::{FromRef, FromRequest, FromRequestParts}; #[doc(inline)] #[allow(deprecated)]