use super::{ compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ModuleInfo, ShaderStages, TypeFlags, }; use crate::arena::UniqueArena; use crate::{ arena::Handle, proc::{IndexableLengthError, ResolveError}, }; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum ExpressionError { #[error("Doesn't exist")] DoesntExist, #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] NotInScope, #[error("Base type {0:?} is not compatible with this expression")] InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] InvalidIndexType(Handle), #[error("Accessing {0:?} via a negative index is invalid")] NegativeIndex(Handle), #[error("Accessing index {1} is out of {0:?} bounds")] IndexOutOfBounds(Handle, u32), #[error("The expression {0:?} may only be indexed by a constant")] IndexMustBeConstant(Handle), #[error("Function argument {0:?} doesn't exist")] FunctionArgumentDoesntExist(u32), #[error("Loading of {0:?} can't be done")] InvalidPointerType(Handle), #[error("Array length of {0:?} can't be done")] InvalidArrayType(Handle), #[error("Get intersection of {0:?} can't be done")] InvalidRayQueryType(Handle), #[error("Splatting {0:?} can't be done")] InvalidSplatType(Handle), #[error("Swizzling {0:?} can't be done")] InvalidVectorType(Handle), #[error("Swizzle component {0:?} is outside of vector size {1:?}")] InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize), #[error(transparent)] Compose(#[from] super::ComposeError), #[error(transparent)] IndexableLength(#[from] IndexableLengthError), #[error("Operation {0:?} can't work with {1:?}")] InvalidUnaryOperandType(crate::UnaryOperator, Handle), #[error("Operation {0:?} can't work with {1:?} and {2:?}")] InvalidBinaryOperandTypes( crate::BinaryOperator, Handle, Handle, ), #[error("Selecting is not possible")] InvalidSelectTypes, #[error("Relational argument {0:?} is not a boolean vector")] InvalidBooleanVector(Handle), #[error("Relational argument {0:?} is not a float")] InvalidFloatArgument(Handle), #[error("Type resolution failed")] Type(#[from] ResolveError), #[error("Not a global variable")] ExpectedGlobalVariable, #[error("Not a global variable or a function argument")] ExpectedGlobalOrArgument, #[error("Needs to be an binding array instead of {0:?}")] ExpectedBindingArrayType(Handle), #[error("Needs to be an image instead of {0:?}")] ExpectedImageType(Handle), #[error("Needs to be an image instead of {0:?}")] ExpectedSamplerType(Handle), #[error("Unable to operate on image class {0:?}")] InvalidImageClass(crate::ImageClass), #[error("Derivatives can only be taken from scalar and vector floats")] InvalidDerivative, #[error("Image array index parameter is misplaced")] InvalidImageArrayIndex, #[error("Inappropriate sample or level-of-detail index for texel access")] InvalidImageOtherIndex, #[error("Image array index type of {0:?} is not an integer scalar")] InvalidImageArrayIndexType(Handle), #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")] InvalidImageOtherIndexType(Handle), #[error("Image coordinate type of {1:?} does not match dimension {0:?}")] InvalidImageCoordinateType(crate::ImageDimension, Handle), #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")] ComparisonSamplingMismatch { image: crate::ImageClass, sampler: bool, has_ref: bool, }, #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] InvalidSampleOffset(crate::ImageDimension, Handle), #[error("Depth reference {0:?} is not a scalar float")] InvalidDepthReference(Handle), #[error("Depth sample level can only be Auto or Zero")] InvalidDepthSampleLevel, #[error("Gather level can only be Zero")] InvalidGatherLevel, #[error("Gather component {0:?} doesn't exist in the image")] InvalidGatherComponent(crate::SwizzleComponent), #[error("Gather can't be done for image dimension {0:?}")] InvalidGatherDimension(crate::ImageDimension), #[error("Sample level (exact) type {0:?} is not a scalar float")] InvalidSampleLevelExactType(Handle), #[error("Sample level (bias) type {0:?} is not a scalar float")] InvalidSampleLevelBiasType(Handle), #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")] InvalidSampleLevelGradientType(crate::ImageDimension, Handle), #[error("Unable to cast")] InvalidCastArgument, #[error("Invalid argument count for {0:?}")] WrongArgumentCount(crate::MathFunction), #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")] InvalidArgumentType(crate::MathFunction, u32, Handle), #[error("Atomic result type can't be {0:?}")] InvalidAtomicResultType(Handle), #[error( "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type." )] InvalidWorkGroupUniformLoadResultType(Handle), #[error("Shader requires capability {0:?}")] MissingCapabilities(super::Capabilities), #[error(transparent)] Literal(#[from] LiteralError), } #[derive(Clone, Debug, thiserror::Error)] pub enum ConstExpressionError { #[error("The expression is not a constant expression")] NonConst, #[error(transparent)] Compose(#[from] super::ComposeError), #[error("Splatting {0:?} can't be done")] InvalidSplatType(Handle), #[error("Type resolution failed")] Type(#[from] ResolveError), #[error(transparent)] Literal(#[from] LiteralError), #[error(transparent)] Width(#[from] super::r#type::WidthError), } #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum LiteralError { #[error("Float literal is NaN")] NaN, #[error("Float literal is infinite")] Infinity, #[error(transparent)] Width(#[from] super::r#type::WidthError), } struct ExpressionTypeResolver<'a> { root: Handle, types: &'a UniqueArena, info: &'a FunctionInfo, } impl<'a> std::ops::Index> for ExpressionTypeResolver<'a> { type Output = crate::TypeInner; #[allow(clippy::panic)] fn index(&self, handle: Handle) -> &Self::Output { if handle < self.root { self.info[handle].ty.inner_with(self.types) } else { // `Validator::validate_module_handles` should have caught this. panic!( "Depends on {:?}, which has not been processed yet", self.root ) } } } impl super::Validator { pub(super) fn validate_const_expression( &self, handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, ) -> Result<(), ConstExpressionError> { use crate::Expression as E; match gctx.const_expressions[handle] { E::Literal(literal) => { self.validate_literal(literal)?; } E::Constant(_) | E::ZeroValue(_) => {} E::Compose { ref components, ty } => { validate_compose( ty, gctx, components.iter().map(|&handle| mod_info[handle].clone()), )?; } E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) { crate::TypeInner::Scalar { .. } => {} _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), }, _ => return Err(super::ConstExpressionError::NonConst), } Ok(()) } pub(super) fn validate_expression( &self, root: Handle, expression: &crate::Expression, function: &crate::Function, module: &crate::Module, info: &FunctionInfo, mod_info: &ModuleInfo, ) -> Result { use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; let resolver = ExpressionTypeResolver { root, types: &module.types, info, }; let stages = match *expression { E::Access { base, index } => { let base_type = &resolver[base]; // See the documentation for `Expression::Access`. let dynamic_indexing_restricted = match *base_type { Ti::Vector { .. } => false, Ti::Matrix { .. } | Ti::Array { .. } => true, Ti::Pointer { .. } | Ti::ValuePointer { size: Some(_), .. } | Ti::BindingArray { .. } => false, ref other => { log::error!("Indexing of {:?}", other); return Err(ExpressionError::InvalidBaseType(base)); } }; match resolver[index] { //TODO: only allow one of these Ti::Scalar(Sc { kind: Sk::Sint | Sk::Uint, .. }) => {} ref other => { log::error!("Indexing by {:?}", other); return Err(ExpressionError::InvalidIndexType(index)); } } if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index(module) { return Err(ExpressionError::IndexMustBeConstant(base)); } // If we know both the length and the index, we can do the // bounds check now. if let crate::proc::IndexableLength::Known(known_length) = base_type.indexable_length(module)? { match module .to_ctx() .eval_expr_to_u32_from(index, &function.expressions) { Ok(value) => { if value >= known_length { return Err(ExpressionError::IndexOutOfBounds(base, value)); } } Err(crate::proc::U32EvalError::Negative) => { return Err(ExpressionError::NegativeIndex(base)) } Err(crate::proc::U32EvalError::NonConst) => {} } } ShaderStages::all() } E::AccessIndex { base, index } => { fn resolve_index_limit( module: &crate::Module, top: Handle, ty: &crate::TypeInner, top_level: bool, ) -> Result { let limit = match *ty { Ti::Vector { size, .. } | Ti::ValuePointer { size: Some(size), .. } => size as u32, Ti::Matrix { columns, .. } => columns as u32, Ti::Array { size: crate::ArraySize::Constant(len), .. } => len.get(), Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks Ti::Pointer { base, .. } if top_level => { resolve_index_limit(module, top, &module.types[base].inner, false)? } Ti::Struct { ref members, .. } => members.len() as u32, ref other => { log::error!("Indexing of {:?}", other); return Err(ExpressionError::InvalidBaseType(top)); } }; Ok(limit) } let limit = resolve_index_limit(module, base, &resolver[base], true)?; if index >= limit { return Err(ExpressionError::IndexOutOfBounds(base, limit)); } ShaderStages::all() } E::Splat { size: _, value } => match resolver[value] { Ti::Scalar { .. } => ShaderStages::all(), ref other => { log::error!("Splat scalar type {:?}", other); return Err(ExpressionError::InvalidSplatType(value)); } }, E::Swizzle { size, vector, pattern, } => { let vec_size = match resolver[vector] { Ti::Vector { size: vec_size, .. } => vec_size, ref other => { log::error!("Swizzle vector type {:?}", other); return Err(ExpressionError::InvalidVectorType(vector)); } }; for &sc in pattern[..size as usize].iter() { if sc as u8 >= vec_size as u8 { return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size)); } } ShaderStages::all() } E::Literal(literal) => { self.validate_literal(literal)?; ShaderStages::all() } E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, module.to_ctx(), components.iter().map(|&handle| info[handle].ty.clone()), )?; ShaderStages::all() } E::FunctionArgument(index) => { if index >= function.arguments.len() as u32 { return Err(ExpressionError::FunctionArgumentDoesntExist(index)); } ShaderStages::all() } E::GlobalVariable(_handle) => ShaderStages::all(), E::LocalVariable(_handle) => ShaderStages::all(), E::Load { pointer } => { match resolver[pointer] { Ti::Pointer { base, .. } if self.types[base.index()] .flags .contains(TypeFlags::SIZED | TypeFlags::DATA) => {} Ti::ValuePointer { .. } => {} ref other => { log::error!("Loading {:?}", other); return Err(ExpressionError::InvalidPointerType(pointer)); } } ShaderStages::all() } E::ImageSample { image, sampler, gather, coordinate, array_index, offset, level, depth_ref, } => { // check the validity of expressions let image_ty = Self::global_var_ty(module, function, image)?; let sampler_ty = Self::global_var_ty(module, function, sampler)?; let comparison = match module.types[sampler_ty].inner { Ti::Sampler { comparison } => comparison, _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)), }; let (class, dim) = match module.types[image_ty].inner { Ti::Image { class, arrayed, dim, } => { // check the array property if arrayed != array_index.is_some() { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { match resolver[expr] { Ti::Scalar(Sc { kind: Sk::Sint | Sk::Uint, .. }) => {} _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), } } (class, dim) } _ => return Err(ExpressionError::ExpectedImageType(image_ty)), }; // check sampling and comparison properties let image_depth = match class { crate::ImageClass::Sampled { kind: crate::ScalarKind::Float, multi: false, } => false, crate::ImageClass::Sampled { kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint, multi: false, } if gather.is_some() => false, crate::ImageClass::Depth { multi: false } => true, _ => return Err(ExpressionError::InvalidImageClass(class)), }; if comparison != depth_ref.is_some() || (comparison && !image_depth) { return Err(ExpressionError::ComparisonSamplingMismatch { image: class, sampler: comparison, has_ref: depth_ref.is_some(), }); } // check texture coordinates type let num_components = match dim { crate::ImageDimension::D1 => 1, crate::ImageDimension::D2 => 2, crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, }; match resolver[coordinate] { Ti::Scalar(Sc { kind: Sk::Float, .. }) if num_components == 1 => {} Ti::Vector { size, scalar: Sc { kind: Sk::Float, .. }, } if size as u32 == num_components => {} _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)), } // check constant offset if let Some(const_expr) = offset { match *mod_info[const_expr].inner_with(&module.types) { Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {} Ti::Vector { size, scalar: Sc { kind: Sk::Sint, .. }, } if size as u32 == num_components => {} _ => { return Err(ExpressionError::InvalidSampleOffset(dim, const_expr)); } } } // check depth reference type if let Some(expr) = depth_ref { match resolver[expr] { Ti::Scalar(Sc { kind: Sk::Float, .. }) => {} _ => return Err(ExpressionError::InvalidDepthReference(expr)), } match level { crate::SampleLevel::Auto | crate::SampleLevel::Zero => {} _ => return Err(ExpressionError::InvalidDepthSampleLevel), } } if let Some(component) = gather { match dim { crate::ImageDimension::D2 | crate::ImageDimension::Cube => {} crate::ImageDimension::D1 | crate::ImageDimension::D3 => { return Err(ExpressionError::InvalidGatherDimension(dim)) } }; let max_component = match class { crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X, _ => crate::SwizzleComponent::W, }; if component > max_component { return Err(ExpressionError::InvalidGatherComponent(component)); } match level { crate::SampleLevel::Zero => {} _ => return Err(ExpressionError::InvalidGatherLevel), } } // check level properties match level { crate::SampleLevel::Auto => ShaderStages::FRAGMENT, crate::SampleLevel::Zero => ShaderStages::all(), crate::SampleLevel::Exact(expr) => { match resolver[expr] { Ti::Scalar(Sc { kind: Sk::Float, .. }) => {} _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)), } ShaderStages::all() } crate::SampleLevel::Bias(expr) => { match resolver[expr] { Ti::Scalar(Sc { kind: Sk::Float, .. }) => {} _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)), } ShaderStages::FRAGMENT } crate::SampleLevel::Gradient { x, y } => { match resolver[x] { Ti::Scalar(Sc { kind: Sk::Float, .. }) if num_components == 1 => {} Ti::Vector { size, scalar: Sc { kind: Sk::Float, .. }, } if size as u32 == num_components => {} _ => { return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) } } match resolver[y] { Ti::Scalar(Sc { kind: Sk::Float, .. }) if num_components == 1 => {} Ti::Vector { size, scalar: Sc { kind: Sk::Float, .. }, } if size as u32 == num_components => {} _ => { return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y)) } } ShaderStages::all() } } } E::ImageLoad { image, coordinate, array_index, sample, level, } => { let ty = Self::global_var_ty(module, function, image)?; match module.types[ty].inner { Ti::Image { class, arrayed, dim, } => { match resolver[coordinate].image_storage_coordinates() { Some(coord_dim) if coord_dim == dim => {} _ => { return Err(ExpressionError::InvalidImageCoordinateType( dim, coordinate, )) } }; if arrayed != array_index.is_some() { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { match resolver[expr] { Ti::Scalar(Sc { kind: Sk::Sint | Sk::Uint, width: _, }) => {} _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), } } match (sample, class.is_multisampled()) { (None, false) => {} (Some(sample), true) => { if resolver[sample].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType( sample, )); } } _ => { return Err(ExpressionError::InvalidImageOtherIndex); } } match (level, class.is_mipmapped()) { (None, false) => {} (Some(level), true) => { if resolver[level].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType(level)); } } _ => { return Err(ExpressionError::InvalidImageOtherIndex); } } } _ => return Err(ExpressionError::ExpectedImageType(ty)), } ShaderStages::all() } E::ImageQuery { image, query } => { let ty = Self::global_var_ty(module, function, image)?; match module.types[ty].inner { Ti::Image { class, arrayed, .. } => { let good = match query { crate::ImageQuery::NumLayers => arrayed, crate::ImageQuery::Size { level: None } => true, crate::ImageQuery::Size { level: Some(_) } | crate::ImageQuery::NumLevels => class.is_mipmapped(), crate::ImageQuery::NumSamples => class.is_multisampled(), }; if !good { return Err(ExpressionError::InvalidImageClass(class)); } } _ => return Err(ExpressionError::ExpectedImageType(ty)), } ShaderStages::all() } E::Unary { op, expr } => { use crate::UnaryOperator as Uo; let inner = &resolver[expr]; match (op, inner.scalar_kind()) { (Uo::Negate, Some(Sk::Float | Sk::Sint)) | (Uo::LogicalNot, Some(Sk::Bool)) | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {} other => { log::error!("Op {:?} kind {:?}", op, other); return Err(ExpressionError::InvalidUnaryOperandType(op, expr)); } } ShaderStages::all() } E::Binary { op, left, right } => { use crate::BinaryOperator as Bo; let left_inner = &resolver[left]; let right_inner = &resolver[right]; let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, Ti::Matrix { .. } => left_inner == right_inner, _ => false, }, Bo::Divide | Bo::Modulo => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, _ => false, }, Bo::Multiply => { let kind_allowed = match left_inner.scalar_kind() { Some(Sk::Uint | Sk::Sint | Sk::Float) => true, Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false, }; let types_match = match (left_inner, right_inner) { // Straight scalar and mixed scalar/vector. (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2)) | ( &Ti::Vector { scalar: scalar1, .. }, &Ti::Scalar(scalar2), ) | ( &Ti::Scalar(scalar1), &Ti::Vector { scalar: scalar2, .. }, ) => scalar1 == scalar2, // Scalar/matrix. ( &Ti::Scalar(Sc { kind: Sk::Float, .. }), &Ti::Matrix { .. }, ) | ( &Ti::Matrix { .. }, &Ti::Scalar(Sc { kind: Sk::Float, .. }), ) => true, // Vector/vector. ( &Ti::Vector { size: size1, scalar: scalar1, }, &Ti::Vector { size: size2, scalar: scalar2, }, ) => scalar1 == scalar2 && size1 == size2, // Matrix * vector. ( &Ti::Matrix { columns, .. }, &Ti::Vector { size, scalar: Sc { kind: Sk::Float, .. }, }, ) => columns == size, // Vector * matrix. ( &Ti::Vector { size, scalar: Sc { kind: Sk::Float, .. }, }, &Ti::Matrix { rows, .. }, ) => size == rows, (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { columns == rows } _ => false, }; let left_width = left_inner.scalar_width().unwrap_or(0); let right_width = right_inner.scalar_width().unwrap_or(0); kind_allowed && types_match && left_width == right_width } Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner, Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); false } } } Bo::LogicalAnd | Bo::LogicalOr => match *left_inner { Ti::Scalar(Sc { kind: Sk::Bool, .. }) | Ti::Vector { scalar: Sc { kind: Sk::Bool, .. }, .. } => left_inner == right_inner, ref other => { log::error!("Op {:?} left type {:?}", op, other); false } }, Bo::And | Bo::InclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); false } }, Bo::ExclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Sint | Sk::Uint => left_inner == right_inner, Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); false } }, Bo::ShiftLeft | Bo::ShiftRight => { let (base_size, base_scalar) = match *left_inner { Ti::Scalar(scalar) => (Ok(None), scalar), Ti::Vector { size, scalar } => (Ok(Some(size)), scalar), ref other => { log::error!("Op {:?} base type {:?}", op, other); (Err(()), Sc::BOOL) } }; let shift_size = match *right_inner { Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None), Ti::Vector { size, scalar: Sc { kind: Sk::Uint, .. }, } => Ok(Some(size)), ref other => { log::error!("Op {:?} shift type {:?}", op, other); Err(()) } }; match base_scalar.kind { Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false, } } }; if !good { log::error!( "Left: {:?} of type {:?}", function.expressions[left], left_inner ); log::error!( "Right: {:?} of type {:?}", function.expressions[right], right_inner ); return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right)); } ShaderStages::all() } E::Select { condition, accept, reject, } => { let accept_inner = &resolver[accept]; let reject_inner = &resolver[reject]; let condition_good = match resolver[condition] { Ti::Scalar(Sc { kind: Sk::Bool, width: _, }) => { // When `condition` is a single boolean, `accept` and // `reject` can be vectors or scalars. match *accept_inner { Ti::Scalar { .. } | Ti::Vector { .. } => true, _ => false, } } Ti::Vector { size, scalar: Sc { kind: Sk::Bool, width: _, }, } => match *accept_inner { Ti::Vector { size: other_size, .. } => size == other_size, _ => false, }, _ => false, }; if !condition_good || accept_inner != reject_inner { return Err(ExpressionError::InvalidSelectTypes); } ShaderStages::all() } E::Derivative { expr, .. } => { match resolver[expr] { Ti::Scalar(Sc { kind: Sk::Float, .. }) | Ti::Vector { scalar: Sc { kind: Sk::Float, .. }, .. } => {} _ => return Err(ExpressionError::InvalidDerivative), } ShaderStages::FRAGMENT } E::Relational { fun, argument } => { use crate::RelationalFunction as Rf; let argument_inner = &resolver[argument]; match fun { Rf::All | Rf::Any => match *argument_inner { Ti::Vector { scalar: Sc { kind: Sk::Bool, .. }, .. } => {} ref other => { log::error!("All/Any of type {:?}", other); return Err(ExpressionError::InvalidBooleanVector(argument)); } }, Rf::IsNan | Rf::IsInf => match *argument_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } if scalar.kind == Sk::Float => {} ref other => { log::error!("Float test of type {:?}", other); return Err(ExpressionError::InvalidFloatArgument(argument)); } }, } ShaderStages::all() } E::Math { fun, arg, arg1, arg2, arg3, } => { use crate::MathFunction as Mf; let resolve = |arg| &resolver[arg]; let arg_ty = resolve(arg); let arg1_ty = arg1.map(resolve); let arg2_ty = arg2.map(resolve); let arg3_ty = arg3.map(resolve); match fun { Mf::Abs => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } let good = match *arg_ty { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { scalar.kind != Sk::Bool } _ => false, }; if !good { return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); } } Mf::Min | Mf::Max => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let good = match *arg_ty { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { scalar.kind != Sk::Bool } _ => false, }; if !good { return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } } Mf::Clamp => { let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let good = match *arg_ty { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { scalar.kind != Sk::Bool } _ => false, }; if !good { return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } if arg2_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg2.unwrap(), )); } } Mf::Saturate | Mf::Cos | Mf::Cosh | Mf::Sin | Mf::Sinh | Mf::Tan | Mf::Tanh | Mf::Acos | Mf::Asin | Mf::Atan | Mf::Asinh | Mf::Acosh | Mf::Atanh | Mf::Radians | Mf::Degrees | Mf::Ceil | Mf::Floor | Mf::Round | Mf::Fract | Mf::Trunc | Mf::Exp | Mf::Exp2 | Mf::Log | Mf::Log2 | Mf::Length | Mf::Sqrt | Mf::InverseSqrt => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } if scalar.kind == Sk::Float => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } Mf::Sign => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Scalar(Sc { kind: Sk::Float | Sk::Sint, .. }) | Ti::Vector { scalar: Sc { kind: Sk::Float | Sk::Sint, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } if scalar.kind == Sk::Float => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } } Mf::Modf | Mf::Frexp => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } if !matches!(*arg_ty, Ti::Scalar(scalar) | Ti::Vector { scalar, .. } if scalar.kind == Sk::Float) { return Err(ExpressionError::InvalidArgumentType(fun, 1, arg)); } } Mf::Ldexp => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let size0 = match *arg_ty { Ti::Scalar(Sc { kind: Sk::Float, .. }) => None, Ti::Vector { scalar: Sc { kind: Sk::Float, .. }, size, } => Some(size), _ => { return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); } }; let good = match *arg1_ty { Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true, Ti::Vector { size, scalar: Sc { kind: Sk::Sint, .. }, } if Some(size) == size0 => true, _ => false, }; if !good { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } } Mf::Dot => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { Ti::Vector { scalar: Sc { kind: Sk::Float | Sk::Sint | Sk::Uint, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } } Mf::Outer | Mf::Cross | Mf::Reflect => { let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), None, None) => ty1, _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { Ti::Vector { scalar: Sc { kind: Sk::Float, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } } Mf::Refract => { let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { Ti::Vector { scalar: Sc { kind: Sk::Float, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } match (arg_ty, arg2_ty) { ( &Ti::Vector { scalar: Sc { width: vector_width, .. }, .. }, &Ti::Scalar(Sc { width: scalar_width, kind: Sk::Float, }), ) if vector_width == scalar_width => {} _ => { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg2.unwrap(), )) } } } Mf::Normalize => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Vector { scalar: Sc { kind: Sk::Float, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } Mf::FaceForward | Mf::Fma | Mf::SmoothStep => { let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { Ti::Scalar(Sc { kind: Sk::Float, .. }) | Ti::Vector { scalar: Sc { kind: Sk::Float, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } if arg2_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg2.unwrap(), )); } } Mf::Mix => { let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; let arg_width = match *arg_ty { Ti::Scalar(Sc { kind: Sk::Float, width, }) | Ti::Vector { scalar: Sc { kind: Sk::Float, width, }, .. } => width, _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), }; if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } // the last argument can always be a scalar match *arg2_ty { Ti::Scalar(Sc { kind: Sk::Float, width, }) if width == arg_width => {} _ if arg2_ty == arg_ty => {} _ => { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg2.unwrap(), )); } } } Mf::Inverse | Mf::Determinant => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } let good = match *arg_ty { Ti::Matrix { columns, rows, .. } => columns == rows, _ => false, }; if !good { return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); } } Mf::Transpose => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Matrix { .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } Mf::CountTrailingZeros | Mf::CountLeadingZeros | Mf::CountOneBits | Mf::ReverseBits | Mf::FindLsb | Mf::FindMsb => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Scalar(Sc { kind: Sk::Sint | Sk::Uint, .. }) | Ti::Vector { scalar: Sc { kind: Sk::Sint | Sk::Uint, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } Mf::InsertBits => { let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { Ti::Scalar(Sc { kind: Sk::Sint | Sk::Uint, .. }) | Ti::Vector { scalar: Sc { kind: Sk::Sint | Sk::Uint, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } if arg1_ty != arg_ty { return Err(ExpressionError::InvalidArgumentType( fun, 1, arg1.unwrap(), )); } match *arg2_ty { Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} _ => { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg2.unwrap(), )) } } match *arg3_ty { Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} _ => { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg3.unwrap(), )) } } } Mf::ExtractBits => { let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { (Some(ty1), Some(ty2), None) => (ty1, ty2), _ => return Err(ExpressionError::WrongArgumentCount(fun)), }; match *arg_ty { Ti::Scalar(Sc { kind: Sk::Sint | Sk::Uint, .. }) | Ti::Vector { scalar: Sc { kind: Sk::Sint | Sk::Uint, .. }, .. } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } match *arg1_ty { Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} _ => { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg1.unwrap(), )) } } match *arg2_ty { Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} _ => { return Err(ExpressionError::InvalidArgumentType( fun, 2, arg2.unwrap(), )) } } } Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Vector { size: crate::VectorSize::Bi, scalar: Sc { kind: Sk::Float, .. }, } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } Mf::Pack4x8snorm | Mf::Pack4x8unorm => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Vector { size: crate::VectorSize::Quad, scalar: Sc { kind: Sk::Float, .. }, } => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } Mf::Unpack2x16float | Mf::Unpack2x16snorm | Mf::Unpack2x16unorm | Mf::Unpack4x8snorm | Mf::Unpack4x8unorm => { if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { return Err(ExpressionError::WrongArgumentCount(fun)); } match *arg_ty { Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), } } } ShaderStages::all() } E::As { expr, kind, convert, } => { let mut base_scalar = match resolver[expr] { crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => { scalar } crate::TypeInner::Matrix { scalar, .. } => scalar, _ => return Err(ExpressionError::InvalidCastArgument), }; base_scalar.kind = kind; if let Some(width) = convert { base_scalar.width = width; } if self.check_width(base_scalar).is_err() { return Err(ExpressionError::InvalidCastArgument); } ShaderStages::all() } E::CallResult(function) => mod_info.functions[function.index()].available_stages, E::AtomicResult { ty, comparison } => { let scalar_predicate = |ty: &crate::TypeInner| match ty { &crate::TypeInner::Scalar( scalar @ Sc { kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint, .. }, ) => self.check_width(scalar).is_ok(), _ => false, }; let good = match &module.types[ty].inner { ty if !comparison => scalar_predicate(ty), &crate::TypeInner::Struct { ref members, .. } if comparison => { validate_atomic_compare_exchange_struct( &module.types, members, scalar_predicate, ) } _ => false, }; if !good { return Err(ExpressionError::InvalidAtomicResultType(ty)); } ShaderStages::all() } E::WorkGroupUniformLoadResult { ty } => { if self.types[ty.index()] .flags // Sized | Constructible is exactly the types currently supported by // WorkGroupUniformLoad .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE) { ShaderStages::COMPUTE } else { return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty)); } } E::ArrayLength(expr) => match resolver[expr] { Ti::Pointer { base, .. } => { let base_ty = &resolver.types[base]; if let Ti::Array { size: crate::ArraySize::Dynamic, .. } = base_ty.inner { ShaderStages::all() } else { return Err(ExpressionError::InvalidArrayType(expr)); } } ref other => { log::error!("Array length of {:?}", other); return Err(ExpressionError::InvalidArrayType(expr)); } }, E::RayQueryProceedResult => ShaderStages::all(), E::RayQueryGetIntersection { query, committed: _, } => match resolver[query] { Ti::Pointer { base, space: crate::AddressSpace::Function, } => match resolver.types[base].inner { Ti::RayQuery => ShaderStages::all(), ref other => { log::error!("Intersection result of a pointer to {:?}", other); return Err(ExpressionError::InvalidRayQueryType(query)); } }, ref other => { log::error!("Intersection result of {:?}", other); return Err(ExpressionError::InvalidRayQueryType(query)); } }, }; Ok(stages) } fn global_var_ty( module: &crate::Module, function: &crate::Function, expr: Handle, ) -> Result, ExpressionError> { use crate::Expression as Ex; match function.expressions[expr] { Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty), Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty), Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => { match function.expressions[base] { Ex::GlobalVariable(var_handle) => { let array_ty = module.global_variables[var_handle].ty; match module.types[array_ty].inner { crate::TypeInner::BindingArray { base, .. } => Ok(base), _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)), } } _ => Err(ExpressionError::ExpectedGlobalVariable), } } _ => Err(ExpressionError::ExpectedGlobalVariable), } } pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> { self.check_width(literal.scalar())?; check_literal_value(literal)?; Ok(()) } } pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> { let is_nan = match literal { crate::Literal::F64(v) => v.is_nan(), crate::Literal::F32(v) => v.is_nan(), _ => false, }; if is_nan { return Err(LiteralError::NaN); } let is_infinite = match literal { crate::Literal::F64(v) => v.is_infinite(), crate::Literal::F32(v) => v.is_infinite(), _ => false, }; if is_infinite { return Err(LiteralError::Infinity); } Ok(()) } #[cfg(all(test, feature = "validate"))] /// Validate a module containing the given expression, expecting an error. fn validate_with_expression( expr: crate::Expression, caps: super::Capabilities, ) -> Result> { use crate::span::Span; let mut function = crate::Function::default(); function.expressions.append(expr, Span::default()); function.body.push( crate::Statement::Emit(function.expressions.range_from(0)), Span::default(), ); let mut module = crate::Module::default(); module.functions.append(function, Span::default()); let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps); validator.validate(&module) } #[cfg(all(test, feature = "validate"))] /// Validate a module containing the given constant expression, expecting an error. fn validate_with_const_expression( expr: crate::Expression, caps: super::Capabilities, ) -> Result> { use crate::span::Span; let mut module = crate::Module::default(); module.const_expressions.append(expr, Span::default()); let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps); validator.validate(&module) } /// Using F64 in a function's expression arena is forbidden. #[cfg(feature = "validate")] #[test] fn f64_runtime_literals() { let result = validate_with_expression( crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), super::Capabilities::default(), ); let error = result.unwrap_err().into_inner(); assert!(matches!( error, crate::valid::ValidationError::Function { source: super::FunctionError::Expression { source: super::ExpressionError::Literal(super::LiteralError::Width( super::r#type::WidthError::MissingCapability { name: "f64", flag: "FLOAT64", } ),), .. }, .. } )); let result = validate_with_expression( crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), super::Capabilities::default() | super::Capabilities::FLOAT64, ); assert!(result.is_ok()); } /// Using F64 in a module's constant expression arena is forbidden. #[cfg(feature = "validate")] #[test] fn f64_const_literals() { let result = validate_with_const_expression( crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), super::Capabilities::default(), ); let error = result.unwrap_err().into_inner(); assert!(matches!( error, crate::valid::ValidationError::ConstExpression { source: super::ConstExpressionError::Literal(super::LiteralError::Width( super::r#type::WidthError::MissingCapability { name: "f64", flag: "FLOAT64", } )), .. } )); let result = validate_with_const_expression( crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), super::Capabilities::default() | super::Capabilities::FLOAT64, ); assert!(result.is_ok()); } /// Using I64 in a function's expression arena is forbidden. #[cfg(feature = "validate")] #[test] fn i64_runtime_literals() { let result = validate_with_expression( crate::Expression::Literal(crate::Literal::I64(1729)), // There is no capability that enables this. super::Capabilities::all(), ); let error = result.unwrap_err().into_inner(); assert!(matches!( error, crate::valid::ValidationError::Function { source: super::FunctionError::Expression { source: super::ExpressionError::Literal(super::LiteralError::Width( super::r#type::WidthError::Unsupported64Bit ),), .. }, .. } )); } /// Using I64 in a module's constant expression arena is forbidden. #[cfg(feature = "validate")] #[test] fn i64_const_literals() { let result = validate_with_const_expression( crate::Expression::Literal(crate::Literal::I64(1729)), // There is no capability that enables this. super::Capabilities::all(), ); let error = result.unwrap_err().into_inner(); assert!(matches!( error, crate::valid::ValidationError::ConstExpression { source: super::ConstExpressionError::Literal(super::LiteralError::Width( super::r#type::WidthError::Unsupported64Bit, ),), .. } )); }