summaryrefslogtreecommitdiffstats
path: root/third_party/rust/derive_arbitrary/src
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/derive_arbitrary/src
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/derive_arbitrary/src')
-rw-r--r--third_party/rust/derive_arbitrary/src/field_attributes.rs117
-rw-r--r--third_party/rust/derive_arbitrary/src/lib.rs348
2 files changed, 465 insertions, 0 deletions
diff --git a/third_party/rust/derive_arbitrary/src/field_attributes.rs b/third_party/rust/derive_arbitrary/src/field_attributes.rs
new file mode 100644
index 0000000000..ccaba743b9
--- /dev/null
+++ b/third_party/rust/derive_arbitrary/src/field_attributes.rs
@@ -0,0 +1,117 @@
+use proc_macro2::{Group, Span, TokenStream, TokenTree};
+use quote::quote;
+use syn::{spanned::Spanned, *};
+
+/// Used to filter out necessary field attribute and within error messages.
+static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
+
+/// Determines how a value for a field should be constructed.
+#[cfg_attr(test, derive(Debug))]
+pub enum FieldConstructor {
+ /// Assume that Arbitrary is defined for the type of this field and use it (default)
+ Arbitrary,
+
+ /// Places `Default::default()` as a field value.
+ Default,
+
+ /// Use custom function or closure to generate a value for a field.
+ With(TokenStream),
+
+ /// Set a field always to the given value.
+ Value(TokenStream),
+}
+
+pub fn determine_field_constructor(field: &Field) -> Result<FieldConstructor> {
+ let opt_attr = fetch_attr_from_field(field)?;
+ let ctor = match opt_attr {
+ Some(attr) => parse_attribute(attr)?,
+ None => FieldConstructor::Arbitrary,
+ };
+ Ok(ctor)
+}
+
+fn fetch_attr_from_field(field: &Field) -> Result<Option<&Attribute>> {
+ let found_attributes: Vec<_> = field
+ .attrs
+ .iter()
+ .filter(|a| {
+ let path = &a.path;
+ let name = quote!(#path).to_string();
+ name == ARBITRARY_ATTRIBUTE_NAME
+ })
+ .collect();
+ if found_attributes.len() > 1 {
+ let name = field.ident.as_ref().unwrap();
+ let msg = format!(
+ "Multiple conflicting #[{ARBITRARY_ATTRIBUTE_NAME}] attributes found on field `{name}`"
+ );
+ return Err(syn::Error::new(field.span(), msg));
+ }
+ Ok(found_attributes.into_iter().next())
+}
+
+fn parse_attribute(attr: &Attribute) -> Result<FieldConstructor> {
+ let group = {
+ let mut tokens_iter = attr.clone().tokens.into_iter();
+ let token = tokens_iter.next().ok_or_else(|| {
+ let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
+ syn::Error::new(attr.span(), msg)
+ })?;
+ match token {
+ TokenTree::Group(g) => g,
+ t => {
+ let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] must contain a group, got: {t})");
+ return Err(syn::Error::new(attr.span(), msg));
+ }
+ }
+ };
+ parse_attribute_internals(group)
+}
+
+fn parse_attribute_internals(group: Group) -> Result<FieldConstructor> {
+ let stream = group.stream();
+ let mut tokens_iter = stream.into_iter();
+ let token = tokens_iter.next().ok_or_else(|| {
+ let msg = format!("#[{ARBITRARY_ATTRIBUTE_NAME}] cannot be empty.");
+ syn::Error::new(group.span(), msg)
+ })?;
+ match token.to_string().as_ref() {
+ "default" => Ok(FieldConstructor::Default),
+ "with" => {
+ let func_path = parse_assigned_value("with", tokens_iter, group.span())?;
+ Ok(FieldConstructor::With(func_path))
+ }
+ "value" => {
+ let value = parse_assigned_value("value", tokens_iter, group.span())?;
+ Ok(FieldConstructor::Value(value))
+ }
+ _ => {
+ let msg = format!("Unknown option for #[{ARBITRARY_ATTRIBUTE_NAME}]: `{token}`");
+ Err(syn::Error::new(token.span(), msg))
+ }
+ }
+}
+
+// Input:
+// = 2 + 2
+// Output:
+// 2 + 2
+fn parse_assigned_value(
+ opt_name: &str,
+ mut tokens_iter: impl Iterator<Item = TokenTree>,
+ default_span: Span,
+) -> Result<TokenStream> {
+ let eq_sign = tokens_iter.next().ok_or_else(|| {
+ let msg = format!(
+ "Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], `{opt_name}` is missing assignment."
+ );
+ syn::Error::new(default_span, msg)
+ })?;
+
+ if eq_sign.to_string() == "=" {
+ Ok(tokens_iter.collect())
+ } else {
+ let msg = format!("Invalid syntax for #[{ARBITRARY_ATTRIBUTE_NAME}], expected `=` after `{opt_name}`, got: `{eq_sign}`");
+ Err(syn::Error::new(eq_sign.span(), msg))
+ }
+}
diff --git a/third_party/rust/derive_arbitrary/src/lib.rs b/third_party/rust/derive_arbitrary/src/lib.rs
new file mode 100644
index 0000000000..4ed3817262
--- /dev/null
+++ b/third_party/rust/derive_arbitrary/src/lib.rs
@@ -0,0 +1,348 @@
+extern crate proc_macro;
+
+use proc_macro2::{Span, TokenStream};
+use quote::quote;
+use syn::*;
+
+mod field_attributes;
+use field_attributes::{determine_field_constructor, FieldConstructor};
+
+static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
+
+#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
+pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
+ let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
+ expand_derive_arbitrary(input)
+ .unwrap_or_else(syn::Error::into_compile_error)
+ .into()
+}
+
+fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
+ let (lifetime_without_bounds, lifetime_with_bounds) =
+ build_arbitrary_lifetime(input.generics.clone());
+
+ let recursive_count = syn::Ident::new(
+ &format!("RECURSIVE_COUNT_{}", input.ident),
+ Span::call_site(),
+ );
+
+ let arbitrary_method =
+ gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
+ let size_hint_method = gen_size_hint_method(&input)?;
+ let name = input.ident;
+ // Add a bound `T: Arbitrary` to every type parameter T.
+ let generics = add_trait_bounds(input.generics, lifetime_without_bounds.clone());
+
+ // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
+ let mut generics_with_lifetime = generics.clone();
+ generics_with_lifetime
+ .params
+ .push(GenericParam::Lifetime(lifetime_with_bounds));
+ let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
+
+ // Build TypeGenerics and WhereClause without a lifetime
+ let (_, ty_generics, where_clause) = generics.split_for_impl();
+
+ Ok(quote! {
+ const _: () = {
+ std::thread_local! {
+ #[allow(non_upper_case_globals)]
+ static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
+ }
+
+ #[automatically_derived]
+ impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
+ #arbitrary_method
+ #size_hint_method
+ }
+ };
+ })
+}
+
+// Returns: (lifetime without bounds, lifetime with bounds)
+// Example: ("'arbitrary", "'arbitrary: 'a + 'b")
+fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeDef, LifetimeDef) {
+ let lifetime_without_bounds =
+ LifetimeDef::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
+ let mut lifetime_with_bounds = lifetime_without_bounds.clone();
+
+ for param in generics.params.iter() {
+ if let GenericParam::Lifetime(lifetime_def) = param {
+ lifetime_with_bounds
+ .bounds
+ .push(lifetime_def.lifetime.clone());
+ }
+ }
+
+ (lifetime_without_bounds, lifetime_with_bounds)
+}
+
+// Add a bound `T: Arbitrary` to every type parameter T.
+fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeDef) -> Generics {
+ for param in generics.params.iter_mut() {
+ if let GenericParam::Type(type_param) = param {
+ type_param
+ .bounds
+ .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
+ }
+ }
+ generics
+}
+
+fn with_recursive_count_guard(
+ recursive_count: &syn::Ident,
+ expr: impl quote::ToTokens,
+) -> impl quote::ToTokens {
+ quote! {
+ let guard_against_recursion = u.is_empty();
+ if guard_against_recursion {
+ #recursive_count.with(|count| {
+ if count.get() > 0 {
+ return Err(arbitrary::Error::NotEnoughData);
+ }
+ count.set(count.get() + 1);
+ Ok(())
+ })?;
+ }
+
+ let result = (|| { #expr })();
+
+ if guard_against_recursion {
+ #recursive_count.with(|count| {
+ count.set(count.get() - 1);
+ });
+ }
+
+ result
+ }
+}
+
+fn gen_arbitrary_method(
+ input: &DeriveInput,
+ lifetime: LifetimeDef,
+ recursive_count: &syn::Ident,
+) -> Result<TokenStream> {
+ fn arbitrary_structlike(
+ fields: &Fields,
+ ident: &syn::Ident,
+ lifetime: LifetimeDef,
+ recursive_count: &syn::Ident,
+ ) -> Result<TokenStream> {
+ let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
+ let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
+
+ let arbitrary_take_rest = construct_take_rest(fields)?;
+ let take_rest_body =
+ with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
+
+ Ok(quote! {
+ fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+ #body
+ }
+
+ fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+ #take_rest_body
+ }
+ })
+ }
+
+ let ident = &input.ident;
+ let output = match &input.data {
+ Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
+ Data::Union(data) => arbitrary_structlike(
+ &Fields::Named(data.fields.clone()),
+ ident,
+ lifetime,
+ recursive_count,
+ )?,
+ Data::Enum(data) => {
+ let variants: Vec<TokenStream> = data
+ .variants
+ .iter()
+ .enumerate()
+ .map(|(i, variant)| {
+ let idx = i as u64;
+ let variant_name = &variant.ident;
+ construct(&variant.fields, |_, field| gen_constructor_for_field(field))
+ .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
+ })
+ .collect::<Result<_>>()?;
+
+ let variants_take_rest: Vec<TokenStream> = data
+ .variants
+ .iter()
+ .enumerate()
+ .map(|(i, variant)| {
+ let idx = i as u64;
+ let variant_name = &variant.ident;
+ construct_take_rest(&variant.fields)
+ .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
+ })
+ .collect::<Result<_>>()?;
+
+ let count = data.variants.len() as u64;
+
+ let arbitrary = with_recursive_count_guard(
+ recursive_count,
+ quote! {
+ // Use a multiply + shift to generate a ranged random number
+ // with slight bias. For details, see:
+ // https://lemire.me/blog/2016/06/30/fast-random-shuffling
+ Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
+ #(#variants,)*
+ _ => unreachable!()
+ })
+ },
+ );
+
+ let arbitrary_take_rest = with_recursive_count_guard(
+ recursive_count,
+ quote! {
+ // Use a multiply + shift to generate a ranged random number
+ // with slight bias. For details, see:
+ // https://lemire.me/blog/2016/06/30/fast-random-shuffling
+ Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
+ #(#variants_take_rest,)*
+ _ => unreachable!()
+ })
+ },
+ );
+
+ quote! {
+ fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+ #arbitrary
+ }
+
+ fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
+ #arbitrary_take_rest
+ }
+ }
+ }
+ };
+ Ok(output)
+}
+
+fn construct(
+ fields: &Fields,
+ ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
+) -> Result<TokenStream> {
+ let output = match fields {
+ Fields::Named(names) => {
+ let names: Vec<TokenStream> = names
+ .named
+ .iter()
+ .enumerate()
+ .map(|(i, f)| {
+ let name = f.ident.as_ref().unwrap();
+ ctor(i, f).map(|ctor| quote! { #name: #ctor })
+ })
+ .collect::<Result<_>>()?;
+ quote! { { #(#names,)* } }
+ }
+ Fields::Unnamed(names) => {
+ let names: Vec<TokenStream> = names
+ .unnamed
+ .iter()
+ .enumerate()
+ .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
+ .collect::<Result<_>>()?;
+ quote! { ( #(#names),* ) }
+ }
+ Fields::Unit => quote!(),
+ };
+ Ok(output)
+}
+
+fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
+ construct(fields, |idx, field| {
+ determine_field_constructor(field).map(|field_constructor| match field_constructor {
+ FieldConstructor::Default => quote!(Default::default()),
+ FieldConstructor::Arbitrary => {
+ if idx + 1 == fields.len() {
+ quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
+ } else {
+ quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
+ }
+ }
+ FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
+ FieldConstructor::Value(value) => quote!(#value),
+ })
+ })
+}
+
+fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
+ let size_hint_fields = |fields: &Fields| {
+ fields
+ .iter()
+ .map(|f| {
+ let ty = &f.ty;
+ determine_field_constructor(f).map(|field_constructor| {
+ match field_constructor {
+ FieldConstructor::Default | FieldConstructor::Value(_) => {
+ quote!((0, Some(0)))
+ }
+ FieldConstructor::Arbitrary => {
+ quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
+ }
+
+ // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
+ // just an educated guess, although it's gonna be inaccurate for dynamically
+ // allocated types (Vec, HashMap, etc.).
+ FieldConstructor::With(_) => {
+ quote! { (::core::mem::size_of::<#ty>(), None) }
+ }
+ }
+ })
+ })
+ .collect::<Result<Vec<TokenStream>>>()
+ .map(|hints| {
+ quote! {
+ arbitrary::size_hint::and_all(&[
+ #( #hints ),*
+ ])
+ }
+ })
+ };
+ let size_hint_structlike = |fields: &Fields| {
+ size_hint_fields(fields).map(|hint| {
+ quote! {
+ #[inline]
+ fn size_hint(depth: usize) -> (usize, Option<usize>) {
+ arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
+ }
+ }
+ })
+ };
+ match &input.data {
+ Data::Struct(data) => size_hint_structlike(&data.fields),
+ Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
+ Data::Enum(data) => data
+ .variants
+ .iter()
+ .map(|v| size_hint_fields(&v.fields))
+ .collect::<Result<Vec<TokenStream>>>()
+ .map(|variants| {
+ quote! {
+ #[inline]
+ fn size_hint(depth: usize) -> (usize, Option<usize>) {
+ arbitrary::size_hint::and(
+ <u32 as arbitrary::Arbitrary>::size_hint(depth),
+ arbitrary::size_hint::recursion_guard(depth, |depth| {
+ arbitrary::size_hint::or_all(&[ #( #variants ),* ])
+ }),
+ )
+ }
+ }
+ }),
+ }
+}
+
+fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
+ let ctor = match determine_field_constructor(field)? {
+ FieldConstructor::Default => quote!(Default::default()),
+ FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
+ FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
+ FieldConstructor::Value(value) => quote!(#value),
+ };
+ Ok(ctor)
+}