summaryrefslogtreecommitdiffstats
path: root/third_party/rust/tracing-attributes/src/expand.rs
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--third_party/rust/tracing-attributes/src/expand.rs814
1 files changed, 814 insertions, 0 deletions
diff --git a/third_party/rust/tracing-attributes/src/expand.rs b/third_party/rust/tracing-attributes/src/expand.rs
new file mode 100644
index 0000000000..7005b4423e
--- /dev/null
+++ b/third_party/rust/tracing-attributes/src/expand.rs
@@ -0,0 +1,814 @@
+use std::iter;
+
+use proc_macro2::TokenStream;
+use quote::{quote, quote_spanned, ToTokens};
+use syn::visit_mut::VisitMut;
+use syn::{
+ punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FieldPat, FnArg,
+ Ident, Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType,
+ Path, ReturnType, Signature, Stmt, Token, Type, TypePath,
+};
+
+use crate::{
+ attr::{Field, Fields, FormatMode, InstrumentArgs},
+ MaybeItemFn, MaybeItemFnRef,
+};
+
+/// Given an existing function, generate an instrumented version of that function
+pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
+ input: MaybeItemFnRef<'a, B>,
+ args: InstrumentArgs,
+ instrumented_function_name: &str,
+ self_type: Option<&TypePath>,
+) -> proc_macro2::TokenStream {
+ // these are needed ahead of time, as ItemFn contains the function body _and_
+ // isn't representable inside a quote!/quote_spanned! macro
+ // (Syn's ToTokens isn't implemented for ItemFn)
+ let MaybeItemFnRef {
+ outer_attrs,
+ inner_attrs,
+ vis,
+ sig,
+ block,
+ } = input;
+
+ let Signature {
+ output,
+ inputs: params,
+ unsafety,
+ asyncness,
+ constness,
+ abi,
+ ident,
+ generics:
+ syn::Generics {
+ params: gen_params,
+ where_clause,
+ ..
+ },
+ ..
+ } = sig;
+
+ let warnings = args.warnings();
+
+ let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output {
+ (erase_impl_trait(return_type), return_type.span())
+ } else {
+ // Point at function name if we don't have an explicit return type
+ (syn::parse_quote! { () }, ident.span())
+ };
+ // Install a fake return statement as the first thing in the function
+ // body, so that we eagerly infer that the return type is what we
+ // declared in the async fn signature.
+ // The `#[allow(..)]` is given because the return statement is
+ // unreachable, but does affect inference, so it needs to be written
+ // exactly that way for it to do its magic.
+ let fake_return_edge = quote_spanned! {return_span=>
+ #[allow(unreachable_code, clippy::diverging_sub_expression, clippy::let_unit_value)]
+ if false {
+ let __tracing_attr_fake_return: #return_type =
+ unreachable!("this is just for type inference, and is unreachable code");
+ return __tracing_attr_fake_return;
+ }
+ };
+ let block = quote! {
+ {
+ #fake_return_edge
+ #block
+ }
+ };
+
+ let body = gen_block(
+ &block,
+ params,
+ asyncness.is_some(),
+ args,
+ instrumented_function_name,
+ self_type,
+ );
+
+ quote!(
+ #(#outer_attrs) *
+ #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #output
+ #where_clause
+ {
+ #(#inner_attrs) *
+ #warnings
+ #body
+ }
+ )
+}
+
+/// Instrument a block
+fn gen_block<B: ToTokens>(
+ block: &B,
+ params: &Punctuated<FnArg, Token![,]>,
+ async_context: bool,
+ mut args: InstrumentArgs,
+ instrumented_function_name: &str,
+ self_type: Option<&TypePath>,
+) -> proc_macro2::TokenStream {
+ // generate the span's name
+ let span_name = args
+ // did the user override the span's name?
+ .name
+ .as_ref()
+ .map(|name| quote!(#name))
+ .unwrap_or_else(|| quote!(#instrumented_function_name));
+
+ let level = args.level();
+
+ let follows_from = args.follows_from.iter();
+ let follows_from = quote! {
+ #(for cause in #follows_from {
+ __tracing_attr_span.follows_from(cause);
+ })*
+ };
+
+ // generate this inside a closure, so we can return early on errors.
+ let span = (|| {
+ // Pull out the arguments-to-be-skipped first, so we can filter results
+ // below.
+ let param_names: Vec<(Ident, (Ident, RecordType))> = params
+ .clone()
+ .into_iter()
+ .flat_map(|param| match param {
+ FnArg::Typed(PatType { pat, ty, .. }) => {
+ param_names(*pat, RecordType::parse_from_ty(&*ty))
+ }
+ FnArg::Receiver(_) => Box::new(iter::once((
+ Ident::new("self", param.span()),
+ RecordType::Debug,
+ ))),
+ })
+ // Little dance with new (user-exposed) names and old (internal)
+ // names of identifiers. That way, we could do the following
+ // even though async_trait (<=0.1.43) rewrites "self" as "_self":
+ // ```
+ // #[async_trait]
+ // impl Foo for FooImpl {
+ // #[instrument(skip(self))]
+ // async fn foo(&self, v: usize) {}
+ // }
+ // ```
+ .map(|(x, record_type)| {
+ // if we are inside a function generated by async-trait <=0.1.43, we need to
+ // take care to rewrite "_self" as "self" for 'user convenience'
+ if self_type.is_some() && x == "_self" {
+ (Ident::new("self", x.span()), (x, record_type))
+ } else {
+ (x.clone(), (x, record_type))
+ }
+ })
+ .collect();
+
+ for skip in &args.skips {
+ if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) {
+ return quote_spanned! {skip.span()=>
+ compile_error!("attempting to skip non-existent parameter")
+ };
+ }
+ }
+
+ let target = args.target();
+
+ let parent = args.parent.iter();
+
+ // filter out skipped fields
+ let quoted_fields: Vec<_> = param_names
+ .iter()
+ .filter(|(param, _)| {
+ if args.skip_all || args.skips.contains(param) {
+ return false;
+ }
+
+ // If any parameters have the same name as a custom field, skip
+ // and allow them to be formatted by the custom field.
+ if let Some(ref fields) = args.fields {
+ fields.0.iter().all(|Field { ref name, .. }| {
+ let first = name.first();
+ first != name.last() || !first.iter().any(|name| name == &param)
+ })
+ } else {
+ true
+ }
+ })
+ .map(|(user_name, (real_name, record_type))| match record_type {
+ RecordType::Value => quote!(#user_name = #real_name),
+ RecordType::Debug => quote!(#user_name = tracing::field::debug(&#real_name)),
+ })
+ .collect();
+
+ // replace every use of a variable with its original name
+ if let Some(Fields(ref mut fields)) = args.fields {
+ let mut replacer = IdentAndTypesRenamer {
+ idents: param_names.into_iter().map(|(a, (b, _))| (a, b)).collect(),
+ types: Vec::new(),
+ };
+
+ // when async-trait <=0.1.43 is in use, replace instances
+ // of the "Self" type inside the fields values
+ if let Some(self_type) = self_type {
+ replacer.types.push(("Self", self_type.clone()));
+ }
+
+ for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) {
+ syn::visit_mut::visit_expr_mut(&mut replacer, e);
+ }
+ }
+
+ let custom_fields = &args.fields;
+
+ quote!(tracing::span!(
+ target: #target,
+ #(parent: #parent,)*
+ #level,
+ #span_name,
+ #(#quoted_fields,)*
+ #custom_fields
+
+ ))
+ })();
+
+ let target = args.target();
+
+ let err_event = match args.err_mode {
+ Some(FormatMode::Default) | Some(FormatMode::Display) => {
+ Some(quote!(tracing::error!(target: #target, error = %e)))
+ }
+ Some(FormatMode::Debug) => Some(quote!(tracing::error!(target: #target, error = ?e))),
+ _ => None,
+ };
+
+ let ret_event = match args.ret_mode {
+ Some(FormatMode::Display) => Some(quote!(
+ tracing::event!(target: #target, #level, return = %x)
+ )),
+ Some(FormatMode::Default) | Some(FormatMode::Debug) => Some(quote!(
+ tracing::event!(target: #target, #level, return = ?x)
+ )),
+ _ => None,
+ };
+
+ // Generate the instrumented function body.
+ // If the function is an `async fn`, this will wrap it in an async block,
+ // which is `instrument`ed using `tracing-futures`. Otherwise, this will
+ // enter the span and then perform the rest of the body.
+ // If `err` is in args, instrument any resulting `Err`s.
+ // If `ret` is in args, instrument any resulting `Ok`s when the function
+ // returns `Result`s, otherwise instrument any resulting values.
+ if async_context {
+ let mk_fut = match (err_event, ret_event) {
+ (Some(err_event), Some(ret_event)) => quote_spanned!(block.span()=>
+ async move {
+ match async move #block.await {
+ #[allow(clippy::unit_arg)]
+ Ok(x) => {
+ #ret_event;
+ Ok(x)
+ },
+ Err(e) => {
+ #err_event;
+ Err(e)
+ }
+ }
+ }
+ ),
+ (Some(err_event), None) => quote_spanned!(block.span()=>
+ async move {
+ match async move #block.await {
+ #[allow(clippy::unit_arg)]
+ Ok(x) => Ok(x),
+ Err(e) => {
+ #err_event;
+ Err(e)
+ }
+ }
+ }
+ ),
+ (None, Some(ret_event)) => quote_spanned!(block.span()=>
+ async move {
+ let x = async move #block.await;
+ #ret_event;
+ x
+ }
+ ),
+ (None, None) => quote_spanned!(block.span()=>
+ async move #block
+ ),
+ };
+
+ return quote!(
+ let __tracing_attr_span = #span;
+ let __tracing_instrument_future = #mk_fut;
+ if !__tracing_attr_span.is_disabled() {
+ #follows_from
+ tracing::Instrument::instrument(
+ __tracing_instrument_future,
+ __tracing_attr_span
+ )
+ .await
+ } else {
+ __tracing_instrument_future.await
+ }
+ );
+ }
+
+ let span = quote!(
+ // These variables are left uninitialized and initialized only
+ // if the tracing level is statically enabled at this point.
+ // While the tracing level is also checked at span creation
+ // time, that will still create a dummy span, and a dummy guard
+ // and drop the dummy guard later. By lazily initializing these
+ // variables, Rust will generate a drop flag for them and thus
+ // only drop the guard if it was created. This creates code that
+ // is very straightforward for LLVM to optimize out if the tracing
+ // level is statically disabled, while not causing any performance
+ // regression in case the level is enabled.
+ let __tracing_attr_span;
+ let __tracing_attr_guard;
+ if tracing::level_enabled!(#level) {
+ __tracing_attr_span = #span;
+ #follows_from
+ __tracing_attr_guard = __tracing_attr_span.enter();
+ }
+ );
+
+ match (err_event, ret_event) {
+ (Some(err_event), Some(ret_event)) => quote_spanned! {block.span()=>
+ #span
+ #[allow(clippy::redundant_closure_call)]
+ match (move || #block)() {
+ #[allow(clippy::unit_arg)]
+ Ok(x) => {
+ #ret_event;
+ Ok(x)
+ },
+ Err(e) => {
+ #err_event;
+ Err(e)
+ }
+ }
+ },
+ (Some(err_event), None) => quote_spanned!(block.span()=>
+ #span
+ #[allow(clippy::redundant_closure_call)]
+ match (move || #block)() {
+ #[allow(clippy::unit_arg)]
+ Ok(x) => Ok(x),
+ Err(e) => {
+ #err_event;
+ Err(e)
+ }
+ }
+ ),
+ (None, Some(ret_event)) => quote_spanned!(block.span()=>
+ #span
+ #[allow(clippy::redundant_closure_call)]
+ let x = (move || #block)();
+ #ret_event;
+ x
+ ),
+ (None, None) => quote_spanned!(block.span() =>
+ // Because `quote` produces a stream of tokens _without_ whitespace, the
+ // `if` and the block will appear directly next to each other. This
+ // generates a clippy lint about suspicious `if/else` formatting.
+ // Therefore, suppress the lint inside the generated code...
+ #[allow(clippy::suspicious_else_formatting)]
+ {
+ #span
+ // ...but turn the lint back on inside the function body.
+ #[warn(clippy::suspicious_else_formatting)]
+ #block
+ }
+ ),
+ }
+}
+
+/// Indicates whether a field should be recorded as `Value` or `Debug`.
+enum RecordType {
+ /// The field should be recorded using its `Value` implementation.
+ Value,
+ /// The field should be recorded using `tracing::field::debug()`.
+ Debug,
+}
+
+impl RecordType {
+ /// Array of primitive types which should be recorded as [RecordType::Value].
+ const TYPES_FOR_VALUE: &'static [&'static str] = &[
+ "bool",
+ "str",
+ "u8",
+ "i8",
+ "u16",
+ "i16",
+ "u32",
+ "i32",
+ "u64",
+ "i64",
+ "f32",
+ "f64",
+ "usize",
+ "isize",
+ "NonZeroU8",
+ "NonZeroI8",
+ "NonZeroU16",
+ "NonZeroI16",
+ "NonZeroU32",
+ "NonZeroI32",
+ "NonZeroU64",
+ "NonZeroI64",
+ "NonZeroUsize",
+ "NonZeroIsize",
+ "Wrapping",
+ ];
+
+ /// Parse `RecordType` from [Type] by looking up
+ /// the [RecordType::TYPES_FOR_VALUE] array.
+ fn parse_from_ty(ty: &Type) -> Self {
+ match ty {
+ Type::Path(TypePath { path, .. })
+ if path
+ .segments
+ .iter()
+ .last()
+ .map(|path_segment| {
+ let ident = path_segment.ident.to_string();
+ Self::TYPES_FOR_VALUE.iter().any(|&t| t == ident)
+ })
+ .unwrap_or(false) =>
+ {
+ RecordType::Value
+ }
+ Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(elem),
+ _ => RecordType::Debug,
+ }
+ }
+}
+
+fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>> {
+ match pat {
+ Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once((ident, record_type))),
+ Pat::Reference(PatReference { pat, .. }) => param_names(*pat, record_type),
+ // We can't get the concrete type of fields in the struct/tuple
+ // patterns by using `syn`. e.g. `fn foo(Foo { x, y }: Foo) {}`.
+ // Therefore, the struct/tuple patterns in the arguments will just
+ // always be recorded as `RecordType::Debug`.
+ Pat::Struct(PatStruct { fields, .. }) => Box::new(
+ fields
+ .into_iter()
+ .flat_map(|FieldPat { pat, .. }| param_names(*pat, RecordType::Debug)),
+ ),
+ Pat::Tuple(PatTuple { elems, .. }) => Box::new(
+ elems
+ .into_iter()
+ .flat_map(|p| param_names(p, RecordType::Debug)),
+ ),
+ Pat::TupleStruct(PatTupleStruct {
+ pat: PatTuple { elems, .. },
+ ..
+ }) => Box::new(
+ elems
+ .into_iter()
+ .flat_map(|p| param_names(p, RecordType::Debug)),
+ ),
+
+ // The above *should* cover all cases of irrefutable patterns,
+ // but we purposefully don't do any funny business here
+ // (such as panicking) because that would obscure rustc's
+ // much more informative error message.
+ _ => Box::new(iter::empty()),
+ }
+}
+
+/// The specific async code pattern that was detected
+enum AsyncKind<'a> {
+ /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
+ /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
+ Function(&'a ItemFn),
+ /// A function returning an async (move) block, optionally `Box::pin`-ed,
+ /// as generated by `async-trait >= 0.1.44`:
+ /// `Box::pin(async move { ... })`
+ Async {
+ async_expr: &'a ExprAsync,
+ pinned_box: bool,
+ },
+}
+
+pub(crate) struct AsyncInfo<'block> {
+ // statement that must be patched
+ source_stmt: &'block Stmt,
+ kind: AsyncKind<'block>,
+ self_type: Option<TypePath>,
+ input: &'block ItemFn,
+}
+
+impl<'block> AsyncInfo<'block> {
+ /// Get the AST of the inner function we need to hook, if it looks like a
+ /// manual future implementation.
+ ///
+ /// When we are given a function that returns a (pinned) future containing the
+ /// user logic, it is that (pinned) future that needs to be instrumented.
+ /// Were we to instrument its parent, we would only collect information
+ /// regarding the allocation of that future, and not its own span of execution.
+ ///
+ /// We inspect the block of the function to find if it matches any of the
+ /// following patterns:
+ ///
+ /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
+ /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
+ ///
+ /// - A function returning an async (move) block, optionally `Box::pin`-ed,
+ /// as generated by `async-trait >= 0.1.44`:
+ /// `Box::pin(async move { ... })`
+ ///
+ /// We the return the statement that must be instrumented, along with some
+ /// other information.
+ /// 'gen_body' will then be able to use that information to instrument the
+ /// proper function/future.
+ ///
+ /// (this follows the approach suggested in
+ /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
+ pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> {
+ // are we in an async context? If yes, this isn't a manual async-like pattern
+ if input.sig.asyncness.is_some() {
+ return None;
+ }
+
+ let block = &input.block;
+
+ // list of async functions declared inside the block
+ let inside_funs = block.stmts.iter().filter_map(|stmt| {
+ if let Stmt::Item(Item::Fn(fun)) = &stmt {
+ // If the function is async, this is a candidate
+ if fun.sig.asyncness.is_some() {
+ return Some((stmt, fun));
+ }
+ }
+ None
+ });
+
+ // last expression of the block: it determines the return value of the
+ // block, this is quite likely a `Box::pin` statement or an async block
+ let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| {
+ if let Stmt::Expr(expr) = stmt {
+ Some((stmt, expr))
+ } else {
+ None
+ }
+ })?;
+
+ // is the last expression an async block?
+ if let Expr::Async(async_expr) = last_expr {
+ return Some(AsyncInfo {
+ source_stmt: last_expr_stmt,
+ kind: AsyncKind::Async {
+ async_expr,
+ pinned_box: false,
+ },
+ self_type: None,
+ input,
+ });
+ }
+
+ // is the last expression a function call?
+ let (outside_func, outside_args) = match last_expr {
+ Expr::Call(ExprCall { func, args, .. }) => (func, args),
+ _ => return None,
+ };
+
+ // is it a call to `Box::pin()`?
+ let path = match outside_func.as_ref() {
+ Expr::Path(path) => &path.path,
+ _ => return None,
+ };
+ if !path_to_string(path).ends_with("Box::pin") {
+ return None;
+ }
+
+ // Does the call take an argument? If it doesn't,
+ // it's not gonna compile anyway, but that's no reason
+ // to (try to) perform an out of bounds access
+ if outside_args.is_empty() {
+ return None;
+ }
+
+ // Is the argument to Box::pin an async block that
+ // captures its arguments?
+ if let Expr::Async(async_expr) = &outside_args[0] {
+ return Some(AsyncInfo {
+ source_stmt: last_expr_stmt,
+ kind: AsyncKind::Async {
+ async_expr,
+ pinned_box: true,
+ },
+ self_type: None,
+ input,
+ });
+ }
+
+ // Is the argument to Box::pin a function call itself?
+ let func = match &outside_args[0] {
+ Expr::Call(ExprCall { func, .. }) => func,
+ _ => return None,
+ };
+
+ // "stringify" the path of the function called
+ let func_name = match **func {
+ Expr::Path(ref func_path) => path_to_string(&func_path.path),
+ _ => return None,
+ };
+
+ // Was that function defined inside of the current block?
+ // If so, retrieve the statement where it was declared and the function itself
+ let (stmt_func_declaration, func) = inside_funs
+ .into_iter()
+ .find(|(_, fun)| fun.sig.ident == func_name)?;
+
+ // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the
+ // parameter type) with the type of "_self"
+ let mut self_type = None;
+ for arg in &func.sig.inputs {
+ if let FnArg::Typed(ty) = arg {
+ if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat {
+ if ident == "_self" {
+ let mut ty = *ty.ty.clone();
+ // extract the inner type if the argument is "&self" or "&mut self"
+ if let Type::Reference(syn::TypeReference { elem, .. }) = ty {
+ ty = *elem;
+ }
+
+ if let Type::Path(tp) = ty {
+ self_type = Some(tp);
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ Some(AsyncInfo {
+ source_stmt: stmt_func_declaration,
+ kind: AsyncKind::Function(func),
+ self_type,
+ input,
+ })
+ }
+
+ pub(crate) fn gen_async(
+ self,
+ args: InstrumentArgs,
+ instrumented_function_name: &str,
+ ) -> Result<proc_macro::TokenStream, syn::Error> {
+ // let's rewrite some statements!
+ let mut out_stmts: Vec<TokenStream> = self
+ .input
+ .block
+ .stmts
+ .iter()
+ .map(|stmt| stmt.to_token_stream())
+ .collect();
+
+ if let Some((iter, _stmt)) = self
+ .input
+ .block
+ .stmts
+ .iter()
+ .enumerate()
+ .find(|(_iter, stmt)| *stmt == self.source_stmt)
+ {
+ // instrument the future by rewriting the corresponding statement
+ out_stmts[iter] = match self.kind {
+ // `Box::pin(immediately_invoked_async_fn())`
+ AsyncKind::Function(fun) => {
+ let fun = MaybeItemFn::from(fun.clone());
+ gen_function(
+ fun.as_ref(),
+ args,
+ instrumented_function_name,
+ self.self_type.as_ref(),
+ )
+ }
+ // `async move { ... }`, optionally pinned
+ AsyncKind::Async {
+ async_expr,
+ pinned_box,
+ } => {
+ let instrumented_block = gen_block(
+ &async_expr.block,
+ &self.input.sig.inputs,
+ true,
+ args,
+ instrumented_function_name,
+ None,
+ );
+ let async_attrs = &async_expr.attrs;
+ if pinned_box {
+ quote! {
+ Box::pin(#(#async_attrs) * async move { #instrumented_block })
+ }
+ } else {
+ quote! {
+ #(#async_attrs) * async move { #instrumented_block }
+ }
+ }
+ }
+ };
+ }
+
+ let vis = &self.input.vis;
+ let sig = &self.input.sig;
+ let attrs = &self.input.attrs;
+ Ok(quote!(
+ #(#attrs) *
+ #vis #sig {
+ #(#out_stmts) *
+ }
+ )
+ .into())
+ }
+}
+
+// Return a path as a String
+fn path_to_string(path: &Path) -> String {
+ use std::fmt::Write;
+ // some heuristic to prevent too many allocations
+ let mut res = String::with_capacity(path.segments.len() * 5);
+ for i in 0..path.segments.len() {
+ write!(&mut res, "{}", path.segments[i].ident)
+ .expect("writing to a String should never fail");
+ if i < path.segments.len() - 1 {
+ res.push_str("::");
+ }
+ }
+ res
+}
+
+/// A visitor struct to replace idents and types in some piece
+/// of code (e.g. the "self" and "Self" tokens in user-supplied
+/// fields expressions when the function is generated by an old
+/// version of async-trait).
+struct IdentAndTypesRenamer<'a> {
+ types: Vec<(&'a str, TypePath)>,
+ idents: Vec<(Ident, Ident)>,
+}
+
+impl<'a> VisitMut for IdentAndTypesRenamer<'a> {
+ // we deliberately compare strings because we want to ignore the spans
+ // If we apply clippy's lint, the behavior changes
+ #[allow(clippy::cmp_owned)]
+ fn visit_ident_mut(&mut self, id: &mut Ident) {
+ for (old_ident, new_ident) in &self.idents {
+ if id.to_string() == old_ident.to_string() {
+ *id = new_ident.clone();
+ }
+ }
+ }
+
+ fn visit_type_mut(&mut self, ty: &mut Type) {
+ for (type_name, new_type) in &self.types {
+ if let Type::Path(TypePath { path, .. }) = ty {
+ if path_to_string(path) == *type_name {
+ *ty = Type::Path(new_type.clone());
+ }
+ }
+ }
+ }
+}
+
+// A visitor struct that replace an async block by its patched version
+struct AsyncTraitBlockReplacer<'a> {
+ block: &'a Block,
+ patched_block: Block,
+}
+
+impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> {
+ fn visit_block_mut(&mut self, i: &mut Block) {
+ if i == self.block {
+ *i = self.patched_block.clone();
+ }
+ }
+}
+
+// Replaces any `impl Trait` with `_` so it can be used as the type in
+// a `let` statement's LHS.
+struct ImplTraitEraser;
+
+impl VisitMut for ImplTraitEraser {
+ fn visit_type_mut(&mut self, t: &mut Type) {
+ if let Type::ImplTrait(..) = t {
+ *t = syn::TypeInfer {
+ underscore_token: Token![_](t.span()),
+ }
+ .into();
+ } else {
+ syn::visit_mut::visit_type_mut(self, t);
+ }
+ }
+}
+
+fn erase_impl_trait(ty: &Type) -> Type {
+ let mut ty = ty.clone();
+ ImplTraitEraser.visit_type_mut(&mut ty);
+ ty
+}