extern crate proc_macro; use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::*; mod container_attributes; mod field_attributes; use container_attributes::ContainerAttributes; use field_attributes::{determine_field_constructor, FieldConstructor}; static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary"; static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary"; #[proc_macro_derive(Arbitrary, attributes(arbitrary))] pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = syn::parse_macro_input!(tokens as syn::DeriveInput); expand_derive_arbitrary(input) .unwrap_or_else(syn::Error::into_compile_error) .into() } fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result { let container_attrs = ContainerAttributes::from_derive_input(&input)?; let (lifetime_without_bounds, lifetime_with_bounds) = build_arbitrary_lifetime(input.generics.clone()); let recursive_count = syn::Ident::new( &format!("RECURSIVE_COUNT_{}", input.ident), Span::call_site(), ); let arbitrary_method = gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?; let size_hint_method = gen_size_hint_method(&input)?; let name = input.ident; // Apply user-supplied bounds or automatic `T: ArbitraryBounds`. let generics = apply_trait_bounds( input.generics, lifetime_without_bounds.clone(), &container_attrs, )?; // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90) let mut generics_with_lifetime = generics.clone(); generics_with_lifetime .params .push(GenericParam::Lifetime(lifetime_with_bounds)); let (impl_generics, _, _) = generics_with_lifetime.split_for_impl(); // Build TypeGenerics and WhereClause without a lifetime let (_, ty_generics, where_clause) = generics.split_for_impl(); Ok(quote! { const _: () = { std::thread_local! { #[allow(non_upper_case_globals)] static #recursive_count: std::cell::Cell = std::cell::Cell::new(0); } #[automatically_derived] impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause { #arbitrary_method #size_hint_method } }; }) } // Returns: (lifetime without bounds, lifetime with bounds) // Example: ("'arbitrary", "'arbitrary: 'a + 'b") fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) { let lifetime_without_bounds = LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site())); let mut lifetime_with_bounds = lifetime_without_bounds.clone(); for param in generics.params.iter() { if let GenericParam::Lifetime(lifetime_def) = param { lifetime_with_bounds .bounds .push(lifetime_def.lifetime.clone()); } } (lifetime_without_bounds, lifetime_with_bounds) } fn apply_trait_bounds( mut generics: Generics, lifetime: LifetimeParam, container_attrs: &ContainerAttributes, ) -> Result { // If user-supplied bounds exist, apply them to their matching type parameters. if let Some(config_bounds) = &container_attrs.bounds { let mut config_bounds_applied = 0; for param in generics.params.iter_mut() { if let GenericParam::Type(type_param) = param { if let Some(replacement) = config_bounds .iter() .flatten() .find(|p| p.ident == type_param.ident) { *type_param = replacement.clone(); config_bounds_applied += 1; } else { // If no user-supplied bounds exist for this type, delete the original bounds. // This mimics serde. type_param.bounds = Default::default(); type_param.default = None; } } } let config_bounds_supplied = config_bounds .iter() .map(|bounds| bounds.len()) .sum::(); if config_bounds_applied != config_bounds_supplied { return Err(Error::new( Span::call_site(), format!( "invalid `{}` attribute. too many bounds, only {} out of {} are applicable", ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied, ), )); } Ok(generics) } else { // Otherwise, inject a `T: Arbitrary` bound for every parameter. Ok(add_trait_bounds(generics, lifetime)) } } // Add a bound `T: Arbitrary` to every type parameter T. fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics { for param in generics.params.iter_mut() { if let GenericParam::Type(type_param) = param { type_param .bounds .push(parse_quote!(arbitrary::Arbitrary<#lifetime>)); } } generics } fn with_recursive_count_guard( recursive_count: &syn::Ident, expr: impl quote::ToTokens, ) -> impl quote::ToTokens { quote! { let guard_against_recursion = u.is_empty(); if guard_against_recursion { #recursive_count.with(|count| { if count.get() > 0 { return Err(arbitrary::Error::NotEnoughData); } count.set(count.get() + 1); Ok(()) })?; } let result = (|| { #expr })(); if guard_against_recursion { #recursive_count.with(|count| { count.set(count.get() - 1); }); } result } } fn gen_arbitrary_method( input: &DeriveInput, lifetime: LifetimeParam, recursive_count: &syn::Ident, ) -> Result { fn arbitrary_structlike( fields: &Fields, ident: &syn::Ident, lifetime: LifetimeParam, recursive_count: &syn::Ident, ) -> Result { let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?; let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) }); let arbitrary_take_rest = construct_take_rest(fields)?; let take_rest_body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) }); Ok(quote! { fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { #body } fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { #take_rest_body } }) } let ident = &input.ident; let output = match &input.data { Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?, Data::Union(data) => arbitrary_structlike( &Fields::Named(data.fields.clone()), ident, lifetime, recursive_count, )?, Data::Enum(data) => { let variants: Vec = data .variants .iter() .enumerate() .map(|(i, variant)| { let idx = i as u64; let variant_name = &variant.ident; construct(&variant.fields, |_, field| gen_constructor_for_field(field)) .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) }) .collect::>()?; let variants_take_rest: Vec = data .variants .iter() .enumerate() .map(|(i, variant)| { let idx = i as u64; let variant_name = &variant.ident; construct_take_rest(&variant.fields) .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) }) .collect::>()?; let count = data.variants.len() as u64; let arbitrary = with_recursive_count_guard( recursive_count, quote! { // Use a multiply + shift to generate a ranged random number // with slight bias. For details, see: // https://lemire.me/blog/2016/06/30/fast-random-shuffling Ok(match (u64::from(::arbitrary(u)?) * #count) >> 32 { #(#variants,)* _ => unreachable!() }) }, ); let arbitrary_take_rest = with_recursive_count_guard( recursive_count, quote! { // Use a multiply + shift to generate a ranged random number // with slight bias. For details, see: // https://lemire.me/blog/2016/06/30/fast-random-shuffling Ok(match (u64::from(::arbitrary(&mut u)?) * #count) >> 32 { #(#variants_take_rest,)* _ => unreachable!() }) }, ); quote! { fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { #arbitrary } fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { #arbitrary_take_rest } } } }; Ok(output) } fn construct( fields: &Fields, ctor: impl Fn(usize, &Field) -> Result, ) -> Result { let output = match fields { Fields::Named(names) => { let names: Vec = names .named .iter() .enumerate() .map(|(i, f)| { let name = f.ident.as_ref().unwrap(); ctor(i, f).map(|ctor| quote! { #name: #ctor }) }) .collect::>()?; quote! { { #(#names,)* } } } Fields::Unnamed(names) => { let names: Vec = names .unnamed .iter() .enumerate() .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor })) .collect::>()?; quote! { ( #(#names),* ) } } Fields::Unit => quote!(), }; Ok(output) } fn construct_take_rest(fields: &Fields) -> Result { construct(fields, |idx, field| { determine_field_constructor(field).map(|field_constructor| match field_constructor { FieldConstructor::Default => quote!(Default::default()), FieldConstructor::Arbitrary => { if idx + 1 == fields.len() { quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? } } else { quote! { arbitrary::Arbitrary::arbitrary(&mut u)? } } } FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?), FieldConstructor::Value(value) => quote!(#value), }) }) } fn gen_size_hint_method(input: &DeriveInput) -> Result { let size_hint_fields = |fields: &Fields| { fields .iter() .map(|f| { let ty = &f.ty; determine_field_constructor(f).map(|field_constructor| { match field_constructor { FieldConstructor::Default | FieldConstructor::Value(_) => { quote!((0, Some(0))) } FieldConstructor::Arbitrary => { quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) } } // Note that in this case it's hard to determine what size_hint must be, so size_of::() is // just an educated guess, although it's gonna be inaccurate for dynamically // allocated types (Vec, HashMap, etc.). FieldConstructor::With(_) => { quote! { (::core::mem::size_of::<#ty>(), None) } } } }) }) .collect::>>() .map(|hints| { quote! { arbitrary::size_hint::and_all(&[ #( #hints ),* ]) } }) }; let size_hint_structlike = |fields: &Fields| { size_hint_fields(fields).map(|hint| { quote! { #[inline] fn size_hint(depth: usize) -> (usize, Option) { arbitrary::size_hint::recursion_guard(depth, |depth| #hint) } } }) }; match &input.data { Data::Struct(data) => size_hint_structlike(&data.fields), Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())), Data::Enum(data) => data .variants .iter() .map(|v| size_hint_fields(&v.fields)) .collect::>>() .map(|variants| { quote! { #[inline] fn size_hint(depth: usize) -> (usize, Option) { arbitrary::size_hint::and( ::size_hint(depth), arbitrary::size_hint::recursion_guard(depth, |depth| { arbitrary::size_hint::or_all(&[ #( #variants ),* ]) }), ) } } }), } } fn gen_constructor_for_field(field: &Field) -> Result { let ctor = match determine_field_constructor(field)? { FieldConstructor::Default => quote!(Default::default()), FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?), FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?), FieldConstructor::Value(value) => quote!(#value), }; Ok(ctor) }