use proc_macro::TokenStream; use proc_macro2::{Ident, Span}; use quote::{quote, quote_spanned, ToTokens}; use syn::parse::Parser; // syn::AttributeArgs does not implement syn::Parse type AttributeArgs = syn::punctuated::Punctuated; #[derive(Clone, Copy, PartialEq)] enum RuntimeFlavor { CurrentThread, Threaded, } impl RuntimeFlavor { fn from_str(s: &str) -> Result { match s { "current_thread" => Ok(RuntimeFlavor::CurrentThread), "multi_thread" => Ok(RuntimeFlavor::Threaded), "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()), "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()), "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()), _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)), } } } struct FinalConfig { flavor: RuntimeFlavor, worker_threads: Option, start_paused: Option, crate_name: Option, } /// Config used in case of the attribute not being able to build a valid config const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { flavor: RuntimeFlavor::CurrentThread, worker_threads: None, start_paused: None, crate_name: None, }; struct Configuration { rt_multi_thread_available: bool, default_flavor: RuntimeFlavor, flavor: Option, worker_threads: Option<(usize, Span)>, start_paused: Option<(bool, Span)>, is_test: bool, crate_name: Option, } impl Configuration { fn new(is_test: bool, rt_multi_thread: bool) -> Self { Configuration { rt_multi_thread_available: rt_multi_thread, default_flavor: match is_test { true => RuntimeFlavor::CurrentThread, false => RuntimeFlavor::Threaded, }, flavor: None, worker_threads: None, start_paused: None, is_test, crate_name: None, } } fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { if self.flavor.is_some() { return Err(syn::Error::new(span, "`flavor` set multiple times.")); } let runtime_str = parse_string(runtime, span, "flavor")?; let runtime = RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; self.flavor = Some(runtime); Ok(()) } fn set_worker_threads( &mut self, worker_threads: syn::Lit, span: Span, ) -> Result<(), syn::Error> { if self.worker_threads.is_some() { return Err(syn::Error::new( span, "`worker_threads` set multiple times.", )); } let worker_threads = parse_int(worker_threads, span, "worker_threads")?; if worker_threads == 0 { return Err(syn::Error::new(span, "`worker_threads` may not be 0.")); } self.worker_threads = Some((worker_threads, span)); Ok(()) } fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> { if self.start_paused.is_some() { return Err(syn::Error::new(span, "`start_paused` set multiple times.")); } let start_paused = parse_bool(start_paused, span, "start_paused")?; self.start_paused = Some((start_paused, span)); Ok(()) } fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { if self.crate_name.is_some() { return Err(syn::Error::new(span, "`crate` set multiple times.")); } let name_ident = parse_ident(name, span, "crate")?; self.crate_name = Some(name_ident.to_string()); Ok(()) } fn macro_name(&self) -> &'static str { if self.is_test { "tokio::test" } else { "tokio::main" } } fn build(&self) -> Result { let flavor = self.flavor.unwrap_or(self.default_flavor); use RuntimeFlavor::*; let worker_threads = match (flavor, self.worker_threads) { (CurrentThread, Some((_, worker_threads_span))) => { let msg = format!( "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`", self.macro_name(), ); return Err(syn::Error::new(worker_threads_span, msg)); } (CurrentThread, None) => None, (Threaded, worker_threads) if self.rt_multi_thread_available => { worker_threads.map(|(val, _span)| val) } (Threaded, _) => { let msg = if self.flavor.is_none() { "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." } else { "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." }; return Err(syn::Error::new(Span::call_site(), msg)); } }; let start_paused = match (flavor, self.start_paused) { (Threaded, Some((_, start_paused_span))) => { let msg = format!( "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`", self.macro_name(), ); return Err(syn::Error::new(start_paused_span, msg)); } (CurrentThread, Some((start_paused, _))) => Some(start_paused), (_, None) => None, }; Ok(FinalConfig { crate_name: self.crate_name.clone(), flavor, worker_threads, start_paused, }) } } fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result { match int { syn::Lit::Int(lit) => match lit.base10_parse::() { Ok(value) => Ok(value), Err(e) => Err(syn::Error::new( span, format!("Failed to parse value of `{}` as integer: {}", field, e), )), }, _ => Err(syn::Error::new( span, format!("Failed to parse value of `{}` as integer.", field), )), } } fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result { match int { syn::Lit::Str(s) => Ok(s.value()), syn::Lit::Verbatim(s) => Ok(s.to_string()), _ => Err(syn::Error::new( span, format!("Failed to parse value of `{}` as string.", field), )), } } fn parse_ident(lit: syn::Lit, span: Span, field: &str) -> Result { match lit { syn::Lit::Str(s) => { let err = syn::Error::new( span, format!( "Failed to parse value of `{}` as ident: \"{}\"", field, s.value() ), ); let path = s.parse::().map_err(|_| err.clone())?; path.get_ident().cloned().ok_or(err) } _ => Err(syn::Error::new( span, format!("Failed to parse value of `{}` as ident.", field), )), } } fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result { match bool { syn::Lit::Bool(b) => Ok(b.value), _ => Err(syn::Error::new( span, format!("Failed to parse value of `{}` as bool.", field), )), } } fn build_config( input: syn::ItemFn, args: AttributeArgs, is_test: bool, rt_multi_thread: bool, ) -> Result { if input.sig.asyncness.is_none() { let msg = "the `async` keyword is missing from the function declaration"; return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); } let mut config = Configuration::new(is_test, rt_multi_thread); let macro_name = config.macro_name(); for arg in args { match arg { syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => { let ident = namevalue .path .get_ident() .ok_or_else(|| { syn::Error::new_spanned(&namevalue, "Must have specified ident") })? .to_string() .to_lowercase(); match ident.as_str() { "worker_threads" => { config.set_worker_threads( namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit), )?; } "flavor" => { config.set_flavor( namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit), )?; } "start_paused" => { config.set_start_paused( namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit), )?; } "core_threads" => { let msg = "Attribute `core_threads` is renamed to `worker_threads`"; return Err(syn::Error::new_spanned(namevalue, msg)); } "crate" => { config.set_crate_name( namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit), )?; } name => { let msg = format!( "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name, ); return Err(syn::Error::new_spanned(namevalue, msg)); } } } syn::NestedMeta::Meta(syn::Meta::Path(path)) => { let name = path .get_ident() .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))? .to_string() .to_lowercase(); let msg = match name.as_str() { "threaded_scheduler" | "multi_thread" => { format!( "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].", macro_name ) } "basic_scheduler" | "current_thread" | "single_threaded" => { format!( "Set the runtime flavor with #[{}(flavor = \"current_thread\")].", macro_name ) } "flavor" | "worker_threads" | "start_paused" => { format!("The `{}` attribute requires an argument.", name) } name => { format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name) } }; return Err(syn::Error::new_spanned(path, msg)); } other => { return Err(syn::Error::new_spanned( other, "Unknown attribute inside the macro", )); } } } config.build() } fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> TokenStream { input.sig.asyncness = None; // If type mismatch occurs, the current rustc points to the last statement. let (last_stmt_start_span, last_stmt_end_span) = { let mut last_stmt = input .block .stmts .last() .map(ToTokens::into_token_stream) .unwrap_or_default() .into_iter(); // `Span` on stable Rust has a limitation that only points to the first // token, not the whole tokens. We can work around this limitation by // using the first/last span of the tokens like // `syn::Error::new_spanned` does. let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); let end = last_stmt.last().map_or(start, |t| t.span()); (start, end) }; let crate_name = config.crate_name.as_deref().unwrap_or("tokio"); let crate_ident = Ident::new(crate_name, last_stmt_start_span); let mut rt = match config.flavor { RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> #crate_ident::runtime::Builder::new_current_thread() }, RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> #crate_ident::runtime::Builder::new_multi_thread() }, }; if let Some(v) = config.worker_threads { rt = quote! { #rt.worker_threads(#v) }; } if let Some(v) = config.start_paused { rt = quote! { #rt.start_paused(#v) }; } let header = if is_test { quote! { #[::core::prelude::v1::test] } } else { quote! {} }; let body = &input.block; let brace_token = input.block.brace_token; let body_ident = quote! { body }; let block_expr = quote_spanned! {last_stmt_end_span=> #[allow(clippy::expect_used, clippy::diverging_sub_expression)] { return #rt .enable_all() .build() .expect("Failed building the Runtime") .block_on(#body_ident); } }; // For test functions pin the body to the stack and use `Pin<&mut dyn // Future>` to reduce the amount of `Runtime::block_on` (and related // functions) copies we generate during compilation due to the generic // parameter `F` (the future to block on). This could have an impact on // performance, but because it's only for testing it's unlikely to be very // large. // // We don't do this for the main function as it should only be used once so // there will be no benefit. let body = if is_test { let output_type = match &input.sig.output { // For functions with no return value syn doesn't print anything, // but that doesn't work as `Output` for our boxed `Future`, so // default to `()` (the same type as the function output). syn::ReturnType::Default => quote! { () }, syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, }; quote! { let body = async #body; #crate_ident::pin!(body); let body: ::std::pin::Pin<&mut dyn ::std::future::Future> = body; } } else { quote! { let body = async #body; } }; input.block = syn::parse2(quote! { { #body #block_expr } }) .expect("Parsing failure"); input.block.brace_token = brace_token; let result = quote! { #header #input }; result.into() } fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { tokens.extend(TokenStream::from(error.into_compile_error())); tokens } #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub(crate) fn main(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { // If any of the steps for this macro fail, we still want to expand to an item that is as close // to the expected output as possible. This helps out IDEs such that completions and other // related features keep working. let input: syn::ItemFn = match syn::parse(item.clone()) { Ok(it) => it, Err(e) => return token_stream_with_error(item, e), }; let config = if input.sig.ident == "main" && !input.sig.inputs.is_empty() { let msg = "the main function cannot accept arguments"; Err(syn::Error::new_spanned(&input.sig.ident, msg)) } else { AttributeArgs::parse_terminated .parse(args) .and_then(|args| build_config(input.clone(), args, false, rt_multi_thread)) }; match config { Ok(config) => parse_knobs(input, false, config), Err(e) => token_stream_with_error(parse_knobs(input, false, DEFAULT_ERROR_CONFIG), e), } } pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { // If any of the steps for this macro fail, we still want to expand to an item that is as close // to the expected output as possible. This helps out IDEs such that completions and other // related features keep working. let input: syn::ItemFn = match syn::parse(item.clone()) { Ok(it) => it, Err(e) => return token_stream_with_error(item, e), }; let config = if let Some(attr) = input.attrs.iter().find(|attr| attr.path.is_ident("test")) { let msg = "second test attribute is supplied"; Err(syn::Error::new_spanned(attr, msg)) } else { AttributeArgs::parse_terminated .parse(args) .and_then(|args| build_config(input.clone(), args, true, rt_multi_thread)) }; match config { Ok(config) => parse_knobs(input, true, config), Err(e) => token_stream_with_error(parse_knobs(input, true, DEFAULT_ERROR_CONFIG), e), } }