/*! Shader validator. */ mod analyzer; mod compose; mod expression; mod function; mod handles; mod interface; mod r#type; #[cfg(feature = "validate")] use crate::arena::{Arena, UniqueArena}; use crate::{ arena::Handle, proc::{LayoutError, Layouter}, FastHashSet, }; use bit_set::BitSet; use std::ops; //TODO: analyze the model at the same time as we validate it, // merge the corresponding matches over expressions and statements. use crate::span::{AddSpan as _, WithSpan}; pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; pub use compose::ComposeError; pub use expression::ExpressionError; pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; pub use r#type::{Disalignment, TypeError, TypeFlags}; use self::handles::InvalidHandleError; bitflags::bitflags! { /// Validation flags. /// /// If you are working with trusted shaders, then you may be able /// to save some time by skipping validation. /// /// If you do not perform full validation, invalid shaders may /// cause Naga to panic. If you do perform full validation and /// [`Validator::validate`] returns `Ok`, then Naga promises that /// code generation will either succeed or return an error; it /// should never panic. /// /// The default value for `ValidationFlags` is /// `ValidationFlags::all()`. If Naga's `"validate"` feature is /// enabled, this requests full validation; otherwise, this /// requests no validation. (The `"validate"` feature is disabled /// by default.) #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ValidationFlags: u8 { /// Expressions. #[cfg(feature = "validate")] const EXPRESSIONS = 0x1; /// Statements and blocks of them. #[cfg(feature = "validate")] const BLOCKS = 0x2; /// Uniformity of control flow for operations that require it. #[cfg(feature = "validate")] const CONTROL_FLOW_UNIFORMITY = 0x4; /// Host-shareable structure layouts. #[cfg(feature = "validate")] const STRUCT_LAYOUTS = 0x8; /// Constants. #[cfg(feature = "validate")] const CONSTANTS = 0x10; /// Group, binding, and location attributes. #[cfg(feature = "validate")] const BINDINGS = 0x20; } } impl Default for ValidationFlags { fn default() -> Self { Self::all() } } bitflags::bitflags! { /// Allowed IR capabilities. #[must_use] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct Capabilities: u16 { /// Support for [`AddressSpace:PushConstant`]. const PUSH_CONSTANT = 0x1; /// Float values with width = 8. const FLOAT64 = 0x2; /// Support for [`Builtin:PrimitiveIndex`]. const PRIMITIVE_INDEX = 0x4; /// Support for non-uniform indexing of sampled textures and storage buffer arrays. const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8; /// Support for non-uniform indexing of uniform buffers and storage texture arrays. const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10; /// Support for non-uniform indexing of samplers. const SAMPLER_NON_UNIFORM_INDEXING = 0x20; /// Support for [`Builtin::ClipDistance`]. const CLIP_DISTANCE = 0x40; /// Support for [`Builtin::CullDistance`]. const CULL_DISTANCE = 0x80; /// Support for 16-bit normalized storage texture formats. const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100; /// Support for [`BuiltIn::ViewIndex`]. const MULTIVIEW = 0x200; /// Support for `early_depth_test`. const EARLY_DEPTH_TEST = 0x400; /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`]. const MULTISAMPLED_SHADING = 0x800; /// Support for ray queries and acceleration structures. const RAY_QUERY = 0x1000; } } impl Default for Capabilities { fn default() -> Self { Self::MULTISAMPLED_SHADING } } bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ShaderStages: u8 { const VERTEX = 0x1; const FRAGMENT = 0x2; const COMPUTE = 0x4; } } #[derive(Debug)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { type_flags: Vec, functions: Vec, entry_points: Vec, } impl ops::Index> for ModuleInfo { type Output = TypeFlags; fn index(&self, handle: Handle) -> &Self::Output { &self.type_flags[handle.index()] } } impl ops::Index> for ModuleInfo { type Output = FunctionInfo; fn index(&self, handle: Handle) -> &Self::Output { &self.functions[handle.index()] } } #[derive(Debug)] pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, types: Vec, layouter: Layouter, location_mask: BitSet, bind_group_masks: Vec, #[allow(dead_code)] switch_values: FastHashSet, valid_expression_list: Vec>, valid_expression_set: BitSet, } #[derive(Clone, Debug, thiserror::Error)] pub enum ConstantError { #[error("The type doesn't match the constant")] InvalidType, #[error("The component handle {0:?} can not be resolved")] UnresolvedComponent(Handle), #[error("The array size handle {0:?} can not be resolved")] UnresolvedSize(Handle), #[error(transparent)] Compose(#[from] ComposeError), } #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { #[error(transparent)] InvalidHandle(#[from] InvalidHandleError), #[error(transparent)] Layouter(#[from] LayoutError), #[error("Type {handle:?} '{name}' is invalid")] Type { handle: Handle, name: String, source: TypeError, }, #[error("Constant {handle:?} '{name}' is invalid")] Constant { handle: Handle, name: String, source: ConstantError, }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle, name: String, source: GlobalVariableError, }, #[error("Function {handle:?} '{name}' is invalid")] Function { handle: Handle, name: String, source: FunctionError, }, #[error("Entry point {name} at {stage:?} is invalid")] EntryPoint { stage: crate::ShaderStage, name: String, source: EntryPointError, }, #[error("Module is corrupted")] Corrupted, } impl crate::TypeInner { #[cfg(feature = "validate")] const fn is_sized(&self) -> bool { match *self { Self::Scalar { .. } | Self::Vector { .. } | Self::Matrix { .. } | Self::Array { size: crate::ArraySize::Constant(_), .. } | Self::Atomic { .. } | Self::Pointer { .. } | Self::ValuePointer { .. } | Self::Struct { .. } => true, Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } | Self::AccelerationStructure | Self::RayQuery | Self::BindingArray { .. } => false, } } /// Return the `ImageDimension` for which `self` is an appropriate coordinate. #[cfg(feature = "validate")] const fn image_storage_coordinates(&self) -> Option { match *self { Self::Scalar { kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, .. } => Some(crate::ImageDimension::D1), Self::Vector { size: crate::VectorSize::Bi, kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, .. } => Some(crate::ImageDimension::D2), Self::Vector { size: crate::VectorSize::Tri, kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, .. } => Some(crate::ImageDimension::D3), _ => None, } } } impl Validator { /// Construct a new validator instance. pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { Validator { flags, capabilities, types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), bind_group_masks: Vec::new(), switch_values: FastHashSet::default(), valid_expression_list: Vec::new(), valid_expression_set: BitSet::new(), } } /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); self.layouter.clear(); self.location_mask.clear(); self.bind_group_masks.clear(); self.switch_values.clear(); self.valid_expression_list.clear(); self.valid_expression_set.clear(); } #[cfg(feature = "validate")] fn validate_constant( &self, handle: Handle, constants: &Arena, types: &UniqueArena, ) -> Result<(), ConstantError> { let con = &constants[handle]; match con.inner { crate::ConstantInner::Scalar { width, ref value } => { if self.check_width(value.scalar_kind(), width).is_err() { return Err(ConstantError::InvalidType); } } crate::ConstantInner::Composite { ty, ref components } => { match types[ty].inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. } if handle <= size_handle => { return Err(ConstantError::UnresolvedSize(size_handle)); } _ => {} } if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { return Err(ConstantError::UnresolvedComponent(comp)); } compose::validate_compose( ty, constants, types, components .iter() .map(|&component| constants[component].inner.resolve_type()), )?; } } Ok(()) } /// Check the given module to be valid. pub fn validate( &mut self, module: &crate::Module, ) -> Result> { self.reset(); self.reset_types(module.types.len()); #[cfg(feature = "validate")] Self::validate_module_handles(module).map_err(|e| e.with_span())?; self.layouter .update(&module.types, &module.constants) .map_err(|e| { let handle = e.ty; ValidationError::from(e).with_span_handle(handle, &module.types) })?; #[cfg(feature = "validate")] if self.flags.contains(ValidationFlags::CONSTANTS) { for (handle, constant) in module.constants.iter() { self.validate_constant(handle, &module.constants, &module.types) .map_err(|source| { ValidationError::Constant { handle, name: constant.name.clone().unwrap_or_default(), source, } .with_span_handle(handle, &module.constants) })? } } let mut mod_info = ModuleInfo { type_flags: Vec::with_capacity(module.types.len()), functions: Vec::with_capacity(module.functions.len()), entry_points: Vec::with_capacity(module.entry_points.len()), }; for (handle, ty) in module.types.iter() { let ty_info = self .validate_type(handle, &module.types, &module.constants) .map_err(|source| { ValidationError::Type { handle, name: ty.name.clone().unwrap_or_default(), source, } .with_span_handle(handle, &module.types) })?; mod_info.type_flags.push(ty_info.flags); self.types[handle.index()] = ty_info; } #[cfg(feature = "validate")] for (var_handle, var) in module.global_variables.iter() { self.validate_global_var(var, &module.types) .map_err(|source| { ValidationError::GlobalVariable { handle: var_handle, name: var.name.clone().unwrap_or_default(), source, } .with_span_handle(var_handle, &module.global_variables) })?; } for (handle, fun) in module.functions.iter() { match self.validate_function(fun, module, &mod_info, false) { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { ValidationError::Function { handle, name: fun.name.clone().unwrap_or_default(), source, } .with_span_handle(handle, &module.functions) })) } } } let mut ep_map = FastHashSet::default(); for ep in module.entry_points.iter() { if !ep_map.insert((ep.stage, &ep.name)) { return Err(ValidationError::EntryPoint { stage: ep.stage, name: ep.name.clone(), source: EntryPointError::Conflict, } .with_span()); // TODO: keep some EP span information? } match self.validate_entry_point(ep, module, &mod_info) { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { ValidationError::EntryPoint { stage: ep.stage, name: ep.name.clone(), source, } .with_span() })); } } } Ok(mod_info) } } #[cfg(feature = "validate")] fn validate_atomic_compare_exchange_struct( types: &UniqueArena, members: &[crate::StructMember], scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool, ) -> bool { members.len() == 2 && members[0].name.as_deref() == Some("old_value") && scalar_predicate(&types[members[0].ty].inner) && members[1].name.as_deref() == Some("exchanged") && types[members[1].ty].inner == crate::TypeInner::Scalar { kind: crate::ScalarKind::Bool, width: crate::BOOL_WIDTH, } }