Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the context selector struct name to be specified #210

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 69 additions & 32 deletions snafu-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct VariantInfo {

enum ContextSelectorKind {
Context {
selector_name: syn::Ident,
source_field: Option<SourceField>,
user_fields: Vec<Field>,
},
Expand Down Expand Up @@ -723,34 +724,48 @@ fn parse_snafu_enum(
let (visibility, errs) = visibilities.finish();
errors.extend(errs);

let (is_context, errs) = contexts.finish();
let (context_arg, errs) = contexts.finish();
errors.extend(errs);

let source_field = source.map(|(val, _tts)| val);

let selector_kind = if is_context.unwrap_or(true) {
ContextSelectorKind::Context {
source_field,
user_fields,
}
} else {
errors.extend(
user_fields.into_iter().map(|Field { original, .. }| {
syn::Error::new_spanned(
original,
"Context selectors without context must not have context fields",
)
})
);

let source_field = source_field.ok_or_else(|| {
vec![syn::Error::new(
variant_span,
"Context selectors without context must have a source field",
)]
})?;
// if no context argument is specified, that's the same as specifying context(true)
let context_arg = context_arg.unwrap_or(ContextAttributeArgument::Enabled(true));

ContextSelectorKind::NoContext { source_field }
let selector_kind = match context_arg {
ContextAttributeArgument::Enabled(enabled) => {
if enabled {
ContextSelectorKind::Context {
selector_name: name.clone(),
source_field,
user_fields,
}
} else {
errors.extend(
user_fields.into_iter().map(|Field { original, .. }| {
syn::Error::new_spanned(
original,
"Context selectors without context must not have context fields",
)
})
);

let source_field = source_field.ok_or_else(|| {
vec![syn::Error::new(
variant_span,
"Context selectors without context must have a source field",
)]
})?;
ContextSelectorKind::NoContext { source_field }
}
},
ContextAttributeArgument::CustomName(selector_name) => {
ContextSelectorKind::Context {
selector_name,
source_field,
user_fields,
}
}
};

Ok(VariantInfo {
Expand Down Expand Up @@ -1013,10 +1028,15 @@ enum SnafuAttribute {
Visibility(proc_macro2::TokenStream, UserInput),
Source(proc_macro2::TokenStream, Vec<Source>),
Backtrace(proc_macro2::TokenStream, bool),
Context(proc_macro2::TokenStream, bool),
Context(proc_macro2::TokenStream, ContextAttributeArgument),
DocComment(proc_macro2::TokenStream, String),
}

enum ContextAttributeArgument {
Enabled(bool),
CustomName(syn::Ident),
}

impl syn::parse::Parse for SnafuAttribute {
fn parse(input: syn::parse::ParseStream) -> SynResult<Self> {
use syn::token::{Comma, Paren};
Expand Down Expand Up @@ -1060,10 +1080,26 @@ impl syn::parse::Parse for SnafuAttribute {
}
} else if name == "context" {
if input.is_empty() {
Ok(SnafuAttribute::Context(input_tts, true))
Ok(SnafuAttribute::Context(
input_tts,
ContextAttributeArgument::Enabled(true),
))
} else {
let v: MyParens<LitBool> = input.parse()?;
Ok(SnafuAttribute::Context(input_tts, v.0.value))
let inside;
parenthesized!(inside in input);
if inside.peek(LitBool) {
let v: LitBool = inside.parse()?;
Ok(SnafuAttribute::Context(
input_tts,
ContextAttributeArgument::Enabled(v.value),
))
} else {
let v: Ident = inside.parse()?;
Ok(SnafuAttribute::Context(
input_tts,
ContextAttributeArgument::CustomName(v),
))
}
}
} else {
Err(syn::Error::new(
Expand Down Expand Up @@ -1330,6 +1366,7 @@ impl<'a> quote::ToTokens for ContextSelector<'a> {

match selector_kind {
ContextSelectorKind::Context {
selector_name,
user_fields,
source_field,
} => {
Expand All @@ -1344,7 +1381,7 @@ impl<'a> quote::ToTokens for ContextSelector<'a> {
.unwrap_or(&self.0.default_visibility);

let generics_list = quote! { <#(#original_lifetimes,)* #(#generic_names,)* #(#original_generic_types_without_defaults,)*> };
let selector_name = quote! { #variant_name<#(#generic_names,)*> };
let selector_type = quote! { #selector_name<#(#generic_names,)*> };

let names: Vec<_> = user_fields.iter().map(|f| f.name.clone()).collect();
let selector_doc = format!(
Expand All @@ -1357,15 +1394,15 @@ impl<'a> quote::ToTokens for ContextSelector<'a> {
quote! {
#[derive(Debug, Copy, Clone)]
#[doc = #selector_doc]
#visibility struct #selector_name;
#visibility struct #selector_type;
}
} else {
let visibilities = iter::repeat(visibility);

quote! {
#[derive(Debug, Copy, Clone)]
#[doc = #selector_doc]
#visibility struct #selector_name {
#visibility struct #selector_type {
#(
#[allow(missing_docs)]
#visibilities #names: #generic_names
Expand All @@ -1387,7 +1424,7 @@ impl<'a> quote::ToTokens for ContextSelector<'a> {

let inherent_impl = if source_field.is_none() {
quote! {
impl<#(#generic_names,)*> #selector_name
impl<#(#generic_names,)*> #selector_type
{
#[doc = "Consume the selector and return a `Result` with the associated error"]
#visibility fn fail<#(#original_generics_without_defaults,)* __T>(self) -> core::result::Result<__T, #parameterized_enum_name>
Expand Down Expand Up @@ -1438,7 +1475,7 @@ impl<'a> quote::ToTokens for ContextSelector<'a> {
}

quote! {
impl#generics_list snafu::IntoError<#parameterized_enum_name> for #selector_name
impl#generics_list snafu::IntoError<#parameterized_enum_name> for #selector_type
where
#parameterized_enum_name: snafu::Error + snafu::ErrorCompat,
#(#where_clauses),*
Expand Down
32 changes: 32 additions & 0 deletions src/guide/attributes.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,38 @@ fn main() {
}
```

## Specifying the generated context selector type's name

Sometimes, you may already have a struct with the same name as one of
the variants of your error enum. You might also have 2 error enums
that share a variant. In those cases, there may be naming collisions
with the context selector type(s) generated by Snafu, which by default
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proper (non-code) name of the library is "SNAFU".

have the same name as the enum variant.

To solve this, you can specify the name of the context selector that
Snafu should generate, using `#[snafu(context(MySelectorName))]`.

**Example**

```rust
# use snafu::{Snafu, ResultExt};
#
// some struct not related to the error
struct Foo;

#[derive(Debug, Snafu)]
enum Error {
#[snafu(context(FooContext))]
Foo {
source: std::io::Error
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
source: std::io::Error
source: std::io::Error,

},
}

fn read_file() -> Result<Vec<u8>, Error> {
std::fs::read("/some/file/that/doesnt/exist").context(FooContext)
}
```

## Controlling context

Sometimes, an underlying error can only occur in exactly one context
Expand Down
3 changes: 2 additions & 1 deletion src/guide/the_macro.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ struct UserIdInvalid<I> { user_id: I }
Notably:

1. One context selector is created for each enum variant.
1. The name of the selector is the same as the enum variant's name.
1. The name of the selector is the same as the enum variant's name,
unless a different name is specified using the `context` attribute.
1. The `source` and `backtrace` fields have been removed; the
library will automatically handle this for you.
1. Each remaining field's type has been replaced with a generic
Expand Down
59 changes: 59 additions & 0 deletions tests/custom_selector_name.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// This test is the same as basic.rs but with custom context selectors

use snafu::{ResultExt, Snafu};
use std::{
fs, io,
path::{Path, PathBuf},
};

#[derive(Debug, Snafu)]
enum Error {
#[snafu(
context(OpenConfigContext),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing it in use, I wonder if we want to give some more "namespacing". For example:

        context(name(OpenConfigContext)),

Or

        context(selector(OpenConfigContext)),

Or

        context(selector_name(OpenConfigContext)),

display = r#"("Could not open config file at {}: {}", filename.display(), source)"#
)]
OpenConfig {
filename: PathBuf,
source: io::Error,
},
#[snafu(
context(SaveConfigContext),
display = r#"("Could not open config file at {}", source)"#
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to use the modern style in these tests:

        display("Could not open config file at {}", source)

)]
SaveConfig { source: io::Error },
#[snafu(
context(InvalidUserContext),
display = r#"("User ID {} is invalid", user_id)"#
)]
Comment on lines +24 to +27
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This style is fine, which is why it compiles, but I just find it ugly for whatever reason. FWIW, I tend to split it into two:

    #[snafu(context(InvalidUserContext))]
    #[snafu(display = r#"("User ID {} is invalid", user_id)"#)]

InvalidUser { user_id: i32 },
#[snafu(context(MissingUserContext), display("No user available"))]
MissingUser,
}

type Result<T, E = Error> = std::result::Result<T, E>;

const CONFIG_FILENAME: &str = "/tmp/config";

fn example(root: impl AsRef<Path>, user_id: Option<i32>) -> Result<()> {
let root = root.as_ref();
let filename = &root.join(CONFIG_FILENAME);

let config = fs::read(filename).context(OpenConfigContext { filename })?;

let _user_id = match user_id {
None => MissingUserContext.fail()?,
Some(user_id) if user_id != 42 => InvalidUserContext { user_id }.fail()?,
Some(user_id) => user_id,
};

fs::write(filename, config).context(SaveConfigContext)?;

Ok(())
}

#[test]
fn implements_error() {
fn check<T: std::error::Error>() {}
check::<Error>();
example("/some/directory/that/does/not/exist", None).unwrap_err();
}