// SPDX-License-Identifier: GPL-2.0 use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; use std::collections::HashSet; use std::fmt::Write; pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { let mut tokens: Vec<_> = ts.into_iter().collect(); // Scan for the `trait` or `impl` keyword. let is_trait = tokens .iter() .find_map(|token| match token { TokenTree::Ident(ident) => match ident.to_string().as_str() { "trait" => Some(true), "impl" => Some(false), _ => None, }, _ => None, }) .expect("#[vtable] attribute should only be applied to trait or impl block"); // Retrieve the main body. The main body should be the last token tree. let body = match tokens.pop() { Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, _ => panic!("cannot locate main body of trait or impl block"), }; let mut body_it = body.stream().into_iter(); let mut functions = Vec::new(); let mut consts = HashSet::new(); while let Some(token) = body_it.next() { match token { TokenTree::Ident(ident) if ident.to_string() == "fn" => { let fn_name = match body_it.next() { Some(TokenTree::Ident(ident)) => ident.to_string(), // Possibly we've encountered a fn pointer type instead. _ => continue, }; functions.push(fn_name); } TokenTree::Ident(ident) if ident.to_string() == "const" => { let const_name = match body_it.next() { Some(TokenTree::Ident(ident)) => ident.to_string(), // Possibly we've encountered an inline const block instead. _ => continue, }; consts.insert(const_name); } _ => (), } } let mut const_items; if is_trait { const_items = " /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) /// attribute when implementing this trait. const USE_VTABLE_ATTR: (); " .to_owned(); for f in functions { let gen_const_name = format!("HAS_{}", f.to_uppercase()); // Skip if it's declared already -- this allows user override. if consts.contains(&gen_const_name) { continue; } // We don't know on the implementation-site whether a method is required or provided // so we have to generate a const for all methods. write!( const_items, "/// Indicates if the `{f}` method is overridden by the implementor. const {gen_const_name}: bool = false;", ) .unwrap(); consts.insert(gen_const_name); } } else { const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); for f in functions { let gen_const_name = format!("HAS_{}", f.to_uppercase()); if consts.contains(&gen_const_name) { continue; } write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); } } let new_body = vec![const_items.parse().unwrap(), body.stream()] .into_iter() .collect(); tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); tokens.into_iter().collect() }