diff options
Diffstat (limited to 'third_party/rust/async-trait/src/expand.rs')
-rw-r--r-- | third_party/rust/async-trait/src/expand.rs | 502 |
1 files changed, 502 insertions, 0 deletions
diff --git a/third_party/rust/async-trait/src/expand.rs b/third_party/rust/async-trait/src/expand.rs new file mode 100644 index 0000000000..88338db9bc --- /dev/null +++ b/third_party/rust/async-trait/src/expand.rs @@ -0,0 +1,502 @@ +use crate::bound::{has_bound, InferredBound, Supertraits}; +use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes}; +use crate::parse::Item; +use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf}; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, quote_spanned, ToTokens}; +use std::collections::BTreeSet as Set; +use std::mem; +use syn::punctuated::Punctuated; +use syn::visit_mut::{self, VisitMut}; +use syn::{ + parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam, + Generics, Ident, ImplItem, Lifetime, LifetimeDef, Pat, PatIdent, PathArguments, Receiver, + ReturnType, Signature, Stmt, Token, TraitItem, Type, TypePath, WhereClause, +}; + +impl ToTokens for Item { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Item::Trait(item) => item.to_tokens(tokens), + Item::Impl(item) => item.to_tokens(tokens), + } + } +} + +#[derive(Clone, Copy)] +enum Context<'a> { + Trait { + generics: &'a Generics, + supertraits: &'a Supertraits, + }, + Impl { + impl_generics: &'a Generics, + associated_type_impl_traits: &'a Set<Ident>, + }, +} + +impl Context<'_> { + fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeDef> { + let generics = match self { + Context::Trait { generics, .. } => generics, + Context::Impl { impl_generics, .. } => impl_generics, + }; + generics.params.iter().filter_map(move |param| { + if let GenericParam::Lifetime(param) = param { + if used.contains(¶m.lifetime) { + return Some(param); + } + } + None + }) + } +} + +pub fn expand(input: &mut Item, is_local: bool) { + match input { + Item::Trait(input) => { + let context = Context::Trait { + generics: &input.generics, + supertraits: &input.supertraits, + }; + for inner in &mut input.items { + if let TraitItem::Method(method) = inner { + let sig = &mut method.sig; + if sig.asyncness.is_some() { + let block = &mut method.default; + let mut has_self = has_self_in_sig(sig); + method.attrs.push(parse_quote!(#[must_use])); + if let Some(block) = block { + has_self |= has_self_in_block(block); + transform_block(context, sig, block); + method.attrs.push(lint_suppress_with_body()); + } else { + method.attrs.push(lint_suppress_without_body()); + } + let has_default = method.default.is_some(); + transform_sig(context, sig, has_self, has_default, is_local); + } + } + } + } + Item::Impl(input) => { + let mut lifetimes = CollectLifetimes::new("'impl"); + lifetimes.visit_type_mut(&mut *input.self_ty); + lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1); + let params = &input.generics.params; + let elided = lifetimes.elided; + input.generics.params = parse_quote!(#(#elided,)* #params); + + let mut associated_type_impl_traits = Set::new(); + for inner in &input.items { + if let ImplItem::Type(assoc) = inner { + if let Type::ImplTrait(_) = assoc.ty { + associated_type_impl_traits.insert(assoc.ident.clone()); + } + } + } + + let context = Context::Impl { + impl_generics: &input.generics, + associated_type_impl_traits: &associated_type_impl_traits, + }; + for inner in &mut input.items { + if let ImplItem::Method(method) = inner { + let sig = &mut method.sig; + if sig.asyncness.is_some() { + let block = &mut method.block; + let has_self = has_self_in_sig(sig) || has_self_in_block(block); + transform_block(context, sig, block); + transform_sig(context, sig, has_self, false, is_local); + method.attrs.push(lint_suppress_with_body()); + } + } + } + } + } +} + +fn lint_suppress_with_body() -> Attribute { + parse_quote! { + #[allow( + clippy::async_yields_async, + clippy::let_unit_value, + clippy::no_effect_underscore_binding, + clippy::shadow_same, + clippy::type_complexity, + clippy::type_repetition_in_bounds, + clippy::used_underscore_binding + )] + } +} + +fn lint_suppress_without_body() -> Attribute { + parse_quote! { + #[allow( + clippy::type_complexity, + clippy::type_repetition_in_bounds + )] + } +} + +// Input: +// async fn f<T>(&self, x: &T) -> Ret; +// +// Output: +// fn f<'life0, 'life1, 'async_trait, T>( +// &'life0 self, +// x: &'life1 T, +// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>> +// where +// 'life0: 'async_trait, +// 'life1: 'async_trait, +// T: 'async_trait, +// Self: Sync + 'async_trait; +fn transform_sig( + context: Context, + sig: &mut Signature, + has_self: bool, + has_default: bool, + is_local: bool, +) { + let default_span = sig.asyncness.take().unwrap().span; + sig.fn_token.span = default_span; + + let (ret_arrow, ret) = match &sig.output { + ReturnType::Default => (Token![->](default_span), quote_spanned!(default_span=> ())), + ReturnType::Type(arrow, ret) => (*arrow, quote!(#ret)), + }; + + let mut lifetimes = CollectLifetimes::new("'life"); + for arg in sig.inputs.iter_mut() { + match arg { + FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg), + FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty), + } + } + + for param in &mut sig.generics.params { + match param { + GenericParam::Type(param) => { + let param_name = ¶m.ident; + let span = match param.colon_token.take() { + Some(colon_token) => colon_token.span, + None => param_name.span(), + }; + let bounds = mem::replace(&mut param.bounds, Punctuated::new()); + where_clause_or_default(&mut sig.generics.where_clause) + .predicates + .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds)); + } + GenericParam::Lifetime(param) => { + let param_name = ¶m.lifetime; + let span = match param.colon_token.take() { + Some(colon_token) => colon_token.span, + None => param_name.span(), + }; + let bounds = mem::replace(&mut param.bounds, Punctuated::new()); + where_clause_or_default(&mut sig.generics.where_clause) + .predicates + .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds)); + } + GenericParam::Const(_) => {} + } + } + + for param in context.lifetimes(&lifetimes.explicit) { + let param = ¶m.lifetime; + let span = param.span(); + where_clause_or_default(&mut sig.generics.where_clause) + .predicates + .push(parse_quote_spanned!(span=> #param: 'async_trait)); + } + + if sig.generics.lt_token.is_none() { + sig.generics.lt_token = Some(Token![<](sig.ident.span())); + } + if sig.generics.gt_token.is_none() { + sig.generics.gt_token = Some(Token![>](sig.paren_token.span)); + } + + for elided in lifetimes.elided { + sig.generics.params.push(parse_quote!(#elided)); + where_clause_or_default(&mut sig.generics.where_clause) + .predicates + .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait)); + } + + sig.generics + .params + .push(parse_quote_spanned!(default_span=> 'async_trait)); + + if has_self { + let bounds: &[InferredBound] = match sig.inputs.iter().next() { + Some(FnArg::Receiver(Receiver { + reference: Some(_), + mutability: None, + .. + })) => &[InferredBound::Sync], + Some(FnArg::Typed(arg)) + if match arg.pat.as_ref() { + Pat::Ident(pat) => pat.ident == "self", + _ => false, + } => + { + match arg.ty.as_ref() { + // self: &Self + Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync], + // self: Arc<Self> + Type::Path(ty) + if { + let segment = ty.path.segments.last().unwrap(); + segment.ident == "Arc" + && match &segment.arguments { + PathArguments::AngleBracketed(arguments) => { + arguments.args.len() == 1 + && match &arguments.args[0] { + GenericArgument::Type(Type::Path(arg)) => { + arg.path.is_ident("Self") + } + _ => false, + } + } + _ => false, + } + } => + { + &[InferredBound::Sync, InferredBound::Send] + } + _ => &[InferredBound::Send], + } + } + _ => &[InferredBound::Send], + }; + + let bounds = bounds.iter().filter_map(|bound| { + let assume_bound = match context { + Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, bound), + Context::Impl { .. } => true, + }; + if assume_bound || is_local { + None + } else { + Some(bound.spanned_path(default_span)) + } + }); + + where_clause_or_default(&mut sig.generics.where_clause) + .predicates + .push(parse_quote_spanned! {default_span=> + Self: #(#bounds +)* 'async_trait + }); + } + + for (i, arg) in sig.inputs.iter_mut().enumerate() { + match arg { + FnArg::Receiver(Receiver { + reference: Some(_), .. + }) => {} + FnArg::Receiver(arg) => arg.mutability = None, + FnArg::Typed(arg) => { + let type_is_reference = match *arg.ty { + Type::Reference(_) => true, + _ => false, + }; + if let Pat::Ident(pat) = &mut *arg.pat { + if pat.ident == "self" || !type_is_reference { + pat.by_ref = None; + pat.mutability = None; + } + } else if !type_is_reference { + let positional = positional_arg(i, &arg.pat); + let m = mut_pat(&mut arg.pat); + arg.pat = parse_quote!(#m #positional); + } + AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty); + } + } + } + + let bounds = if is_local { + quote_spanned!(default_span=> 'async_trait) + } else { + quote_spanned!(default_span=> ::core::marker::Send + 'async_trait) + }; + sig.output = parse_quote_spanned! {default_span=> + #ret_arrow ::core::pin::Pin<Box< + dyn ::core::future::Future<Output = #ret> + #bounds + >> + }; +} + +// Input: +// async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret { +// self + x + a + b +// } +// +// Output: +// Box::pin(async move { +// let ___ret: Ret = { +// let __self = self; +// let x = x; +// let (a, b) = __arg1; +// +// __self + x + a + b +// }; +// +// ___ret +// }) +fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) { + if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() { + if block.stmts.len() == 1 && item.to_string() == ";" { + return; + } + } + + let mut self_span = None; + let decls = sig + .inputs + .iter() + .enumerate() + .map(|(i, arg)| match arg { + FnArg::Receiver(Receiver { + self_token, + mutability, + .. + }) => { + let ident = Ident::new("__self", self_token.span); + self_span = Some(self_token.span); + quote!(let #mutability #ident = #self_token;) + } + FnArg::Typed(arg) => { + // If there is a #[cfg(...)] attribute that selectively enables + // the parameter, forward it to the variable. + // + // This is currently not applied to the `self` parameter. + let attrs = arg.attrs.iter().filter(|attr| attr.path.is_ident("cfg")); + + if let Pat::Ident(PatIdent { + ident, mutability, .. + }) = &*arg.pat + { + if ident == "self" { + self_span = Some(ident.span()); + let prefixed = Ident::new("__self", ident.span()); + quote!(let #mutability #prefixed = #ident;) + } else if let Type::Reference(_) = *arg.ty { + quote!() + } else { + quote! { + #(#attrs)* + let #mutability #ident = #ident; + } + } + } else if let Type::Reference(_) = *arg.ty { + quote!() + } else { + let pat = &arg.pat; + let ident = positional_arg(i, pat); + if let Pat::Wild(_) = **pat { + quote! { + #(#attrs)* + let #ident = #ident; + } + } else { + quote! { + #(#attrs)* + let #pat = { + let #ident = #ident; + #ident + }; + } + } + } + } + }) + .collect::<Vec<_>>(); + + if let Some(span) = self_span { + let mut replace_self = ReplaceSelf(span); + replace_self.visit_block_mut(block); + } + + let stmts = &block.stmts; + let let_ret = match &mut sig.output { + ReturnType::Default => quote_spanned! {block.brace_token.span=> + #(#decls)* + let _: () = { #(#stmts)* }; + }, + ReturnType::Type(_, ret) => { + if contains_associated_type_impl_trait(context, ret) { + if decls.is_empty() { + quote!(#(#stmts)*) + } else { + quote!(#(#decls)* { #(#stmts)* }) + } + } else { + quote_spanned! {block.brace_token.span=> + if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> { + return __ret; + } + #(#decls)* + let __ret: #ret = { #(#stmts)* }; + #[allow(unreachable_code)] + __ret + } + } + } + }; + let box_pin = quote_spanned!(block.brace_token.span=> + Box::pin(async move { #let_ret }) + ); + block.stmts = parse_quote!(#box_pin); +} + +fn positional_arg(i: usize, pat: &Pat) -> Ident { + let span: Span = syn::spanned::Spanned::span(pat); + #[cfg(not(no_span_mixed_site))] + let span = span.resolved_at(Span::mixed_site()); + format_ident!("__arg{}", i, span = span) +} + +fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool { + struct AssociatedTypeImplTraits<'a> { + set: &'a Set<Ident>, + contains: bool, + } + + impl<'a> VisitMut for AssociatedTypeImplTraits<'a> { + fn visit_type_path_mut(&mut self, ty: &mut TypePath) { + if ty.qself.is_none() + && ty.path.segments.len() == 2 + && ty.path.segments[0].ident == "Self" + && self.set.contains(&ty.path.segments[1].ident) + { + self.contains = true; + } + visit_mut::visit_type_path_mut(self, ty); + } + } + + match context { + Context::Trait { .. } => false, + Context::Impl { + associated_type_impl_traits, + .. + } => { + let mut visit = AssociatedTypeImplTraits { + set: associated_type_impl_traits, + contains: false, + }; + visit.visit_type_mut(ret); + visit.contains + } + } +} + +fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause { + clause.get_or_insert_with(|| WhereClause { + where_token: Default::default(), + predicates: Punctuated::new(), + }) +} |