summaryrefslogtreecommitdiffstats
path: root/third_party/rust/extend/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/extend/src/lib.rs')
-rw-r--r--third_party/rust/extend/src/lib.rs620
1 files changed, 620 insertions, 0 deletions
diff --git a/third_party/rust/extend/src/lib.rs b/third_party/rust/extend/src/lib.rs
new file mode 100644
index 0000000000..9b8188a4a0
--- /dev/null
+++ b/third_party/rust/extend/src/lib.rs
@@ -0,0 +1,620 @@
+//! Create extensions for types you don't own with [extension traits] but without the boilerplate.
+//!
+//! Example:
+//!
+//! ```rust
+//! use extend::ext;
+//!
+//! #[ext]
+//! impl<T: Ord> Vec<T> {
+//! fn sorted(mut self) -> Self {
+//! self.sort();
+//! self
+//! }
+//! }
+//!
+//! fn main() {
+//! assert_eq!(
+//! vec![1, 2, 3],
+//! vec![2, 3, 1].sorted(),
+//! );
+//! }
+//! ```
+//!
+//! # How does it work?
+//!
+//! Under the hood it generates a trait with methods in your `impl` and implements those for the
+//! type you specify. The code shown above expands roughly to:
+//!
+//! ```rust
+//! trait VecExt<T: Ord> {
+//! fn sorted(self) -> Self;
+//! }
+//!
+//! impl<T: Ord> VecExt<T> for Vec<T> {
+//! fn sorted(mut self) -> Self {
+//! self.sort();
+//! self
+//! }
+//! }
+//! ```
+//!
+//! # Supported items
+//!
+//! Extensions can contain methods or associated constants:
+//!
+//! ```rust
+//! use extend::ext;
+//!
+//! #[ext]
+//! impl String {
+//! const CONSTANT: &'static str = "FOO";
+//!
+//! fn method() {
+//! // ...
+//! # todo!()
+//! }
+//! }
+//! ```
+//!
+//! # Configuration
+//!
+//! You can configure:
+//!
+//! - The visibility of the trait. Use `pub impl ...` to generate `pub trait ...`. The default
+//! visibility is private.
+//! - The name of the generated extension trait. Example: `#[ext(name = MyExt)]`. By default we
+//! generate a name based on what you extend.
+//! - Which supertraits the generated extension trait should have. Default is no supertraits.
+//! Example: `#[ext(supertraits = Default + Clone)]`.
+//!
+//! More examples:
+//!
+//! ```rust
+//! use extend::ext;
+//!
+//! #[ext(name = SortedVecExt)]
+//! impl<T: Ord> Vec<T> {
+//! fn sorted(mut self) -> Self {
+//! self.sort();
+//! self
+//! }
+//! }
+//!
+//! #[ext]
+//! pub(crate) impl i32 {
+//! fn double(self) -> i32 {
+//! self * 2
+//! }
+//! }
+//!
+//! #[ext(name = ResultSafeUnwrapExt)]
+//! pub impl<T> Result<T, std::convert::Infallible> {
+//! fn safe_unwrap(self) -> T {
+//! match self {
+//! Ok(t) => t,
+//! Err(_) => unreachable!(),
+//! }
+//! }
+//! }
+//!
+//! #[ext(supertraits = Default + Clone)]
+//! impl String {
+//! fn my_length(self) -> usize {
+//! self.len()
+//! }
+//! }
+//! ```
+//!
+//! For backwards compatibility you can also declare the visibility as the first argument to `#[ext]`:
+//!
+//! ```
+//! use extend::ext;
+//!
+//! #[ext(pub)]
+//! impl i32 {
+//! fn double(self) -> i32 {
+//! self * 2
+//! }
+//! }
+//! ```
+//!
+//! # async-trait compatibility
+//!
+//! Async extensions are supported via [async-trait](https://crates.io/crates/async-trait).
+//!
+//! Be aware that you need to add `#[async_trait]` _below_ `#[ext]`. Otherwise the `ext` macro
+//! cannot see the `#[async_trait]` attribute and pass it along in the generated code.
+//!
+//! Example:
+//!
+//! ```
+//! use extend::ext;
+//! use async_trait::async_trait;
+//!
+//! #[ext]
+//! #[async_trait]
+//! impl String {
+//! async fn read_file() -> String {
+//! // ...
+//! # todo!()
+//! }
+//! }
+//! ```
+//!
+//! # Other attributes
+//!
+//! Other attributes provided _below_ `#[ext]` will be passed along to both the generated trait and
+//! the implementation. See [async-trait compatibility](#async-trait-compatibility) above for an
+//! example.
+//!
+//! [extension traits]: https://dev.to/matsimitsu/extending-existing-functionality-in-rust-with-traits-in-rust-3622
+
+#![doc(html_root_url = "https://docs.rs/extend/1.1.2")]
+#![allow(clippy::let_and_return)]
+#![deny(
+ unused_variables,
+ mutable_borrow_reservation_conflict,
+ dead_code,
+ unused_must_use,
+ unused_imports
+)]
+
+use proc_macro2::TokenStream;
+use proc_macro_error::*;
+use quote::{format_ident, quote, ToTokens};
+use syn::{
+ parse::{self, Parse, ParseStream},
+ parse_macro_input, parse_quote,
+ punctuated::Punctuated,
+ spanned::Spanned,
+ token::{Add, Semi},
+ Ident, ImplItem, ItemImpl, Token, TraitItemConst, TraitItemMethod, Type, TypeArray, TypeBareFn,
+ TypeGroup, TypeNever, TypeParamBound, TypeParen, TypePath, TypePtr, TypeReference, TypeSlice,
+ TypeTraitObject, TypeTuple, Visibility,
+};
+
+#[derive(Debug)]
+struct Input {
+ item_impl: ItemImpl,
+ vis: Option<Visibility>,
+}
+
+impl Parse for Input {
+ fn parse(input: ParseStream) -> syn::Result<Self> {
+ let mut attributes = Vec::new();
+ if input.peek(syn::Token![#]) {
+ attributes.extend(syn::Attribute::parse_outer(input)?);
+ }
+
+ let vis = input
+ .parse::<Visibility>()
+ .ok()
+ .filter(|vis| vis != &Visibility::Inherited);
+
+ let mut item_impl = input.parse::<ItemImpl>()?;
+ item_impl.attrs.extend(attributes);
+
+ Ok(Self { item_impl, vis })
+ }
+}
+
+/// See crate docs for more info.
+#[proc_macro_attribute]
+#[proc_macro_error]
+#[allow(clippy::unneeded_field_pattern)]
+pub fn ext(
+ attr: proc_macro::TokenStream,
+ item: proc_macro::TokenStream,
+) -> proc_macro::TokenStream {
+ let item = parse_macro_input!(item as Input);
+ let config = parse_macro_input!(attr as Config);
+ go(item, config)
+}
+
+/// Like [`ext`](macro@crate::ext) but always add `Sized` as a supertrait.
+///
+/// This is provided as a convenience for generating extension traits that require `Self: Sized`
+/// such as:
+///
+/// ```
+/// use extend::ext_sized;
+///
+/// #[ext_sized]
+/// impl i32 {
+/// fn requires_sized(self) -> Option<Self> {
+/// Some(self)
+/// }
+/// }
+/// ```
+#[proc_macro_attribute]
+#[proc_macro_error]
+#[allow(clippy::unneeded_field_pattern)]
+pub fn ext_sized(
+ attr: proc_macro::TokenStream,
+ item: proc_macro::TokenStream,
+) -> proc_macro::TokenStream {
+ let item = parse_macro_input!(item as Input);
+ let mut config: Config = parse_macro_input!(attr as Config);
+
+ config.supertraits = if let Some(supertraits) = config.supertraits.take() {
+ Some(parse_quote!(#supertraits + Sized))
+ } else {
+ Some(parse_quote!(Sized))
+ };
+
+ go(item, config)
+}
+
+fn go(item: Input, mut config: Config) -> proc_macro::TokenStream {
+ if let Some(vis) = item.vis {
+ if config.visibility != Visibility::Inherited {
+ abort!(
+ config.visibility.span(),
+ "Cannot set visibility on `#[ext]` and `impl` block"
+ );
+ }
+
+ config.visibility = vis;
+ }
+
+ let ItemImpl {
+ attrs,
+ unsafety,
+ generics,
+ trait_,
+ self_ty,
+ items,
+ // What is defaultness?
+ defaultness: _,
+ impl_token: _,
+ brace_token: _,
+ } = item.item_impl;
+
+ if let Some((_, path, _)) = trait_ {
+ abort!(path.span(), "Trait impls cannot be used for #[ext]");
+ }
+
+ let self_ty = parse_self_ty(&self_ty);
+
+ let ext_trait_name = config
+ .ext_trait_name
+ .unwrap_or_else(|| ext_trait_name(&self_ty));
+
+ let MethodsAndConsts {
+ trait_methods,
+ trait_consts,
+ } = extract_allowed_items(&items);
+
+ let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
+
+ let visibility = &config.visibility;
+
+ let mut all_supertraits = Vec::<TypeParamBound>::new();
+
+ if let Some(supertraits_from_config) = config.supertraits {
+ all_supertraits.extend(supertraits_from_config);
+ }
+
+ let supertraits_quoted = if all_supertraits.is_empty() {
+ quote! {}
+ } else {
+ let supertraits_quoted = punctuated_from_iter::<_, _, Add>(all_supertraits);
+ quote! { : #supertraits_quoted }
+ };
+
+ let code = (quote! {
+ #[allow(non_camel_case_types)]
+ #(#attrs)*
+ #visibility
+ #unsafety
+ trait #ext_trait_name #impl_generics #supertraits_quoted #where_clause {
+ #(
+ #trait_consts
+ )*
+
+ #(
+ #[allow(
+ patterns_in_fns_without_body,
+ clippy::inline_fn_without_body,
+ unused_attributes
+ )]
+ #trait_methods
+ )*
+ }
+
+ #(#attrs)*
+ impl #impl_generics #ext_trait_name #ty_generics for #self_ty #where_clause {
+ #(#items)*
+ }
+ })
+ .into();
+
+ code
+}
+
+#[derive(Debug, Clone)]
+enum ExtType<'a> {
+ Array(&'a TypeArray),
+ Group(&'a TypeGroup),
+ Never(&'a TypeNever),
+ Paren(&'a TypeParen),
+ Path(&'a TypePath),
+ Ptr(&'a TypePtr),
+ Reference(&'a TypeReference),
+ Slice(&'a TypeSlice),
+ Tuple(&'a TypeTuple),
+ BareFn(&'a TypeBareFn),
+ TraitObject(&'a TypeTraitObject),
+}
+
+#[allow(clippy::wildcard_in_or_patterns)]
+fn parse_self_ty(self_ty: &Type) -> ExtType {
+ match self_ty {
+ Type::Array(inner) => ExtType::Array(inner),
+ Type::Group(inner) => ExtType::Group(inner),
+ Type::Never(inner) => ExtType::Never(inner),
+ Type::Paren(inner) => ExtType::Paren(inner),
+ Type::Path(inner) => ExtType::Path(inner),
+ Type::Ptr(inner) => ExtType::Ptr(inner),
+ Type::Reference(inner) => ExtType::Reference(inner),
+ Type::Slice(inner) => ExtType::Slice(inner),
+ Type::Tuple(inner) => ExtType::Tuple(inner),
+ Type::BareFn(inner) => ExtType::BareFn(inner),
+ Type::TraitObject(inner) => ExtType::TraitObject(inner),
+
+ Type::ImplTrait(_) | Type::Infer(_) | Type::Macro(_) | Type::Verbatim(_) | _ => abort!(
+ self_ty.span(),
+ "#[ext] is not supported for this kind of type"
+ ),
+ }
+}
+
+impl<'a> From<&'a Type> for ExtType<'a> {
+ fn from(inner: &'a Type) -> ExtType<'a> {
+ parse_self_ty(inner)
+ }
+}
+
+impl<'a> ToTokens for ExtType<'a> {
+ fn to_tokens(&self, tokens: &mut TokenStream) {
+ match self {
+ ExtType::Array(inner) => inner.to_tokens(tokens),
+ ExtType::Group(inner) => inner.to_tokens(tokens),
+ ExtType::Never(inner) => inner.to_tokens(tokens),
+ ExtType::Paren(inner) => inner.to_tokens(tokens),
+ ExtType::Path(inner) => inner.to_tokens(tokens),
+ ExtType::Ptr(inner) => inner.to_tokens(tokens),
+ ExtType::Reference(inner) => inner.to_tokens(tokens),
+ ExtType::Slice(inner) => inner.to_tokens(tokens),
+ ExtType::Tuple(inner) => inner.to_tokens(tokens),
+ ExtType::BareFn(inner) => inner.to_tokens(tokens),
+ ExtType::TraitObject(inner) => inner.to_tokens(tokens),
+ }
+ }
+}
+
+fn ext_trait_name(self_ty: &ExtType) -> Ident {
+ fn inner_self_ty(self_ty: &ExtType) -> Ident {
+ match self_ty {
+ ExtType::Path(inner) => find_and_combine_idents(inner),
+ ExtType::Reference(inner) => {
+ let name = inner_self_ty(&(&*inner.elem).into());
+ if inner.mutability.is_some() {
+ format_ident!("RefMut{}", name)
+ } else {
+ format_ident!("Ref{}", name)
+ }
+ }
+ ExtType::Array(inner) => {
+ let name = inner_self_ty(&(&*inner.elem).into());
+ format_ident!("ListOf{}", name)
+ }
+ ExtType::Group(inner) => {
+ let name = inner_self_ty(&(&*inner.elem).into());
+ format_ident!("Group{}", name)
+ }
+ ExtType::Paren(inner) => {
+ let name = inner_self_ty(&(&*inner.elem).into());
+ format_ident!("Paren{}", name)
+ }
+ ExtType::Ptr(inner) => {
+ let name = inner_self_ty(&(&*inner.elem).into());
+ format_ident!("PointerTo{}", name)
+ }
+ ExtType::Slice(inner) => {
+ let name = inner_self_ty(&(&*inner.elem).into());
+ format_ident!("SliceOf{}", name)
+ }
+ ExtType::Tuple(inner) => {
+ let mut name = format_ident!("TupleOf");
+ for elem in &inner.elems {
+ name = format_ident!("{}{}", name, inner_self_ty(&elem.into()));
+ }
+ name
+ }
+ ExtType::Never(_) => format_ident!("Never"),
+ ExtType::BareFn(inner) => {
+ let mut name = format_ident!("BareFn");
+ for input in inner.inputs.iter() {
+ name = format_ident!("{}{}", name, inner_self_ty(&(&input.ty).into()));
+ }
+ match &inner.output {
+ syn::ReturnType::Default => {
+ name = format_ident!("{}Unit", name);
+ }
+ syn::ReturnType::Type(_, ty) => {
+ name = format_ident!("{}{}", name, inner_self_ty(&(&**ty).into()));
+ }
+ }
+ name
+ }
+ ExtType::TraitObject(inner) => {
+ let mut name = format_ident!("TraitObject");
+ for bound in inner.bounds.iter() {
+ match bound {
+ TypeParamBound::Trait(bound) => {
+ for segment in bound.path.segments.iter() {
+ name = format_ident!("{}{}", name, segment.ident);
+ }
+ }
+ TypeParamBound::Lifetime(lifetime) => {
+ name = format_ident!("{}{}", name, lifetime.ident);
+ }
+ }
+ }
+ name
+ }
+ }
+ }
+
+ format_ident!("{}Ext", inner_self_ty(self_ty))
+}
+
+fn find_and_combine_idents(type_path: &TypePath) -> Ident {
+ use syn::visit::{self, Visit};
+
+ struct IdentVisitor<'a>(Vec<&'a Ident>);
+
+ impl<'a> Visit<'a> for IdentVisitor<'a> {
+ fn visit_ident(&mut self, i: &'a Ident) {
+ self.0.push(i);
+ }
+ }
+
+ let mut visitor = IdentVisitor(Vec::new());
+ visit::visit_type_path(&mut visitor, type_path);
+ let idents = visitor.0;
+
+ if idents.is_empty() {
+ abort!(type_path.span(), "Empty type path")
+ } else {
+ let start = &idents[0].span();
+ let combined_span = idents
+ .iter()
+ .map(|i| i.span())
+ .fold(*start, |a, b| a.join(b).unwrap_or(a));
+
+ let combined_name = idents.iter().map(|i| i.to_string()).collect::<String>();
+
+ Ident::new(&combined_name, combined_span)
+ }
+}
+
+#[derive(Debug, Default)]
+struct MethodsAndConsts {
+ trait_methods: Vec<TraitItemMethod>,
+ trait_consts: Vec<TraitItemConst>,
+}
+
+#[allow(clippy::wildcard_in_or_patterns)]
+fn extract_allowed_items(items: &[ImplItem]) -> MethodsAndConsts {
+ let mut acc = MethodsAndConsts::default();
+ for item in items {
+ match item {
+ ImplItem::Method(method) => acc.trait_methods.push(TraitItemMethod {
+ attrs: method.attrs.clone(),
+ sig: method.sig.clone(),
+ default: None,
+ semi_token: Some(Semi::default()),
+ }),
+ ImplItem::Const(const_) => acc.trait_consts.push(TraitItemConst {
+ attrs: const_.attrs.clone(),
+ const_token: Default::default(),
+ ident: const_.ident.clone(),
+ colon_token: Default::default(),
+ ty: const_.ty.clone(),
+ default: None,
+ semi_token: Default::default(),
+ }),
+ ImplItem::Type(_) => abort!(
+ item.span(),
+ "Associated types are not allowed in #[ext] impls"
+ ),
+ ImplItem::Macro(_) => abort!(item.span(), "Macros are not allowed in #[ext] impls"),
+ ImplItem::Verbatim(_) | _ => abort!(item.span(), "Not allowed in #[ext] impls"),
+ }
+ }
+ acc
+}
+
+#[derive(Debug)]
+struct Config {
+ ext_trait_name: Option<Ident>,
+ visibility: Visibility,
+ supertraits: Option<Punctuated<TypeParamBound, Add>>,
+}
+
+impl Parse for Config {
+ fn parse(input: ParseStream) -> parse::Result<Self> {
+ let mut config = Config::default();
+
+ if let Ok(visibility) = input.parse::<Visibility>() {
+ config.visibility = visibility;
+ }
+
+ input.parse::<Token![,]>().ok();
+
+ while !input.is_empty() {
+ let ident = input.parse::<Ident>()?;
+ input.parse::<Token![=]>()?;
+
+ match &*ident.to_string() {
+ "name" => {
+ config.ext_trait_name = Some(input.parse()?);
+ }
+ "supertraits" => {
+ config.supertraits =
+ Some(Punctuated::<TypeParamBound, Add>::parse_terminated(input)?);
+ }
+ _ => abort!(ident.span(), "Unknown configuration name"),
+ }
+
+ input.parse::<Token![,]>().ok();
+ }
+
+ Ok(config)
+ }
+}
+
+impl Default for Config {
+ fn default() -> Self {
+ Self {
+ ext_trait_name: None,
+ visibility: Visibility::Inherited,
+ supertraits: None,
+ }
+ }
+}
+
+fn punctuated_from_iter<I, T, P>(i: I) -> Punctuated<T, P>
+where
+ P: Default,
+ I: IntoIterator<Item = T>,
+{
+ let mut iter = i.into_iter().peekable();
+ let mut acc = Punctuated::default();
+
+ while let Some(item) = iter.next() {
+ acc.push_value(item);
+
+ if iter.peek().is_some() {
+ acc.push_punct(P::default());
+ }
+ }
+
+ acc
+}
+
+#[cfg(test)]
+mod test {
+ #[allow(unused_imports)]
+ use super::*;
+
+ #[test]
+ fn test_ui() {
+ let t = trybuild::TestCases::new();
+ t.pass("tests/compile_pass/*.rs");
+ t.compile_fail("tests/compile_fail/*.rs");
+ }
+}