diff options
Diffstat (limited to 'third_party/rust/thiserror-impl/src/valid.rs')
-rw-r--r-- | third_party/rust/thiserror-impl/src/valid.rs | 237 |
1 files changed, 237 insertions, 0 deletions
diff --git a/third_party/rust/thiserror-impl/src/valid.rs b/third_party/rust/thiserror-impl/src/valid.rs new file mode 100644 index 0000000000..cf5b8592c2 --- /dev/null +++ b/third_party/rust/thiserror-impl/src/valid.rs @@ -0,0 +1,237 @@ +use crate::ast::{Enum, Field, Input, Struct, Variant}; +use crate::attr::Attrs; +use quote::ToTokens; +use std::collections::BTreeSet as Set; +use syn::{Error, GenericArgument, Member, PathArguments, Result, Type}; + +impl Input<'_> { + pub(crate) fn validate(&self) -> Result<()> { + match self { + Input::Struct(input) => input.validate(), + Input::Enum(input) => input.validate(), + } + } +} + +impl Struct<'_> { + fn validate(&self) -> Result<()> { + check_non_field_attrs(&self.attrs)?; + if let Some(transparent) = self.attrs.transparent { + if self.fields.len() != 1 { + return Err(Error::new_spanned( + transparent.original, + "#[error(transparent)] requires exactly one field", + )); + } + if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) { + return Err(Error::new_spanned( + source, + "transparent error struct can't contain #[source]", + )); + } + } + check_field_attrs(&self.fields)?; + for field in &self.fields { + field.validate()?; + } + Ok(()) + } +} + +impl Enum<'_> { + fn validate(&self) -> Result<()> { + check_non_field_attrs(&self.attrs)?; + let has_display = self.has_display(); + for variant in &self.variants { + variant.validate()?; + if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none() + { + return Err(Error::new_spanned( + variant.original, + "missing #[error(\"...\")] display attribute", + )); + } + } + let mut from_types = Set::new(); + for variant in &self.variants { + if let Some(from_field) = variant.from_field() { + let repr = from_field.ty.to_token_stream().to_string(); + if !from_types.insert(repr) { + return Err(Error::new_spanned( + from_field.original, + "cannot derive From because another variant has the same source type", + )); + } + } + } + Ok(()) + } +} + +impl Variant<'_> { + fn validate(&self) -> Result<()> { + check_non_field_attrs(&self.attrs)?; + if self.attrs.transparent.is_some() { + if self.fields.len() != 1 { + return Err(Error::new_spanned( + self.original, + "#[error(transparent)] requires exactly one field", + )); + } + if let Some(source) = self.fields.iter().find_map(|f| f.attrs.source) { + return Err(Error::new_spanned( + source, + "transparent variant can't contain #[source]", + )); + } + } + check_field_attrs(&self.fields)?; + for field in &self.fields { + field.validate()?; + } + Ok(()) + } +} + +impl Field<'_> { + fn validate(&self) -> Result<()> { + if let Some(display) = &self.attrs.display { + return Err(Error::new_spanned( + display.original, + "not expected here; the #[error(...)] attribute belongs on top of a struct or an enum variant", + )); + } + Ok(()) + } +} + +fn check_non_field_attrs(attrs: &Attrs) -> Result<()> { + if let Some(from) = &attrs.from { + return Err(Error::new_spanned( + from, + "not expected here; the #[from] attribute belongs on a specific field", + )); + } + if let Some(source) = &attrs.source { + return Err(Error::new_spanned( + source, + "not expected here; the #[source] attribute belongs on a specific field", + )); + } + if let Some(backtrace) = &attrs.backtrace { + return Err(Error::new_spanned( + backtrace, + "not expected here; the #[backtrace] attribute belongs on a specific field", + )); + } + if let Some(display) = &attrs.display { + if attrs.transparent.is_some() { + return Err(Error::new_spanned( + display.original, + "cannot have both #[error(transparent)] and a display attribute", + )); + } + } + Ok(()) +} + +fn check_field_attrs(fields: &[Field]) -> Result<()> { + let mut from_field = None; + let mut source_field = None; + let mut backtrace_field = None; + let mut has_backtrace = false; + for field in fields { + if let Some(from) = field.attrs.from { + if from_field.is_some() { + return Err(Error::new_spanned(from, "duplicate #[from] attribute")); + } + from_field = Some(field); + } + if let Some(source) = field.attrs.source { + if source_field.is_some() { + return Err(Error::new_spanned(source, "duplicate #[source] attribute")); + } + source_field = Some(field); + } + if let Some(backtrace) = field.attrs.backtrace { + if backtrace_field.is_some() { + return Err(Error::new_spanned( + backtrace, + "duplicate #[backtrace] attribute", + )); + } + backtrace_field = Some(field); + has_backtrace = true; + } + if let Some(transparent) = field.attrs.transparent { + return Err(Error::new_spanned( + transparent.original, + "#[error(transparent)] needs to go outside the enum or struct, not on an individual field", + )); + } + has_backtrace |= field.is_backtrace(); + } + if let (Some(from_field), Some(source_field)) = (from_field, source_field) { + if !same_member(from_field, source_field) { + return Err(Error::new_spanned( + from_field.attrs.from, + "#[from] is only supported on the source field, not any other field", + )); + } + } + if let Some(from_field) = from_field { + let max_expected_fields = match backtrace_field { + Some(backtrace_field) => 1 + !same_member(from_field, backtrace_field) as usize, + None => 1 + has_backtrace as usize, + }; + if fields.len() > max_expected_fields { + return Err(Error::new_spanned( + from_field.attrs.from, + "deriving From requires no fields other than source and backtrace", + )); + } + } + if let Some(source_field) = source_field.or(from_field) { + if contains_non_static_lifetime(source_field.ty) { + return Err(Error::new_spanned( + &source_field.original.ty, + "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static", + )); + } + } + Ok(()) +} + +fn same_member(one: &Field, two: &Field) -> bool { + match (&one.member, &two.member) { + (Member::Named(one), Member::Named(two)) => one == two, + (Member::Unnamed(one), Member::Unnamed(two)) => one.index == two.index, + _ => unreachable!(), + } +} + +fn contains_non_static_lifetime(ty: &Type) -> bool { + match ty { + Type::Path(ty) => { + let bracketed = match &ty.path.segments.last().unwrap().arguments { + PathArguments::AngleBracketed(bracketed) => bracketed, + _ => return false, + }; + for arg in &bracketed.args { + match arg { + GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true, + GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => { + return true + } + _ => {} + } + } + false + } + Type::Reference(ty) => ty + .lifetime + .as_ref() + .map_or(false, |lifetime| lifetime.ident != "static"), + _ => false, // maybe implement later if there are common other cases + } +} |