Skip to content

Commit

Permalink
Allow serde types to be Decode/Encoded (#434)
Browse files Browse the repository at this point in the history
* Added support for #[bincode(serde)] attribute on fields, added SerdeToBincode helper struct

* Switch to using Compat/BorrowCompat

* Moved all the serde features and functions to its own module

* Fix broken link

* Added support for the bincode(with_serde) attribute in enum variants

* Updated the main documentation on serde, fixed an example not compiling under certain feature flag combinations

* Added #[serde(flatten)] to the list of problematic attributes

* Added better error reporting on invalid attributes
  • Loading branch information
VictorKoenders committed Nov 9, 2021
1 parent a2e7e0e commit cb16c7f
Show file tree
Hide file tree
Showing 15 changed files with 411 additions and 94 deletions.
40 changes: 30 additions & 10 deletions derive/src/derive_enum.rs
@@ -1,5 +1,5 @@
use crate::generate::{FnSelfArg, Generator, StreamBuilder};
use crate::parse::{EnumVariant, Fields};
use crate::parse::{EnumVariant, FieldAttribute, Fields};
use crate::prelude::*;
use crate::Result;

Expand Down Expand Up @@ -80,11 +80,19 @@ impl DeriveEnum {
body.punct(';');
// If we have any fields, encode them all one by one
for field_name in variant.fields.names() {
body.push_parsed(format!(
"bincode::enc::Encode::encode({}, &mut encoder)?;",
field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
))
.unwrap();
if field_name.has_field_attribute(FieldAttribute::WithSerde) {
body.push_parsed(format!(
"bincode::enc::Encode::encode(&bincode::serde::Compat({}), &mut encoder)?;",
field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
))
.unwrap();
} else {
body.push_parsed(format!(
"bincode::enc::Encode::encode({}, &mut encoder)?;",
field_name.to_string_with_prefix(TUPLE_FIELD_PREFIX),
))
.unwrap();
}
}
});
match_body.punct(',');
Expand Down Expand Up @@ -209,9 +217,15 @@ impl DeriveEnum {
variant_body.ident(field.unwrap_ident().clone());
}
variant_body.punct(':');
variant_body
.push_parsed("bincode::Decode::decode(&mut decoder)?,")
.unwrap();
if field.has_field_attribute(FieldAttribute::WithSerde) {
variant_body
.push_parsed("<bincode::serde::Compat<_> as bincode::Decode>::decode(&mut decoder)?.0,")
.unwrap();
} else {
variant_body
.push_parsed("bincode::Decode::decode(&mut decoder)?,")
.unwrap();
}
}
});
});
Expand Down Expand Up @@ -269,7 +283,13 @@ impl DeriveEnum {
variant_body.ident(field.unwrap_ident().clone());
}
variant_body.punct(':');
variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap();
if field.has_field_attribute(FieldAttribute::WithSerde) {
variant_body
.push_parsed("<bincode::serde::BorrowCompat<_> as bincode::BorrowDecode>::borrow_decode(&mut decoder)?.0,")
.unwrap();
} else {
variant_body.push_parsed("bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,").unwrap();
}
}
});
});
Expand Down
65 changes: 46 additions & 19 deletions derive/src/derive_struct.rs
@@ -1,5 +1,5 @@
use crate::generate::Generator;
use crate::parse::Fields;
use crate::parse::{FieldAttribute, Fields};
use crate::prelude::Delimiter;
use crate::Result;

Expand All @@ -21,12 +21,21 @@ impl DeriveStruct {
.with_return_type("core::result::Result<(), bincode::error::EncodeError>")
.body(|fn_body| {
for field in fields.names() {
fn_body
.push_parsed(format!(
"bincode::enc::Encode::encode(&self.{}, &mut encoder)?;",
field.to_string()
))
.unwrap();
if field.has_field_attribute(FieldAttribute::WithSerde) {
fn_body
.push_parsed(format!(
"bincode::Encode::encode(&bincode::serde::Compat(&self.{}), &mut encoder)?;",
field.to_string()
))
.unwrap();
} else {
fn_body
.push_parsed(format!(
"bincode::enc::Encode::encode(&self.{}, &mut encoder)?;",
field.to_string()
))
.unwrap();
}
}
fn_body.push_parsed("Ok(())").unwrap();
})
Expand Down Expand Up @@ -59,12 +68,21 @@ impl DeriveStruct {
// ...
// }
for field in fields.names() {
struct_body
.push_parsed(format!(
"{}: bincode::Decode::decode(&mut decoder)?,",
field.to_string()
))
.unwrap();
if field.has_field_attribute(FieldAttribute::WithSerde) {
struct_body
.push_parsed(format!(
"{}: (<bincode::serde::Compat<_> as bincode::Decode>::decode(&mut decoder)?).0,",
field.to_string()
))
.unwrap();
} else {
struct_body
.push_parsed(format!(
"{}: bincode::Decode::decode(&mut decoder)?,",
field.to_string()
))
.unwrap();
}
}
});
});
Expand Down Expand Up @@ -92,12 +110,21 @@ impl DeriveStruct {
ok_group.ident_str("Self");
ok_group.group(Delimiter::Brace, |struct_body| {
for field in fields.names() {
struct_body
.push_parsed(format!(
"{}: bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,",
field.to_string()
))
.unwrap();
if field.has_field_attribute(FieldAttribute::WithSerde) {
struct_body
.push_parsed(format!(
"{}: (<bincode::serde::BorrowCompat<_> as bincode::de::BorrowDecode>::borrow_decode(&mut decoder)?).0,",
field.to_string()
))
.unwrap();
} else {
struct_body
.push_parsed(format!(
"{}: bincode::de::BorrowDecode::borrow_decode(&mut decoder)?,",
field.to_string()
))
.unwrap();
}
}
});
});
Expand Down
2 changes: 1 addition & 1 deletion derive/src/error.rs
Expand Up @@ -9,7 +9,7 @@ pub enum Error {
}

impl Error {
pub fn wrong_token<T>(token: Option<&TokenTree>, expected: &'static str) -> Result<T, Self> {
pub fn wrong_token<T>(token: Option<&TokenTree>, expected: &str) -> Result<T, Self> {
Err(Self::InvalidRustSyntax {
span: token.map(|t| t.span()).unwrap_or_else(Span::call_site),
expected: format!("{}, got {:?}", expected, token),
Expand Down
13 changes: 7 additions & 6 deletions derive/src/lib.rs
Expand Up @@ -16,11 +16,12 @@ pub(crate) mod prelude {
}

use error::Error;
use parse::AttributeLocation;
use prelude::TokenStream;

type Result<T = ()> = std::result::Result<T, Error>;

#[proc_macro_derive(Encode)]
#[proc_macro_derive(Encode, attributes(bincode))]
pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
#[allow(clippy::useless_conversion)]
derive_encode_inner(input.into())
Expand All @@ -31,7 +32,7 @@ pub fn derive_encode(input: proc_macro::TokenStream) -> proc_macro::TokenStream
fn derive_encode_inner(input: TokenStream) -> Result<TokenStream> {
let source = &mut input.into_iter().peekable();

let _attributes = parse::Attribute::try_take(source)?;
let _attributes = parse::Attribute::try_take(AttributeLocation::Container, source)?;
let _visibility = parse::Visibility::try_take(source)?;
let (datatype, name) = parse::DataType::take(source)?;
let generics = parse::Generics::try_take(source)?;
Expand Down Expand Up @@ -61,7 +62,7 @@ fn derive_encode_inner(input: TokenStream) -> Result<TokenStream> {
Ok(stream)
}

#[proc_macro_derive(Decode)]
#[proc_macro_derive(Decode, attributes(bincode))]
pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
#[allow(clippy::useless_conversion)]
derive_decode_inner(input.into())
Expand All @@ -72,7 +73,7 @@ pub fn derive_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream
fn derive_decode_inner(input: TokenStream) -> Result<TokenStream> {
let source = &mut input.into_iter().peekable();

let _attributes = parse::Attribute::try_take(source)?;
let _attributes = parse::Attribute::try_take(AttributeLocation::Container, source)?;
let _visibility = parse::Visibility::try_take(source)?;
let (datatype, name) = parse::DataType::take(source)?;
let generics = parse::Generics::try_take(source)?;
Expand Down Expand Up @@ -102,7 +103,7 @@ fn derive_decode_inner(input: TokenStream) -> Result<TokenStream> {
Ok(stream)
}

#[proc_macro_derive(BorrowDecode)]
#[proc_macro_derive(BorrowDecode, attributes(bincode))]
pub fn derive_brrow_decode(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
#[allow(clippy::useless_conversion)]
derive_borrow_decode_inner(input.into())
Expand All @@ -113,7 +114,7 @@ pub fn derive_brrow_decode(input: proc_macro::TokenStream) -> proc_macro::TokenS
fn derive_borrow_decode_inner(input: TokenStream) -> Result<TokenStream> {
let source = &mut input.into_iter().peekable();

let _attributes = parse::Attribute::try_take(source)?;
let _attributes = parse::Attribute::try_take(AttributeLocation::Container, source)?;
let _visibility = parse::Visibility::try_take(source)?;
let (datatype, name) = parse::DataType::take(source)?;
let generics = parse::Generics::try_take(source)?;
Expand Down
87 changes: 73 additions & 14 deletions derive/src/parse/attributes.rs
@@ -1,28 +1,62 @@
use super::{assume_group, assume_punct};
use crate::parse::consume_punct_if;
use super::{assume_group, assume_ident, assume_punct};
use crate::parse::{consume_punct_if, ident_eq};
use crate::prelude::{Delimiter, Group, Punct, TokenTree};
use crate::{Error, Result};
use std::iter::Peekable;

#[derive(Debug)]
pub struct Attribute {
// we don't use these fields yet
#[allow(dead_code)]
punct: Punct,
#[allow(dead_code)]
tokens: Option<Group>,
pub enum Attribute {
Field(FieldAttribute),
Unknown { punct: Punct, tokens: Option<Group> },
}
#[derive(Debug, PartialEq)]
pub enum FieldAttribute {
/// The field is a serde type and should implement Encode/Decode through a wrapper
WithSerde,
}

#[derive(PartialEq, Eq, Debug, Hash, Copy, Clone)]
pub enum AttributeLocation {
Container,
Variant,
Field,
}

impl Attribute {
pub fn try_take(input: &mut Peekable<impl Iterator<Item = TokenTree>>) -> Result<Vec<Self>> {
pub fn try_take(
loc: AttributeLocation,
input: &mut Peekable<impl Iterator<Item = TokenTree>>,
) -> Result<Vec<Self>> {
let mut result = Vec::new();

while let Some(punct) = consume_punct_if(input, '#') {
match input.peek() {
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
result.push(Attribute {
let group = assume_group(input.next());
let stream = &mut group.stream().into_iter().peekable();
if let Some(TokenTree::Ident(attribute_ident)) = stream.peek() {
if super::ident_eq(attribute_ident, "bincode") {
assume_ident(stream.next());
match stream.next() {
Some(TokenTree::Group(group)) => {
result.push(Self::parse_bincode_attribute(
loc,
&mut group.stream().into_iter().peekable(),
)?);
}
token => {
return Error::wrong_token(
token.as_ref(),
"Bracketed group of attributes",
)
}
}
continue;
}
}
result.push(Attribute::Unknown {
punct,
tokens: Some(assume_group(input.next())),
tokens: Some(group),
});
}
Some(TokenTree::Group(g)) => {
Expand All @@ -34,7 +68,7 @@ impl Attribute {
Some(TokenTree::Punct(p)) if p.as_char() == '#' => {
// sometimes with empty lines of doc comments, we get two #'s in a row
// add an empty attributes and continue to the next loop
result.push(Attribute {
result.push(Attribute::Unknown {
punct: assume_punct(input.next(), '#'),
tokens: None,
})
Expand All @@ -44,21 +78,46 @@ impl Attribute {
}
Ok(result)
}

fn parse_bincode_attribute(
loc: AttributeLocation,
stream: &mut Peekable<impl Iterator<Item = TokenTree>>,
) -> Result<Self> {
match (stream.next(), loc) {
(Some(TokenTree::Ident(ident)), AttributeLocation::Field)
if ident_eq(&ident, "with_serde") =>
{
Ok(Self::Field(FieldAttribute::WithSerde))
}
(token @ Some(TokenTree::Ident(_)), AttributeLocation::Field) => {
Error::wrong_token(token.as_ref(), "one of: `with_serde`")
}
(token @ Some(TokenTree::Ident(_)), loc) => Error::wrong_token(
token.as_ref(),
&format!("{:?} attributes not supported", loc),
),
(token, _) => Error::wrong_token(token.as_ref(), "ident"),
}
}
}

#[test]
fn test_attributes_try_take() {
use crate::token_stream;

let stream = &mut token_stream("struct Foo;");
assert!(Attribute::try_take(stream).unwrap().is_empty());
assert!(Attribute::try_take(AttributeLocation::Container, stream)
.unwrap()
.is_empty());
match stream.next().unwrap() {
TokenTree::Ident(i) => assert_eq!(i, "struct"),
x => panic!("Expected ident, found {:?}", x),
}

let stream = &mut token_stream("#[cfg(test)] struct Foo;");
assert!(!Attribute::try_take(stream).unwrap().is_empty());
assert!(!Attribute::try_take(AttributeLocation::Container, stream)
.unwrap()
.is_empty());
match stream.next().unwrap() {
TokenTree::Ident(i) => assert_eq!(i, "struct"),
x => panic!("Expected ident, found {:?}", x),
Expand Down

0 comments on commit cb16c7f

Please sign in to comment.