Skip to content

Commit

Permalink
Add #[derive(FromRef)] (#1430)
Browse files Browse the repository at this point in the history
* add `#[derive(FromRef)]`

* tests

* don't support skipping fields

probably wouldn't work at all since the whole state likely needs `Clone`

* UI tests

* changelog

* changelog link

* revert hello-world example, used for testing

* Re-export `#[derive(FromRef)]`

* Don't need to return `Result`

* use `collect` instead of quoting the iterator

* Mention it in axum's changelog
  • Loading branch information
davidpdrsn committed Oct 10, 2022
1 parent 1681ecf commit 9c0a89c
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 2 deletions.
3 changes: 3 additions & 0 deletions axum-core/src/extract/from_ref.rs
Expand Up @@ -5,7 +5,10 @@
///
/// See [`State`] for more details on how library authors should use this trait.
///
/// This trait can be derived using `#[derive(axum_macros::FromRef)]`.
///
/// [`State`]: https://docs.rs/axum/0.6/axum/extract/struct.State.html
/// [`#[derive(axum_macros::FromRef)]`]: https://docs.rs/axum-macros/latest/axum_macros/derive.FromRef.html
// NOTE: This trait is defined in axum-core, even though it is mainly used with `State` which is
// defined in axum. That allows crate authors to use it when implementing extractors.
pub trait FromRef<T> {
Expand Down
4 changes: 3 additions & 1 deletion axum-macros/CHANGELOG.md
Expand Up @@ -7,7 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

# Unreleased

- None
- **added:** Add `#[derive(FromRef)]` ([#1430])

[#1430]: https://github.com/tokio-rs/axum/pull/1430

# 0.3.0-rc.1 (23. August, 2022)

Expand Down
39 changes: 39 additions & 0 deletions axum-macros/src/from_ref.rs
@@ -0,0 +1,39 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote_spanned;
use syn::{spanned::Spanned, Field, ItemStruct};

pub(crate) fn expand(item: ItemStruct) -> TokenStream {
item.fields
.iter()
.enumerate()
.map(|(idx, field)| expand_field(&item.ident, idx, field))
.collect()
}

fn expand_field(state: &Ident, idx: usize, field: &Field) -> TokenStream {
let field_ty = &field.ty;
let span = field.ty.span();

let body = if let Some(field_ident) = &field.ident {
quote_spanned! {span=> state.#field_ident.clone() }
} else {
let idx = syn::Index {
index: idx as _,
span: field.span(),
};
quote_spanned! {span=> state.#idx.clone() }
};

quote_spanned! {span=>
impl ::axum::extract::FromRef<#state> for #field_ty {
fn from_ref(state: &#state) -> Self {
#body
}
}
}
}

#[test]
fn ui() {
crate::run_ui_tests("from_ref");
}
42 changes: 42 additions & 0 deletions axum-macros/src/lib.rs
Expand Up @@ -49,6 +49,7 @@ use syn::{parse::Parse, Type};

mod attr_parsing;
mod debug_handler;
mod from_ref;
mod from_request;
mod typed_path;
mod with_position;
Expand Down Expand Up @@ -573,6 +574,47 @@ pub fn derive_typed_path(input: TokenStream) -> TokenStream {
expand_with(input, typed_path::expand)
}

/// Derive an implementation of [`FromRef`] for each field in a struct.
///
/// # Example
///
/// ```
/// use axum_macros::FromRef;
/// use axum::{Router, routing::get, extract::State};
///
/// #
/// # type AuthToken = String;
/// # type DatabasePool = ();
/// #
/// // This will implement `FromRef` for each field in the struct.
/// #[derive(FromRef, Clone)]
/// struct AppState {
/// auth_token: AuthToken,
/// database_pool: DatabasePool,
/// }
///
/// // So those types can be extracted via `State`
/// async fn handler(State(auth_token): State<AuthToken>) {}
///
/// async fn other_handler(State(database_pool): State<DatabasePool>) {}
///
/// # let auth_token = Default::default();
/// # let database_pool = Default::default();
/// let state = AppState {
/// auth_token,
/// database_pool,
/// };
///
/// let app = Router::with_state(state).route("/", get(handler).post(other_handler));
/// # let _: Router<AppState> = app;
/// ```
///
/// [`FromRef`]: https://docs.rs/axum/latest/axum/extract/trait.FromRef.html
#[proc_macro_derive(FromRef, attributes(from_ref))]
pub fn derive_from_ref(item: TokenStream) -> TokenStream {
expand_with(item, |item| Ok(from_ref::expand(item)))
}

fn expand_with<F, I, K>(input: TokenStream, f: F) -> TokenStream
where
F: FnOnce(I) -> syn::Result<K>,
Expand Down
19 changes: 19 additions & 0 deletions axum-macros/tests/from_ref/pass/basic.rs
@@ -0,0 +1,19 @@
use axum_macros::FromRef;
use axum::{Router, routing::get, extract::State};

// This will implement `FromRef` for each field in the struct.
#[derive(Clone, FromRef)]
struct AppState {
auth_token: String,
}

// So those types can be extracted via `State`
async fn handler(_: State<String>) {}

fn main() {
let state = AppState {
auth_token: Default::default(),
};

let _: Router<AppState> = Router::with_state(state).route("/", get(handler));
}
1 change: 1 addition & 0 deletions axum/CHANGELOG.md
Expand Up @@ -43,6 +43,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **changed:** The default body limit now applies to the `Multipart` extractor ([#1420])
- **added:** String and binary `From` impls have been added to `extract::ws::Message`
to be more inline with `tungstenite` ([#1421])
- **added:** Add `#[derive(axum::extract::FromRef)]` ([#1430])
- **added:** `FromRequest` and `FromRequestParts` derive macro re-exports from
[`axum-macros`] behind the `macros` feature ([#1352])

Expand Down
2 changes: 1 addition & 1 deletion axum/src/extract/mod.rs
Expand Up @@ -19,7 +19,7 @@ mod state;
pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequestParts};

#[cfg(feature = "macros")]
pub use axum_macros::{FromRequest, FromRequestParts};
pub use axum_macros::{FromRef, FromRequest, FromRequestParts};

#[doc(inline)]
#[allow(deprecated)]
Expand Down

0 comments on commit 9c0a89c

Please sign in to comment.