diff options
Diffstat (limited to 'third_party/rust/naga/src/valid/mod.rs')
-rw-r--r-- | third_party/rust/naga/src/valid/mod.rs | 467 |
1 files changed, 467 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs new file mode 100644 index 0000000000..6b3a2e1456 --- /dev/null +++ b/third_party/rust/naga/src/valid/mod.rs @@ -0,0 +1,467 @@ +/*! +Shader validator. +*/ + +mod analyzer; +mod compose; +mod expression; +mod function; +mod handles; +mod interface; +mod r#type; + +#[cfg(feature = "validate")] +use crate::arena::{Arena, UniqueArena}; + +use crate::{ + arena::Handle, + proc::{LayoutError, Layouter}, + FastHashSet, +}; +use bit_set::BitSet; +use std::ops; + +//TODO: analyze the model at the same time as we validate it, +// merge the corresponding matches over expressions and statements. + +use crate::span::{AddSpan as _, WithSpan}; +pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; +pub use compose::ComposeError; +pub use expression::ExpressionError; +pub use function::{CallError, FunctionError, LocalVariableError}; +pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; +pub use r#type::{Disalignment, TypeError, TypeFlags}; + +use self::handles::InvalidHandleError; + +bitflags::bitflags! { + /// Validation flags. + /// + /// If you are working with trusted shaders, then you may be able + /// to save some time by skipping validation. + /// + /// If you do not perform full validation, invalid shaders may + /// cause Naga to panic. If you do perform full validation and + /// [`Validator::validate`] returns `Ok`, then Naga promises that + /// code generation will either succeed or return an error; it + /// should never panic. + /// + /// The default value for `ValidationFlags` is + /// `ValidationFlags::all()`. If Naga's `"validate"` feature is + /// enabled, this requests full validation; otherwise, this + /// requests no validation. (The `"validate"` feature is disabled + /// by default.) + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct ValidationFlags: u8 { + /// Expressions. + #[cfg(feature = "validate")] + const EXPRESSIONS = 0x1; + /// Statements and blocks of them. + #[cfg(feature = "validate")] + const BLOCKS = 0x2; + /// Uniformity of control flow for operations that require it. + #[cfg(feature = "validate")] + const CONTROL_FLOW_UNIFORMITY = 0x4; + /// Host-shareable structure layouts. + #[cfg(feature = "validate")] + const STRUCT_LAYOUTS = 0x8; + /// Constants. + #[cfg(feature = "validate")] + const CONSTANTS = 0x10; + /// Group, binding, and location attributes. + #[cfg(feature = "validate")] + const BINDINGS = 0x20; + } +} + +impl Default for ValidationFlags { + fn default() -> Self { + Self::all() + } +} + +bitflags::bitflags! { + /// Allowed IR capabilities. + #[must_use] + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct Capabilities: u16 { + /// Support for [`AddressSpace:PushConstant`]. + const PUSH_CONSTANT = 0x1; + /// Float values with width = 8. + const FLOAT64 = 0x2; + /// Support for [`Builtin:PrimitiveIndex`]. + const PRIMITIVE_INDEX = 0x4; + /// Support for non-uniform indexing of sampled textures and storage buffer arrays. + const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8; + /// Support for non-uniform indexing of uniform buffers and storage texture arrays. + const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10; + /// Support for non-uniform indexing of samplers. + const SAMPLER_NON_UNIFORM_INDEXING = 0x20; + /// Support for [`Builtin::ClipDistance`]. + const CLIP_DISTANCE = 0x40; + /// Support for [`Builtin::CullDistance`]. + const CULL_DISTANCE = 0x80; + /// Support for 16-bit normalized storage texture formats. + const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100; + /// Support for [`BuiltIn::ViewIndex`]. + const MULTIVIEW = 0x200; + /// Support for `early_depth_test`. + const EARLY_DEPTH_TEST = 0x400; + /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`]. + const MULTISAMPLED_SHADING = 0x800; + /// Support for ray queries and acceleration structures. + const RAY_QUERY = 0x1000; + } +} + +impl Default for Capabilities { + fn default() -> Self { + Self::MULTISAMPLED_SHADING + } +} + +bitflags::bitflags! { + /// Validation flags. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct ShaderStages: u8 { + const VERTEX = 0x1; + const FRAGMENT = 0x2; + const COMPUTE = 0x4; + } +} + +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct ModuleInfo { + type_flags: Vec<TypeFlags>, + functions: Vec<FunctionInfo>, + entry_points: Vec<FunctionInfo>, +} + +impl ops::Index<Handle<crate::Type>> for ModuleInfo { + type Output = TypeFlags; + fn index(&self, handle: Handle<crate::Type>) -> &Self::Output { + &self.type_flags[handle.index()] + } +} + +impl ops::Index<Handle<crate::Function>> for ModuleInfo { + type Output = FunctionInfo; + fn index(&self, handle: Handle<crate::Function>) -> &Self::Output { + &self.functions[handle.index()] + } +} + +#[derive(Debug)] +pub struct Validator { + flags: ValidationFlags, + capabilities: Capabilities, + types: Vec<r#type::TypeInfo>, + layouter: Layouter, + location_mask: BitSet, + bind_group_masks: Vec<BitSet>, + #[allow(dead_code)] + switch_values: FastHashSet<crate::SwitchValue>, + valid_expression_list: Vec<Handle<crate::Expression>>, + valid_expression_set: BitSet, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ConstantError { + #[error("The type doesn't match the constant")] + InvalidType, + #[error("The component handle {0:?} can not be resolved")] + UnresolvedComponent(Handle<crate::Constant>), + #[error("The array size handle {0:?} can not be resolved")] + UnresolvedSize(Handle<crate::Constant>), + #[error(transparent)] + Compose(#[from] ComposeError), +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ValidationError { + #[error(transparent)] + InvalidHandle(#[from] InvalidHandleError), + #[error(transparent)] + Layouter(#[from] LayoutError), + #[error("Type {handle:?} '{name}' is invalid")] + Type { + handle: Handle<crate::Type>, + name: String, + source: TypeError, + }, + #[error("Constant {handle:?} '{name}' is invalid")] + Constant { + handle: Handle<crate::Constant>, + name: String, + source: ConstantError, + }, + #[error("Global variable {handle:?} '{name}' is invalid")] + GlobalVariable { + handle: Handle<crate::GlobalVariable>, + name: String, + source: GlobalVariableError, + }, + #[error("Function {handle:?} '{name}' is invalid")] + Function { + handle: Handle<crate::Function>, + name: String, + source: FunctionError, + }, + #[error("Entry point {name} at {stage:?} is invalid")] + EntryPoint { + stage: crate::ShaderStage, + name: String, + source: EntryPointError, + }, + #[error("Module is corrupted")] + Corrupted, +} + +impl crate::TypeInner { + #[cfg(feature = "validate")] + const fn is_sized(&self) -> bool { + match *self { + Self::Scalar { .. } + | Self::Vector { .. } + | Self::Matrix { .. } + | Self::Array { + size: crate::ArraySize::Constant(_), + .. + } + | Self::Atomic { .. } + | Self::Pointer { .. } + | Self::ValuePointer { .. } + | Self::Struct { .. } => true, + Self::Array { .. } + | Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery + | Self::BindingArray { .. } => false, + } + } + + /// Return the `ImageDimension` for which `self` is an appropriate coordinate. + #[cfg(feature = "validate")] + const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> { + match *self { + Self::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + } => Some(crate::ImageDimension::D1), + Self::Vector { + size: crate::VectorSize::Bi, + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + } => Some(crate::ImageDimension::D2), + Self::Vector { + size: crate::VectorSize::Tri, + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + } => Some(crate::ImageDimension::D3), + _ => None, + } + } +} + +impl Validator { + /// Construct a new validator instance. + pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { + Validator { + flags, + capabilities, + types: Vec::new(), + layouter: Layouter::default(), + location_mask: BitSet::new(), + bind_group_masks: Vec::new(), + switch_values: FastHashSet::default(), + valid_expression_list: Vec::new(), + valid_expression_set: BitSet::new(), + } + } + + /// Reset the validator internals + pub fn reset(&mut self) { + self.types.clear(); + self.layouter.clear(); + self.location_mask.clear(); + self.bind_group_masks.clear(); + self.switch_values.clear(); + self.valid_expression_list.clear(); + self.valid_expression_set.clear(); + } + + #[cfg(feature = "validate")] + fn validate_constant( + &self, + handle: Handle<crate::Constant>, + constants: &Arena<crate::Constant>, + types: &UniqueArena<crate::Type>, + ) -> Result<(), ConstantError> { + let con = &constants[handle]; + match con.inner { + crate::ConstantInner::Scalar { width, ref value } => { + if self.check_width(value.scalar_kind(), width).is_err() { + return Err(ConstantError::InvalidType); + } + } + crate::ConstantInner::Composite { ty, ref components } => { + match types[ty].inner { + crate::TypeInner::Array { + size: crate::ArraySize::Constant(size_handle), + .. + } if handle <= size_handle => { + return Err(ConstantError::UnresolvedSize(size_handle)); + } + _ => {} + } + if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { + return Err(ConstantError::UnresolvedComponent(comp)); + } + compose::validate_compose( + ty, + constants, + types, + components + .iter() + .map(|&component| constants[component].inner.resolve_type()), + )?; + } + } + Ok(()) + } + + /// Check the given module to be valid. + pub fn validate( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.reset(); + self.reset_types(module.types.len()); + + #[cfg(feature = "validate")] + Self::validate_module_handles(module).map_err(|e| e.with_span())?; + + self.layouter + .update(&module.types, &module.constants) + .map_err(|e| { + let handle = e.ty; + ValidationError::from(e).with_span_handle(handle, &module.types) + })?; + + #[cfg(feature = "validate")] + if self.flags.contains(ValidationFlags::CONSTANTS) { + for (handle, constant) in module.constants.iter() { + self.validate_constant(handle, &module.constants, &module.types) + .map_err(|source| { + ValidationError::Constant { + handle, + name: constant.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.constants) + })? + } + } + + let mut mod_info = ModuleInfo { + type_flags: Vec::with_capacity(module.types.len()), + functions: Vec::with_capacity(module.functions.len()), + entry_points: Vec::with_capacity(module.entry_points.len()), + }; + + for (handle, ty) in module.types.iter() { + let ty_info = self + .validate_type(handle, &module.types, &module.constants) + .map_err(|source| { + ValidationError::Type { + handle, + name: ty.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.types) + })?; + mod_info.type_flags.push(ty_info.flags); + self.types[handle.index()] = ty_info; + } + + #[cfg(feature = "validate")] + for (var_handle, var) in module.global_variables.iter() { + self.validate_global_var(var, &module.types) + .map_err(|source| { + ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(var_handle, &module.global_variables) + })?; + } + + for (handle, fun) in module.functions.iter() { + match self.validate_function(fun, module, &mod_info, false) { + Ok(info) => mod_info.functions.push(info), + Err(error) => { + return Err(error.and_then(|source| { + ValidationError::Function { + handle, + name: fun.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.functions) + })) + } + } + } + + let mut ep_map = FastHashSet::default(); + for ep in module.entry_points.iter() { + if !ep_map.insert((ep.stage, &ep.name)) { + return Err(ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + source: EntryPointError::Conflict, + } + .with_span()); // TODO: keep some EP span information? + } + + match self.validate_entry_point(ep, module, &mod_info) { + Ok(info) => mod_info.entry_points.push(info), + Err(error) => { + return Err(error.and_then(|source| { + ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + source, + } + .with_span() + })); + } + } + } + + Ok(mod_info) + } +} + +#[cfg(feature = "validate")] +fn validate_atomic_compare_exchange_struct( + types: &UniqueArena<crate::Type>, + members: &[crate::StructMember], + scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool, +) -> bool { + members.len() == 2 + && members[0].name.as_deref() == Some("old_value") + && scalar_predicate(&types[members[0].ty].inner) + && members[1].name.as_deref() == Some("exchanged") + && types[members[1].ty].inner + == crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + } +} |