Skip to content

Commit

Permalink
Merge pull request #1213 from Osspial/assoc_type_derive
Browse files Browse the repository at this point in the history
Support #[derive(Serialize, Deserialize)] when using associated types
  • Loading branch information
dtolnay committed Apr 13, 2018
2 parents 5efb22e + 629bf7b commit a6e94e7
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 9 deletions.
27 changes: 24 additions & 3 deletions serde_derive/src/bound.rs
Expand Up @@ -50,18 +50,39 @@ pub fn with_where_predicates(
generics
}

pub fn with_where_predicates_from_fields<F>(
pub fn with_where_predicates_from_fields<F, W>(
cont: &Container,
generics: &syn::Generics,
trait_bound: &syn::Path,
from_field: F,
gen_bound_where: W,
) -> syn::Generics
where
F: Fn(&attr::Field) -> Option<&[syn::WherePredicate]>,
W: Fn(&attr::Field) -> bool,
{
let predicates = cont.data
.all_fields()
.flat_map(|field| from_field(&field.attrs))
.flat_map(|predicates| predicates.to_vec());
.flat_map(|field| {
let field_ty = field.ty;
let matching_generic = |t: &syn::PathSegment, g: &syn::GenericParam| match *g {
syn::GenericParam::Type(ref generic_ty)
if generic_ty.ident == t.ident => true,
_ => false
};

let mut field_bound: Option<syn::WherePredicate> = None;
if let syn::Type::Path(ref ty_path) = *field_ty {
field_bound = match (gen_bound_where(&field.attrs), ty_path.path.segments.first()) {
(true, Some(syn::punctuated::Pair::Punctuated(ref t, _))) =>
if generics.params.iter().any(|g| matching_generic(t, g)) {
Some(parse_quote!(#field_ty: #trait_bound))
} else {None},
(_, _) => None
};
}
field_bound.into_iter().chain(from_field(&field.attrs).into_iter().flat_map(|predicates| predicates.to_vec()))
});

let mut generics = generics.clone();
generics.make_where_clause()
Expand Down
13 changes: 10 additions & 3 deletions serde_derive/src/de.rs
Expand Up @@ -124,7 +124,15 @@ impl Parameters {
fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generics {
let generics = bound::without_defaults(cont.generics);

let generics = bound::with_where_predicates_from_fields(cont, &generics, attr::Field::de_bound);
let delife = borrowed.de_lifetime();
let de_bound = parse_quote!(_serde::Deserialize<#delife>);
let generics = bound::with_where_predicates_from_fields(
cont,
&generics,
&de_bound,
attr::Field::de_bound,
|field| field.deserialize_with().is_none() && !field.skip_deserializing()
);

match cont.attrs.de_bound() {
Some(predicates) => bound::with_where_predicates(&generics, predicates),
Expand All @@ -136,12 +144,11 @@ fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generi
attr::Default::None | attr::Default::Path(_) => generics,
};

let delife = borrowed.de_lifetime();
let generics = bound::with_bound(
cont,
&generics,
needs_deserialize_bound,
&parse_quote!(_serde::Deserialize<#delife>),
&de_bound,
);

bound::with_bound(
Expand Down
12 changes: 9 additions & 3 deletions serde_derive/src/ser.rs
Expand Up @@ -130,16 +130,22 @@ impl Parameters {
fn build_generics(cont: &Container) -> syn::Generics {
let generics = bound::without_defaults(cont.generics);

let generics =
bound::with_where_predicates_from_fields(cont, &generics, attr::Field::ser_bound);
let trait_bound = parse_quote!(_serde::Serialize);
let generics = bound::with_where_predicates_from_fields(
cont,
&generics,
&trait_bound,
attr::Field::ser_bound,
|field| field.serialize_with().is_none() && !field.skip_serializing()
);

match cont.attrs.ser_bound() {
Some(predicates) => bound::with_where_predicates(&generics, predicates),
None => bound::with_bound(
cont,
&generics,
needs_serialize_bound,
&parse_quote!(_serde::Serialize),
&trait_bound
),
}
}
Expand Down
24 changes: 24 additions & 0 deletions test_suite/tests/test_gen.rs
Expand Up @@ -539,6 +539,30 @@ fn test_gen() {
array: [u8; 256],
}
assert_ser::<BigArray>();

trait AssocSerde {
type Assoc;
}

struct NoSerdeImpl;
impl AssocSerde for NoSerdeImpl {
type Assoc = u32;
}

#[derive(Serialize, Deserialize)]
struct AssocDerive<T: AssocSerde> {
assoc: T::Assoc
}

assert::<AssocDerive<NoSerdeImpl>>();

#[derive(Serialize, Deserialize)]
struct AssocDeriveMulti<S, T: AssocSerde> {
s: S,
assoc: T::Assoc,
}

assert::<AssocDeriveMulti<i32, NoSerdeImpl>>();
}

//////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit a6e94e7

Please sign in to comment.