diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
commit | 43a97878ce14b72f0981164f87f2e35e14151312 (patch) | |
tree | 620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/rust/derive_arbitrary/src/lib.rs | |
parent | Initial commit. (diff) | |
download | firefox-43a97878ce14b72f0981164f87f2e35e14151312.tar.xz firefox-43a97878ce14b72f0981164f87f2e35e14151312.zip |
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/derive_arbitrary/src/lib.rs')
-rw-r--r-- | third_party/rust/derive_arbitrary/src/lib.rs | 347 |
1 files changed, 347 insertions, 0 deletions
diff --git a/third_party/rust/derive_arbitrary/src/lib.rs b/third_party/rust/derive_arbitrary/src/lib.rs new file mode 100644 index 0000000000..b7292ab7bb --- /dev/null +++ b/third_party/rust/derive_arbitrary/src/lib.rs @@ -0,0 +1,347 @@ +extern crate proc_macro; + +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::*; + +mod field_attributes; +use field_attributes::{determine_field_constructor, FieldConstructor}; + +static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary"; + +#[proc_macro_derive(Arbitrary, attributes(arbitrary))] +pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = syn::parse_macro_input!(tokens as syn::DeriveInput); + expand_derive_arbitrary(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() +} + +fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> { + let (lifetime_without_bounds, lifetime_with_bounds) = + build_arbitrary_lifetime(input.generics.clone()); + + let recursive_count = syn::Ident::new( + &format!("RECURSIVE_COUNT_{}", input.ident), + Span::call_site(), + ); + + let arbitrary_method = + gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?; + let size_hint_method = gen_size_hint_method(&input)?; + let name = input.ident; + // Add a bound `T: Arbitrary` to every type parameter T. + let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone()); + + // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90) + let mut generics_with_lifetime = generics.clone(); + generics_with_lifetime + .params + .push(GenericParam::Lifetime(lifetime_with_bounds)); + let (impl_generics, _, _) = generics_with_lifetime.split_for_impl(); + + // Build TypeGenerics and WhereClause without a lifetime + let (_, ty_generics, where_clause) = generics.split_for_impl(); + + Ok(quote! { + const _: () = { + std::thread_local! { + #[allow(non_upper_case_globals)] + static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0); + } + + impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause { + #arbitrary_method + #size_hint_method + } + }; + }) +} + +// Returns: (lifetime without bounds, lifetime with bounds) +// Example: ("'arbitrary", "'arbitrary: 'a + 'b") +fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) { + let lifetime_without_bounds = + LifetimeDef::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site())); + let mut lifetime_with_bounds = lifetime_without_bounds.clone(); + + for param in generics.params.iter() { + if let GenericParam::Lifetime(lifetime_def) = param { + lifetime_with_bounds + .bounds + .push(lifetime_def.lifetime.clone()); + } + } + + (lifetime_without_bounds, lifetime_with_bounds) +} + +// Add a bound `T: Arbitrary` to every type parameter T. +fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics { + for param in generics.params.iter_mut() { + if let GenericParam::Type(type_param) = param { + type_param + .bounds + .push(parse_quote!(arbitrary::Arbitrary<#lifetime>)); + } + } + generics +} + +fn with_recursive_count_guard( + recursive_count: &syn::Ident, + expr: impl quote::ToTokens, +) -> impl quote::ToTokens { + quote! { + let guard_against_recursion = u.is_empty(); + if guard_against_recursion { + #recursive_count.with(|count| { + if count.get() > 0 { + return Err(arbitrary::Error::NotEnoughData); + } + count.set(count.get() + 1); + Ok(()) + })?; + } + + let result = (|| { #expr })(); + + if guard_against_recursion { + #recursive_count.with(|count| { + count.set(count.get() - 1); + }); + } + + result + } +} + +fn gen_arbitrary_method( + input: &DeriveInput, + lifetime: LifetimeDef, + recursive_count: &syn::Ident, +) -> Result<TokenStream> { + fn arbitrary_structlike( + fields: &Fields, + ident: &syn::Ident, + lifetime: LifetimeDef, + recursive_count: &syn::Ident, + ) -> Result<TokenStream> { + let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?; + let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) }); + + let arbitrary_take_rest = construct_take_rest(fields)?; + let take_rest_body = + with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) }); + + Ok(quote! { + fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + #body + } + + fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + #take_rest_body + } + }) + } + + let ident = &input.ident; + let output = match &input.data { + Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?, + Data::Union(data) => arbitrary_structlike( + &Fields::Named(data.fields.clone()), + ident, + lifetime, + recursive_count, + )?, + Data::Enum(data) => { + let variants: Vec<TokenStream> = data + .variants + .iter() + .enumerate() + .map(|(i, variant)| { + let idx = i as u64; + let variant_name = &variant.ident; + construct(&variant.fields, |_, field| gen_constructor_for_field(field)) + .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) + }) + .collect::<Result<_>>()?; + + let variants_take_rest: Vec<TokenStream> = data + .variants + .iter() + .enumerate() + .map(|(i, variant)| { + let idx = i as u64; + let variant_name = &variant.ident; + construct_take_rest(&variant.fields) + .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) + }) + .collect::<Result<_>>()?; + + let count = data.variants.len() as u64; + + let arbitrary = with_recursive_count_guard( + recursive_count, + quote! { + // Use a multiply + shift to generate a ranged random number + // with slight bias. For details, see: + // https://lemire.me/blog/2016/06/30/fast-random-shuffling + Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 { + #(#variants,)* + _ => unreachable!() + }) + }, + ); + + let arbitrary_take_rest = with_recursive_count_guard( + recursive_count, + quote! { + // Use a multiply + shift to generate a ranged random number + // with slight bias. For details, see: + // https://lemire.me/blog/2016/06/30/fast-random-shuffling + Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 { + #(#variants_take_rest,)* + _ => unreachable!() + }) + }, + ); + + quote! { + fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + #arbitrary + } + + fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> { + #arbitrary_take_rest + } + } + } + }; + Ok(output) +} + +fn construct( + fields: &Fields, + ctor: impl Fn(usize, &Field) -> Result<TokenStream>, +) -> Result<TokenStream> { + let output = match fields { + Fields::Named(names) => { + let names: Vec<TokenStream> = names + .named + .iter() + .enumerate() + .map(|(i, f)| { + let name = f.ident.as_ref().unwrap(); + ctor(i, f).map(|ctor| quote! { #name: #ctor }) + }) + .collect::<Result<_>>()?; + quote! { { #(#names,)* } } + } + Fields::Unnamed(names) => { + let names: Vec<TokenStream> = names + .unnamed + .iter() + .enumerate() + .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor })) + .collect::<Result<_>>()?; + quote! { ( #(#names),* ) } + } + Fields::Unit => quote!(), + }; + Ok(output) +} + +fn construct_take_rest(fields: &Fields) -> Result<TokenStream> { + construct(fields, |idx, field| { + determine_field_constructor(field).map(|field_constructor| match field_constructor { + FieldConstructor::Default => quote!(Default::default()), + FieldConstructor::Arbitrary => { + if idx + 1 == fields.len() { + quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? } + } else { + quote! { arbitrary::Arbitrary::arbitrary(&mut u)? } + } + } + FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?), + FieldConstructor::Value(value) => quote!(#value), + }) + }) +} + +fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> { + let size_hint_fields = |fields: &Fields| { + fields + .iter() + .map(|f| { + let ty = &f.ty; + determine_field_constructor(f).map(|field_constructor| { + match field_constructor { + FieldConstructor::Default | FieldConstructor::Value(_) => { + quote!((0, Some(0))) + } + FieldConstructor::Arbitrary => { + quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) } + } + + // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is + // just an educated guess, although it's gonna be inaccurate for dynamically + // allocated types (Vec, HashMap, etc.). + FieldConstructor::With(_) => { + quote! { (::core::mem::size_of::<#ty>(), None) } + } + } + }) + }) + .collect::<Result<Vec<TokenStream>>>() + .map(|hints| { + quote! { + arbitrary::size_hint::and_all(&[ + #( #hints ),* + ]) + } + }) + }; + let size_hint_structlike = |fields: &Fields| { + size_hint_fields(fields).map(|hint| { + quote! { + #[inline] + fn size_hint(depth: usize) -> (usize, Option<usize>) { + arbitrary::size_hint::recursion_guard(depth, |depth| #hint) + } + } + }) + }; + match &input.data { + Data::Struct(data) => size_hint_structlike(&data.fields), + Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())), + Data::Enum(data) => data + .variants + .iter() + .map(|v| size_hint_fields(&v.fields)) + .collect::<Result<Vec<TokenStream>>>() + .map(|variants| { + quote! { + #[inline] + fn size_hint(depth: usize) -> (usize, Option<usize>) { + arbitrary::size_hint::and( + <u32 as arbitrary::Arbitrary>::size_hint(depth), + arbitrary::size_hint::recursion_guard(depth, |depth| { + arbitrary::size_hint::or_all(&[ #( #variants ),* ]) + }), + ) + } + } + }), + } +} + +fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> { + let ctor = match determine_field_constructor(field)? { + FieldConstructor::Default => quote!(Default::default()), + FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?), + FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?), + FieldConstructor::Value(value) => quote!(#value), + }; + Ok(ctor) +} |