diff --git a/tracing-attributes/src/expand.rs b/tracing-attributes/src/expand.rs index d4ef29fc7a..e98554a8b0 100644 --- a/tracing-attributes/src/expand.rs +++ b/tracing-attributes/src/expand.rs @@ -2,10 +2,11 @@ use std::iter; use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; +use syn::visit_mut::VisitMut; use syn::{ punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg, Ident, Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType, - Path, Signature, Stmt, Token, TypePath, + Path, ReturnType, Signature, Stmt, Token, Type, TypePath, }; use crate::{ @@ -18,7 +19,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>( input: MaybeItemFnRef<'a, B>, args: InstrumentArgs, instrumented_function_name: &str, - self_type: Option<&syn::TypePath>, + self_type: Option<&TypePath>, ) -> proc_macro2::TokenStream { // these are needed ahead of time, as ItemFn contains the function body _and_ // isn't representable inside a quote!/quote_spanned! macro @@ -31,7 +32,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>( } = input; let Signature { - output: return_type, + output, inputs: params, unsafety, asyncness, @@ -49,8 +50,34 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>( let warnings = args.warnings(); + let block = if let ReturnType::Type(_, return_type) = &output { + let return_type = erase_impl_trait(return_type); + // Install a fake return statement as the first thing in the function + // body, so that we eagerly infer that the return type is what we + // declared in the async fn signature. + let fake_return_edge = quote_spanned! {return_type.span()=> + #[allow(unreachable_code)] + if false { + let __tracing_attr_fake_return: #return_type = panic!(); + return __tracing_attr_fake_return; + } + }; + quote! { + { + #fake_return_edge + #block + } + } + } else { + quote! { + { + let _: () = #block; + } + } + }; + let body = gen_block( - block, + &block, params, asyncness.is_some(), args, @@ -60,7 +87,7 @@ pub(crate) fn gen_function<'a, B: ToTokens + 'a>( quote!( #(#attrs) * - #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type + #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #output #where_clause { #warnings @@ -76,7 +103,7 @@ fn gen_block( async_context: bool, mut args: InstrumentArgs, instrumented_function_name: &str, - self_type: Option<&syn::TypePath>, + self_type: Option<&TypePath>, ) -> proc_macro2::TokenStream { // generate the span's name let span_name = args @@ -393,11 +420,11 @@ impl RecordType { "Wrapping", ]; - /// Parse `RecordType` from [syn::Type] by looking up + /// Parse `RecordType` from [Type] by looking up /// the [RecordType::TYPES_FOR_VALUE] array. - fn parse_from_ty(ty: &syn::Type) -> Self { + fn parse_from_ty(ty: &Type) -> Self { match ty { - syn::Type::Path(syn::TypePath { path, .. }) + Type::Path(TypePath { path, .. }) if path .segments .iter() @@ -410,9 +437,7 @@ impl RecordType { { RecordType::Value } - syn::Type::Reference(syn::TypeReference { elem, .. }) => { - RecordType::parse_from_ty(&*elem) - } + Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(&*elem), _ => RecordType::Debug, } } @@ -471,7 +496,7 @@ pub(crate) struct AsyncInfo<'block> { // statement that must be patched source_stmt: &'block Stmt, kind: AsyncKind<'block>, - self_type: Option, + self_type: Option, input: &'block ItemFn, } @@ -606,11 +631,11 @@ impl<'block> AsyncInfo<'block> { if ident == "_self" { let mut ty = *ty.ty.clone(); // extract the inner type if the argument is "&self" or "&mut self" - if let syn::Type::Reference(syn::TypeReference { elem, .. }) = ty { + if let Type::Reference(syn::TypeReference { elem, .. }) = ty { ty = *elem; } - if let syn::Type::Path(tp) = ty { + if let Type::Path(tp) = ty { self_type = Some(tp); break; } @@ -722,7 +747,7 @@ struct IdentAndTypesRenamer<'a> { idents: Vec<(Ident, Ident)>, } -impl<'a> syn::visit_mut::VisitMut for IdentAndTypesRenamer<'a> { +impl<'a> VisitMut for IdentAndTypesRenamer<'a> { // we deliberately compare strings because we want to ignore the spans // If we apply clippy's lint, the behavior changes #[allow(clippy::cmp_owned)] @@ -734,11 +759,11 @@ impl<'a> syn::visit_mut::VisitMut for IdentAndTypesRenamer<'a> { } } - fn visit_type_mut(&mut self, ty: &mut syn::Type) { + fn visit_type_mut(&mut self, ty: &mut Type) { for (type_name, new_type) in &self.types { - if let syn::Type::Path(TypePath { path, .. }) = ty { + if let Type::Path(TypePath { path, .. }) = ty { if path_to_string(path) == *type_name { - *ty = syn::Type::Path(new_type.clone()); + *ty = Type::Path(new_type.clone()); } } } @@ -751,10 +776,33 @@ struct AsyncTraitBlockReplacer<'a> { patched_block: Block, } -impl<'a> syn::visit_mut::VisitMut for AsyncTraitBlockReplacer<'a> { +impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> { fn visit_block_mut(&mut self, i: &mut Block) { if i == self.block { *i = self.patched_block.clone(); } } } + +// Replaces any `impl Trait` with `_` so it can be used as the type in +// a `let` statement's LHS. +struct ImplTraitEraser; + +impl VisitMut for ImplTraitEraser { + fn visit_type_mut(&mut self, t: &mut Type) { + if let Type::ImplTrait(..) = t { + *t = syn::TypeInfer { + underscore_token: Token![_](t.span()), + } + .into(); + } else { + syn::visit_mut::visit_type_mut(self, t); + } + } +} + +fn erase_impl_trait(ty: &Type) -> Type { + let mut ty = ty.clone(); + ImplTraitEraser.visit_type_mut(&mut ty); + ty +}