Skip to content

Commit

Permalink
Merge pull request #50 from dtolnay/transparent
Browse files Browse the repository at this point in the history
Add transparent attribute for delegating Error impl to one field
  • Loading branch information
dtolnay committed Dec 1, 2019
2 parents 038b8d5 + ac61f40 commit 62e8e66
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 42 deletions.
2 changes: 2 additions & 0 deletions impl/src/ast.rs
Expand Up @@ -79,6 +79,8 @@ impl<'a> Enum<'a> {
}
if let Some(display) = &mut variant.attrs.display {
display.expand_shorthand(&variant.fields);
} else if variant.attrs.transparent.is_none() {
variant.attrs.transparent = attrs.transparent;
}
Ok(variant)
})
Expand Down
38 changes: 27 additions & 11 deletions impl/src/attr.rs
Expand Up @@ -12,6 +12,7 @@ pub struct Attrs<'a> {
pub source: Option<&'a Attribute>,
pub backtrace: Option<&'a Attribute>,
pub from: Option<&'a Attribute>,
pub transparent: Option<&'a Attribute>,
}

#[derive(Clone)]
Expand All @@ -29,18 +30,12 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {
source: None,
backtrace: None,
from: None,
transparent: None,
};

for attr in input {
if attr.path.is_ident("error") {
let display = parse_display(attr)?;
if attrs.display.is_some() {
return Err(Error::new_spanned(
attr,
"only one #[error(...)] attribute is allowed",
));
}
attrs.display = Some(display);
parse_error_attribute(&mut attrs, attr)?;
} else if attr.path.is_ident("source") {
require_empty_attribute(attr)?;
if attrs.source.is_some() {
Expand Down Expand Up @@ -68,15 +63,36 @@ pub fn get(input: &[Attribute]) -> Result<Attrs> {
Ok(attrs)
}

fn parse_display(attr: &Attribute) -> Result<Display> {
fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> {
syn::custom_keyword!(transparent);

attr.parse_args_with(|input: ParseStream| {
Ok(Display {
if input.parse::<Option<transparent>>()?.is_some() {
if attrs.transparent.is_some() {
return Err(Error::new_spanned(
attr,
"duplicate #[error(transparent)] attribute",
));
}
attrs.transparent = Some(attr);
return Ok(());
}

let display = Display {
original: attr,
fmt: input.parse()?,
args: parse_token_expr(input, false)?,
was_shorthand: false,
has_bonus_display: false,
})
};
if attrs.display.is_some() {
return Err(Error::new_spanned(
attr,
"only one #[error(...)] attribute is allowed",
));
}
attrs.display = Some(display);
Ok(())
})
}

Expand Down
90 changes: 63 additions & 27 deletions impl/src/expand.rs
@@ -1,7 +1,6 @@
use crate::ast::{Enum, Field, Input, Struct};
use crate::valid;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned};
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::spanned::Spanned;
use syn::{DeriveInput, Member, PathArguments, Result, Type};

Expand All @@ -18,18 +17,30 @@ fn impl_struct(input: Struct) -> TokenStream {
let ty = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

let source_method = input.source_field().map(|source_field| {
let source_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
Some(quote! {
std::error::Error::source(self.#only_field.as_dyn_error())
})
} else if let Some(source_field) = input.source_field() {
let source = &source_field.member;
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
None
};
let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error());
Some(quote! {
std::option::Option::Some(#dyn_error)
})
} else {
None
};
let source_method = source_body.map(|body| {
quote! {
fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> {
use thiserror::private::AsDynError;
std::option::Option::Some(#dyn_error)
#body
}
}
});
Expand Down Expand Up @@ -76,7 +87,12 @@ fn impl_struct(input: Struct) -> TokenStream {
}
});

let display_impl = input.attrs.display.as_ref().map(|display| {
let display_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
Some(quote! {
std::fmt::Display::fmt(&self.#only_field, __formatter)
})
} else if let Some(display) = &input.attrs.display {
let use_as_display = if display.has_bonus_display {
Some(quote! {
#[allow(unused_imports)]
Expand All @@ -86,13 +102,20 @@ fn impl_struct(input: Struct) -> TokenStream {
None
};
let pat = fields_pat(&input.fields);
Some(quote! {
#use_as_display
#[allow(unused_variables)]
let Self #pat = self;
#display
})
} else {
None
};
let display_impl = display_body.map(|body| {
quote! {
impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause {
fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
#use_as_display
#[allow(unused_variables)]
let Self #pat = self;
#display
#body
}
}
}
Expand Down Expand Up @@ -128,22 +151,27 @@ fn impl_enum(input: Enum) -> TokenStream {
let source_method = if input.has_source() {
let arms = input.variants.iter().map(|variant| {
let ident = &variant.ident;
match variant.source_field() {
Some(source_field) => {
let source = &source_field.member;
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
None
};
let dyn_error = quote_spanned!(source.span()=> source #asref.as_dyn_error());
quote! {
#ty::#ident {#source: source, ..} => std::option::Option::Some(#dyn_error),
}
if variant.attrs.transparent.is_some() {
let only_field = &variant.fields[0].member;
let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
quote! {
#ty::#ident {#only_field: transparent} => #source,
}
} else if let Some(source_field) = variant.source_field() {
let source = &source_field.member;
let asref = if type_is_option(source_field.ty) {
Some(quote_spanned!(source.span()=> .as_ref()?))
} else {
None
};
let dyn_error = quote_spanned!(source.span()=> source #asref.as_dyn_error());
quote! {
#ty::#ident {#source: source, ..} => std::option::Option::Some(#dyn_error),
}
None => quote! {
} else {
quote! {
#ty::#ident {..} => std::option::Option::None,
},
}
}
});
Some(quote! {
Expand Down Expand Up @@ -228,8 +256,7 @@ fn impl_enum(input: Enum) -> TokenStream {
v.attrs
.display
.as_ref()
.expect(valid::CHECKED)
.has_bonus_display
.map_or(false, |display| display.has_bonus_display)
}) {
Some(quote! {
#[allow(unused_imports)]
Expand All @@ -244,7 +271,16 @@ fn impl_enum(input: Enum) -> TokenStream {
None
};
let arms = input.variants.iter().map(|variant| {
let display = variant.attrs.display.as_ref().expect(valid::CHECKED);
let display = match &variant.attrs.display {
Some(display) => display.to_token_stream(),
None => {
let only_field = match &variant.fields[0].member {
Member::Named(ident) => ident.clone(),
Member::Unnamed(index) => format_ident!("_{}", index),
};
quote!(std::fmt::Display::fmt(#only_field, __formatter))
}
};
let ident = &variant.ident;
let pat = fields_pat(&variant.fields);
quote! {
Expand Down Expand Up @@ -297,7 +333,7 @@ fn fields_pat(fields: &[Field]) -> TokenStream {
Some(Member::Named(_)) => quote!({ #(#members),* }),
Some(Member::Unnamed(_)) => {
let vars = members.map(|member| match member {
Member::Unnamed(member) => format_ident!("_{}", member.index),
Member::Unnamed(member) => format_ident!("_{}", member),
Member::Named(_) => unreachable!(),
});
quote!((#(#vars),*))
Expand Down
7 changes: 6 additions & 1 deletion impl/src/prop.rs
Expand Up @@ -19,7 +19,7 @@ impl Enum<'_> {
pub(crate) fn has_source(&self) -> bool {
self.variants
.iter()
.any(|variant| variant.source_field().is_some())
.any(|variant| variant.source_field().is_some() || variant.attrs.transparent.is_some())
}

pub(crate) fn has_backtrace(&self) -> bool {
Expand All @@ -30,10 +30,15 @@ impl Enum<'_> {

pub(crate) fn has_display(&self) -> bool {
self.attrs.display.is_some()
|| self.attrs.transparent.is_some()
|| self
.variants
.iter()
.any(|variant| variant.attrs.display.is_some())
|| self
.variants
.iter()
.all(|variant| variant.attrs.transparent.is_some())
}
}

Expand Down
41 changes: 38 additions & 3 deletions impl/src/valid.rs
Expand Up @@ -4,8 +4,6 @@ use quote::ToTokens;
use std::collections::BTreeSet as Set;
use syn::{Error, Member, Result};

pub(crate) const CHECKED: &str = "checked in validation";

impl Input<'_> {
pub(crate) fn validate(&self) -> Result<()> {
match self {
Expand All @@ -18,6 +16,20 @@ impl Input<'_> {
impl Struct<'_> {
fn validate(&self) -> Result<()> {
check_non_field_attrs(&self.attrs)?;
if let Some(transparent) = self.attrs.transparent {
if self.fields.len() != 1 {
return Err(Error::new_spanned(
transparent,
"#[error(transparent)] requires exactly one field",
));
}
if let Some(source) = self.fields.iter().filter_map(|f| f.attrs.source).next() {
return Err(Error::new_spanned(
source,
"transparent error struct can't contain #[source]",
));
}
}
check_field_attrs(&self.fields)?;
for field in &self.fields {
field.validate()?;
Expand All @@ -32,7 +44,8 @@ impl Enum<'_> {
let has_display = self.has_display();
for variant in &self.variants {
variant.validate()?;
if has_display && variant.attrs.display.is_none() {
if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none()
{
return Err(Error::new_spanned(
variant.original,
"missing #[error(\"...\")] display attribute",
Expand All @@ -58,6 +71,20 @@ impl Enum<'_> {
impl Variant<'_> {
fn validate(&self) -> Result<()> {
check_non_field_attrs(&self.attrs)?;
if self.attrs.transparent.is_some() {
if self.fields.len() != 1 {
return Err(Error::new_spanned(
self.original,
"#[error(transparent)] requires exactly one field",
));
}
if let Some(source) = self.fields.iter().filter_map(|f| f.attrs.source).next() {
return Err(Error::new_spanned(
source,
"transparent variant can't contain #[source]",
));
}
}
check_field_attrs(&self.fields)?;
for field in &self.fields {
field.validate()?;
Expand Down Expand Up @@ -97,6 +124,14 @@ fn check_non_field_attrs(attrs: &Attrs) -> Result<()> {
"not expected here; the #[backtrace] attribute belongs on a specific field",
));
}
if let Some(display) = &attrs.display {
if attrs.transparent.is_some() {
return Err(Error::new_spanned(
display.original,
"cannot have both #[error(transparent)] and a display attribute",
));
}
}
Ok(())
}

Expand Down
57 changes: 57 additions & 0 deletions tests/test_transparent.rs
@@ -0,0 +1,57 @@
use anyhow::anyhow;
use std::error::Error as _;
use std::io;
use thiserror::Error;

#[test]
fn test_transparent_struct() {
#[derive(Error, Debug)]
#[error(transparent)]
struct Error(ErrorKind);

#[derive(Error, Debug)]
enum ErrorKind {
#[error("E0")]
E0,
#[error("E1")]
E1(#[from] io::Error),
}

let error = Error(ErrorKind::E0);
assert_eq!("E0", error.to_string());
assert!(error.source().is_none());

let io = io::Error::new(io::ErrorKind::Other, "oh no!");
let error = Error(ErrorKind::from(io));
assert_eq!("E1", error.to_string());
error.source().unwrap().downcast_ref::<io::Error>().unwrap();
}

#[test]
fn test_transparent_enum() {
#[derive(Error, Debug)]
enum Error {
#[error("this failed")]
This,
#[error(transparent)]
Other(anyhow::Error),
}

let error = Error::This;
assert_eq!("this failed", error.to_string());

let error = Error::Other(anyhow!("inner").context("outer"));
assert_eq!("outer", error.to_string());
assert_eq!("inner", error.source().unwrap().to_string());
}

#[test]
fn test_anyhow() {
#[derive(Error, Debug)]
#[error(transparent)]
struct Any(#[from] anyhow::Error);

let error = Any::from(anyhow!("inner").context("outer"));
assert_eq!("outer", error.to_string());
assert_eq!("inner", error.source().unwrap().to_string());
}

0 comments on commit 62e8e66

Please sign in to comment.