summaryrefslogtreecommitdiffstats
path: root/third_party/rust/async-trait/src/expand.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/async-trait/src/expand.rs')
-rw-r--r--third_party/rust/async-trait/src/expand.rs502
1 files changed, 502 insertions, 0 deletions
diff --git a/third_party/rust/async-trait/src/expand.rs b/third_party/rust/async-trait/src/expand.rs
new file mode 100644
index 0000000000..88338db9bc
--- /dev/null
+++ b/third_party/rust/async-trait/src/expand.rs
@@ -0,0 +1,502 @@
+use crate::bound::{has_bound, InferredBound, Supertraits};
+use crate::lifetime::{AddLifetimeToImplTrait, CollectLifetimes};
+use crate::parse::Item;
+use crate::receiver::{has_self_in_block, has_self_in_sig, mut_pat, ReplaceSelf};
+use proc_macro2::{Span, TokenStream};
+use quote::{format_ident, quote, quote_spanned, ToTokens};
+use std::collections::BTreeSet as Set;
+use std::mem;
+use syn::punctuated::Punctuated;
+use syn::visit_mut::{self, VisitMut};
+use syn::{
+ parse_quote, parse_quote_spanned, Attribute, Block, FnArg, GenericArgument, GenericParam,
+ Generics, Ident, ImplItem, Lifetime, LifetimeDef, Pat, PatIdent, PathArguments, Receiver,
+ ReturnType, Signature, Stmt, Token, TraitItem, Type, TypePath, WhereClause,
+};
+
+impl ToTokens for Item {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ match self {
+ Item::Trait(item) => item.to_tokens(tokens),
+ Item::Impl(item) => item.to_tokens(tokens),
+ }
+ }
+}
+
+#[derive(Clone, Copy)]
+enum Context<'a> {
+ Trait {
+ generics: &'a Generics,
+ supertraits: &'a Supertraits,
+ },
+ Impl {
+ impl_generics: &'a Generics,
+ associated_type_impl_traits: &'a Set<Ident>,
+ },
+}
+
+impl Context<'_> {
+ fn lifetimes<'a>(&'a self, used: &'a [Lifetime]) -> impl Iterator<Item = &'a LifetimeDef> {
+ let generics = match self {
+ Context::Trait { generics, .. } => generics,
+ Context::Impl { impl_generics, .. } => impl_generics,
+ };
+ generics.params.iter().filter_map(move |param| {
+ if let GenericParam::Lifetime(param) = param {
+ if used.contains(&param.lifetime) {
+ return Some(param);
+ }
+ }
+ None
+ })
+ }
+}
+
+pub fn expand(input: &mut Item, is_local: bool) {
+ match input {
+ Item::Trait(input) => {
+ let context = Context::Trait {
+ generics: &input.generics,
+ supertraits: &input.supertraits,
+ };
+ for inner in &mut input.items {
+ if let TraitItem::Method(method) = inner {
+ let sig = &mut method.sig;
+ if sig.asyncness.is_some() {
+ let block = &mut method.default;
+ let mut has_self = has_self_in_sig(sig);
+ method.attrs.push(parse_quote!(#[must_use]));
+ if let Some(block) = block {
+ has_self |= has_self_in_block(block);
+ transform_block(context, sig, block);
+ method.attrs.push(lint_suppress_with_body());
+ } else {
+ method.attrs.push(lint_suppress_without_body());
+ }
+ let has_default = method.default.is_some();
+ transform_sig(context, sig, has_self, has_default, is_local);
+ }
+ }
+ }
+ }
+ Item::Impl(input) => {
+ let mut lifetimes = CollectLifetimes::new("'impl");
+ lifetimes.visit_type_mut(&mut *input.self_ty);
+ lifetimes.visit_path_mut(&mut input.trait_.as_mut().unwrap().1);
+ let params = &input.generics.params;
+ let elided = lifetimes.elided;
+ input.generics.params = parse_quote!(#(#elided,)* #params);
+
+ let mut associated_type_impl_traits = Set::new();
+ for inner in &input.items {
+ if let ImplItem::Type(assoc) = inner {
+ if let Type::ImplTrait(_) = assoc.ty {
+ associated_type_impl_traits.insert(assoc.ident.clone());
+ }
+ }
+ }
+
+ let context = Context::Impl {
+ impl_generics: &input.generics,
+ associated_type_impl_traits: &associated_type_impl_traits,
+ };
+ for inner in &mut input.items {
+ if let ImplItem::Method(method) = inner {
+ let sig = &mut method.sig;
+ if sig.asyncness.is_some() {
+ let block = &mut method.block;
+ let has_self = has_self_in_sig(sig) || has_self_in_block(block);
+ transform_block(context, sig, block);
+ transform_sig(context, sig, has_self, false, is_local);
+ method.attrs.push(lint_suppress_with_body());
+ }
+ }
+ }
+ }
+ }
+}
+
+fn lint_suppress_with_body() -> Attribute {
+ parse_quote! {
+ #[allow(
+ clippy::async_yields_async,
+ clippy::let_unit_value,
+ clippy::no_effect_underscore_binding,
+ clippy::shadow_same,
+ clippy::type_complexity,
+ clippy::type_repetition_in_bounds,
+ clippy::used_underscore_binding
+ )]
+ }
+}
+
+fn lint_suppress_without_body() -> Attribute {
+ parse_quote! {
+ #[allow(
+ clippy::type_complexity,
+ clippy::type_repetition_in_bounds
+ )]
+ }
+}
+
+// Input:
+// async fn f<T>(&self, x: &T) -> Ret;
+//
+// Output:
+// fn f<'life0, 'life1, 'async_trait, T>(
+// &'life0 self,
+// x: &'life1 T,
+// ) -> Pin<Box<dyn Future<Output = Ret> + Send + 'async_trait>>
+// where
+// 'life0: 'async_trait,
+// 'life1: 'async_trait,
+// T: 'async_trait,
+// Self: Sync + 'async_trait;
+fn transform_sig(
+ context: Context,
+ sig: &mut Signature,
+ has_self: bool,
+ has_default: bool,
+ is_local: bool,
+) {
+ let default_span = sig.asyncness.take().unwrap().span;
+ sig.fn_token.span = default_span;
+
+ let (ret_arrow, ret) = match &sig.output {
+ ReturnType::Default => (Token![->](default_span), quote_spanned!(default_span=> ())),
+ ReturnType::Type(arrow, ret) => (*arrow, quote!(#ret)),
+ };
+
+ let mut lifetimes = CollectLifetimes::new("'life");
+ for arg in sig.inputs.iter_mut() {
+ match arg {
+ FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
+ FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
+ }
+ }
+
+ for param in &mut sig.generics.params {
+ match param {
+ GenericParam::Type(param) => {
+ let param_name = &param.ident;
+ let span = match param.colon_token.take() {
+ Some(colon_token) => colon_token.span,
+ None => param_name.span(),
+ };
+ let bounds = mem::replace(&mut param.bounds, Punctuated::new());
+ where_clause_or_default(&mut sig.generics.where_clause)
+ .predicates
+ .push(parse_quote_spanned!(span=> #param_name: 'async_trait + #bounds));
+ }
+ GenericParam::Lifetime(param) => {
+ let param_name = &param.lifetime;
+ let span = match param.colon_token.take() {
+ Some(colon_token) => colon_token.span,
+ None => param_name.span(),
+ };
+ let bounds = mem::replace(&mut param.bounds, Punctuated::new());
+ where_clause_or_default(&mut sig.generics.where_clause)
+ .predicates
+ .push(parse_quote_spanned!(span=> #param: 'async_trait + #bounds));
+ }
+ GenericParam::Const(_) => {}
+ }
+ }
+
+ for param in context.lifetimes(&lifetimes.explicit) {
+ let param = &param.lifetime;
+ let span = param.span();
+ where_clause_or_default(&mut sig.generics.where_clause)
+ .predicates
+ .push(parse_quote_spanned!(span=> #param: 'async_trait));
+ }
+
+ if sig.generics.lt_token.is_none() {
+ sig.generics.lt_token = Some(Token![<](sig.ident.span()));
+ }
+ if sig.generics.gt_token.is_none() {
+ sig.generics.gt_token = Some(Token![>](sig.paren_token.span));
+ }
+
+ for elided in lifetimes.elided {
+ sig.generics.params.push(parse_quote!(#elided));
+ where_clause_or_default(&mut sig.generics.where_clause)
+ .predicates
+ .push(parse_quote_spanned!(elided.span()=> #elided: 'async_trait));
+ }
+
+ sig.generics
+ .params
+ .push(parse_quote_spanned!(default_span=> 'async_trait));
+
+ if has_self {
+ let bounds: &[InferredBound] = match sig.inputs.iter().next() {
+ Some(FnArg::Receiver(Receiver {
+ reference: Some(_),
+ mutability: None,
+ ..
+ })) => &[InferredBound::Sync],
+ Some(FnArg::Typed(arg))
+ if match arg.pat.as_ref() {
+ Pat::Ident(pat) => pat.ident == "self",
+ _ => false,
+ } =>
+ {
+ match arg.ty.as_ref() {
+ // self: &Self
+ Type::Reference(ty) if ty.mutability.is_none() => &[InferredBound::Sync],
+ // self: Arc<Self>
+ Type::Path(ty)
+ if {
+ let segment = ty.path.segments.last().unwrap();
+ segment.ident == "Arc"
+ && match &segment.arguments {
+ PathArguments::AngleBracketed(arguments) => {
+ arguments.args.len() == 1
+ && match &arguments.args[0] {
+ GenericArgument::Type(Type::Path(arg)) => {
+ arg.path.is_ident("Self")
+ }
+ _ => false,
+ }
+ }
+ _ => false,
+ }
+ } =>
+ {
+ &[InferredBound::Sync, InferredBound::Send]
+ }
+ _ => &[InferredBound::Send],
+ }
+ }
+ _ => &[InferredBound::Send],
+ };
+
+ let bounds = bounds.iter().filter_map(|bound| {
+ let assume_bound = match context {
+ Context::Trait { supertraits, .. } => !has_default || has_bound(supertraits, bound),
+ Context::Impl { .. } => true,
+ };
+ if assume_bound || is_local {
+ None
+ } else {
+ Some(bound.spanned_path(default_span))
+ }
+ });
+
+ where_clause_or_default(&mut sig.generics.where_clause)
+ .predicates
+ .push(parse_quote_spanned! {default_span=>
+ Self: #(#bounds +)* 'async_trait
+ });
+ }
+
+ for (i, arg) in sig.inputs.iter_mut().enumerate() {
+ match arg {
+ FnArg::Receiver(Receiver {
+ reference: Some(_), ..
+ }) => {}
+ FnArg::Receiver(arg) => arg.mutability = None,
+ FnArg::Typed(arg) => {
+ let type_is_reference = match *arg.ty {
+ Type::Reference(_) => true,
+ _ => false,
+ };
+ if let Pat::Ident(pat) = &mut *arg.pat {
+ if pat.ident == "self" || !type_is_reference {
+ pat.by_ref = None;
+ pat.mutability = None;
+ }
+ } else if !type_is_reference {
+ let positional = positional_arg(i, &arg.pat);
+ let m = mut_pat(&mut arg.pat);
+ arg.pat = parse_quote!(#m #positional);
+ }
+ AddLifetimeToImplTrait.visit_type_mut(&mut arg.ty);
+ }
+ }
+ }
+
+ let bounds = if is_local {
+ quote_spanned!(default_span=> 'async_trait)
+ } else {
+ quote_spanned!(default_span=> ::core::marker::Send + 'async_trait)
+ };
+ sig.output = parse_quote_spanned! {default_span=>
+ #ret_arrow ::core::pin::Pin<Box<
+ dyn ::core::future::Future<Output = #ret> + #bounds
+ >>
+ };
+}
+
+// Input:
+// async fn f<T>(&self, x: &T, (a, b): (A, B)) -> Ret {
+// self + x + a + b
+// }
+//
+// Output:
+// Box::pin(async move {
+// let ___ret: Ret = {
+// let __self = self;
+// let x = x;
+// let (a, b) = __arg1;
+//
+// __self + x + a + b
+// };
+//
+// ___ret
+// })
+fn transform_block(context: Context, sig: &mut Signature, block: &mut Block) {
+ if let Some(Stmt::Item(syn::Item::Verbatim(item))) = block.stmts.first() {
+ if block.stmts.len() == 1 && item.to_string() == ";" {
+ return;
+ }
+ }
+
+ let mut self_span = None;
+ let decls = sig
+ .inputs
+ .iter()
+ .enumerate()
+ .map(|(i, arg)| match arg {
+ FnArg::Receiver(Receiver {
+ self_token,
+ mutability,
+ ..
+ }) => {
+ let ident = Ident::new("__self", self_token.span);
+ self_span = Some(self_token.span);
+ quote!(let #mutability #ident = #self_token;)
+ }
+ FnArg::Typed(arg) => {
+ // If there is a #[cfg(...)] attribute that selectively enables
+ // the parameter, forward it to the variable.
+ //
+ // This is currently not applied to the `self` parameter.
+ let attrs = arg.attrs.iter().filter(|attr| attr.path.is_ident("cfg"));
+
+ if let Pat::Ident(PatIdent {
+ ident, mutability, ..
+ }) = &*arg.pat
+ {
+ if ident == "self" {
+ self_span = Some(ident.span());
+ let prefixed = Ident::new("__self", ident.span());
+ quote!(let #mutability #prefixed = #ident;)
+ } else if let Type::Reference(_) = *arg.ty {
+ quote!()
+ } else {
+ quote! {
+ #(#attrs)*
+ let #mutability #ident = #ident;
+ }
+ }
+ } else if let Type::Reference(_) = *arg.ty {
+ quote!()
+ } else {
+ let pat = &arg.pat;
+ let ident = positional_arg(i, pat);
+ if let Pat::Wild(_) = **pat {
+ quote! {
+ #(#attrs)*
+ let #ident = #ident;
+ }
+ } else {
+ quote! {
+ #(#attrs)*
+ let #pat = {
+ let #ident = #ident;
+ #ident
+ };
+ }
+ }
+ }
+ }
+ })
+ .collect::<Vec<_>>();
+
+ if let Some(span) = self_span {
+ let mut replace_self = ReplaceSelf(span);
+ replace_self.visit_block_mut(block);
+ }
+
+ let stmts = &block.stmts;
+ let let_ret = match &mut sig.output {
+ ReturnType::Default => quote_spanned! {block.brace_token.span=>
+ #(#decls)*
+ let _: () = { #(#stmts)* };
+ },
+ ReturnType::Type(_, ret) => {
+ if contains_associated_type_impl_trait(context, ret) {
+ if decls.is_empty() {
+ quote!(#(#stmts)*)
+ } else {
+ quote!(#(#decls)* { #(#stmts)* })
+ }
+ } else {
+ quote_spanned! {block.brace_token.span=>
+ if let ::core::option::Option::Some(__ret) = ::core::option::Option::None::<#ret> {
+ return __ret;
+ }
+ #(#decls)*
+ let __ret: #ret = { #(#stmts)* };
+ #[allow(unreachable_code)]
+ __ret
+ }
+ }
+ }
+ };
+ let box_pin = quote_spanned!(block.brace_token.span=>
+ Box::pin(async move { #let_ret })
+ );
+ block.stmts = parse_quote!(#box_pin);
+}
+
+fn positional_arg(i: usize, pat: &Pat) -> Ident {
+ let span: Span = syn::spanned::Spanned::span(pat);
+ #[cfg(not(no_span_mixed_site))]
+ let span = span.resolved_at(Span::mixed_site());
+ format_ident!("__arg{}", i, span = span)
+}
+
+fn contains_associated_type_impl_trait(context: Context, ret: &mut Type) -> bool {
+ struct AssociatedTypeImplTraits<'a> {
+ set: &'a Set<Ident>,
+ contains: bool,
+ }
+
+ impl<'a> VisitMut for AssociatedTypeImplTraits<'a> {
+ fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
+ if ty.qself.is_none()
+ && ty.path.segments.len() == 2
+ && ty.path.segments[0].ident == "Self"
+ && self.set.contains(&ty.path.segments[1].ident)
+ {
+ self.contains = true;
+ }
+ visit_mut::visit_type_path_mut(self, ty);
+ }
+ }
+
+ match context {
+ Context::Trait { .. } => false,
+ Context::Impl {
+ associated_type_impl_traits,
+ ..
+ } => {
+ let mut visit = AssociatedTypeImplTraits {
+ set: associated_type_impl_traits,
+ contains: false,
+ };
+ visit.visit_type_mut(ret);
+ visit.contains
+ }
+ }
+}
+
+fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
+ clause.get_or_insert_with(|| WhereClause {
+ where_token: Default::default(),
+ predicates: Punctuated::new(),
+ })
+}