diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-19 00:47:55 +0000 |
commit | 26a029d407be480d791972afb5975cf62c9360a6 (patch) | |
tree | f435a8308119effd964b339f76abb83a57c29483 /third_party/rust/naga/src/valid | |
parent | Initial commit. (diff) | |
download | firefox-26a029d407be480d791972afb5975cf62c9360a6.tar.xz firefox-26a029d407be480d791972afb5975cf62c9360a6.zip |
Adding upstream version 124.0.1.upstream/124.0.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/valid')
-rw-r--r-- | third_party/rust/naga/src/valid/analyzer.rs | 1281 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/compose.rs | 128 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/expression.rs | 1797 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/function.rs | 1056 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/handles.rs | 699 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/interface.rs | 709 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/mod.rs | 477 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/type.rs | 643 |
8 files changed, 6790 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs new file mode 100644 index 0000000000..df6fc5e9b0 --- /dev/null +++ b/third_party/rust/naga/src/valid/analyzer.rs @@ -0,0 +1,1281 @@ +/*! Module analyzer. + +Figures out the following properties: + - control flow uniformity + - texture/sampler pairs + - expression reference counts +!*/ + +use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags}; +use crate::span::{AddSpan as _, WithSpan}; +use crate::{ + arena::{Arena, Handle}, + proc::{ResolveContext, TypeResolution}, +}; +use std::ops; + +pub type NonUniformResult = Option<Handle<crate::Expression>>; + +// Remove this once we update our uniformity analysis and +// add support for the `derivative_uniformity` diagnostic +const DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE: bool = true; + +bitflags::bitflags! { + /// Kinds of expressions that require uniform control flow. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct UniformityRequirements: u8 { + const WORK_GROUP_BARRIER = 0x1; + const DERIVATIVE = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x2 }; + const IMPLICIT_LEVEL = if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { 0 } else { 0x4 }; + } +} + +/// Uniform control flow characteristics. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(test, derive(PartialEq))] +pub struct Uniformity { + /// A child expression with non-uniform result. + /// + /// This means, when the relevant invocations are scheduled on a compute unit, + /// they have to use vector registers to store an individual value + /// per invocation. + /// + /// Whenever the control flow is conditioned on such value, + /// the hardware needs to keep track of the mask of invocations, + /// and process all branches of the control flow. + /// + /// Any operations that depend on non-uniform results also produce non-uniform. + pub non_uniform_result: NonUniformResult, + /// If this expression requires uniform control flow, store the reason here. + pub requirements: UniformityRequirements, +} + +impl Uniformity { + const fn new() -> Self { + Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::empty(), + } + } +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, PartialEq)] + struct ExitFlags: u8 { + /// Control flow may return from the function, which makes all the + /// subsequent statements within the current function (only!) + /// to be executed in a non-uniform control flow. + const MAY_RETURN = 0x1; + /// Control flow may be killed. Anything after `Statement::Kill` is + /// considered inside non-uniform context. + const MAY_KILL = 0x2; + } +} + +/// Uniformity characteristics of a function. +#[cfg_attr(test, derive(Debug, PartialEq))] +struct FunctionUniformity { + result: Uniformity, + exit: ExitFlags, +} + +impl ops::BitOr for FunctionUniformity { + type Output = Self; + fn bitor(self, other: Self) -> Self { + FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .result + .non_uniform_result + .or(other.result.non_uniform_result), + requirements: self.result.requirements | other.result.requirements, + }, + exit: self.exit | other.exit, + } + } +} + +impl FunctionUniformity { + const fn new() -> Self { + FunctionUniformity { + result: Uniformity::new(), + exit: ExitFlags::empty(), + } + } + + /// Returns a disruptor based on the stored exit flags, if any. + const fn exit_disruptor(&self) -> Option<UniformityDisruptor> { + if self.exit.contains(ExitFlags::MAY_RETURN) { + Some(UniformityDisruptor::Return) + } else if self.exit.contains(ExitFlags::MAY_KILL) { + Some(UniformityDisruptor::Discard) + } else { + None + } + } +} + +bitflags::bitflags! { + /// Indicates how a global variable is used. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct GlobalUse: u8 { + /// Data will be read from the variable. + const READ = 0x1; + /// Data will be written to the variable. + const WRITE = 0x2; + /// The information about the data is queried. + const QUERY = 0x4; + } +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct SamplingKey { + pub image: Handle<crate::GlobalVariable>, + pub sampler: Handle<crate::GlobalVariable>, +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct ExpressionInfo { + pub uniformity: Uniformity, + pub ref_count: usize, + assignable_global: Option<Handle<crate::GlobalVariable>>, + pub ty: TypeResolution, +} + +impl ExpressionInfo { + const fn new() -> Self { + ExpressionInfo { + uniformity: Uniformity::new(), + ref_count: 0, + assignable_global: None, + // this doesn't matter at this point, will be overwritten + ty: TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: 0, + })), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +enum GlobalOrArgument { + Global(Handle<crate::GlobalVariable>), + Argument(u32), +} + +impl GlobalOrArgument { + fn from_expression( + expression_arena: &Arena<crate::Expression>, + expression: Handle<crate::Expression>, + ) -> Result<GlobalOrArgument, ExpressionError> { + Ok(match expression_arena[expression] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i), + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + _ => return Err(ExpressionError::ExpectedGlobalOrArgument), + }, + _ => return Err(ExpressionError::ExpectedGlobalOrArgument), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +struct Sampling { + image: GlobalOrArgument, + sampler: GlobalOrArgument, +} + +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct FunctionInfo { + /// Validation flags. + #[allow(dead_code)] + flags: ValidationFlags, + /// Set of shader stages where calling this function is valid. + pub available_stages: ShaderStages, + /// Uniformity characteristics. + pub uniformity: Uniformity, + /// Function may kill the invocation. + pub may_kill: bool, + + /// All pairs of (texture, sampler) globals that may be used together in + /// sampling operations by this function and its callees. This includes + /// pairings that arise when this function passes textures and samplers as + /// arguments to its callees. + /// + /// This table does not include uses of textures and samplers passed as + /// arguments to this function itself, since we do not know which globals + /// those will be. However, this table *is* exhaustive when computed for an + /// entry point function: entry points never receive textures or samplers as + /// arguments, so all an entry point's sampling can be reported in terms of + /// globals. + /// + /// The GLSL back end uses this table to construct reflection info that + /// clients need to construct texture-combined sampler values. + pub sampling_set: crate::FastHashSet<SamplingKey>, + + /// How this function and its callees use this module's globals. + /// + /// This is indexed by `Handle<GlobalVariable>` indices. However, + /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`, + /// so you can simply index this struct with a global handle to retrieve + /// its usage information. + global_uses: Box<[GlobalUse]>, + + /// Information about each expression in this function's body. + /// + /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo` + /// implements `std::ops::Index<Handle<Expression>>`, so you can simply + /// index this struct with an expression handle to retrieve its + /// `ExpressionInfo`. + expressions: Box<[ExpressionInfo]>, + + /// All (texture, sampler) pairs that may be used together in sampling + /// operations by this function and its callees, whether they are accessed + /// as globals or passed as arguments. + /// + /// Participants are represented by [`GlobalVariable`] handles whenever + /// possible, and otherwise by indices of this function's arguments. + /// + /// When analyzing a function call, we combine this data about the callee + /// with the actual arguments being passed to produce the callers' own + /// `sampling_set` and `sampling` tables. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + sampling: crate::FastHashSet<Sampling>, + + /// Indicates that the function is using dual source blending. + pub dual_source_blending: bool, +} + +impl FunctionInfo { + pub const fn global_variable_count(&self) -> usize { + self.global_uses.len() + } + pub const fn expression_count(&self) -> usize { + self.expressions.len() + } + pub fn dominates_global_use(&self, other: &Self) -> bool { + for (self_global_uses, other_global_uses) in + self.global_uses.iter().zip(other.global_uses.iter()) + { + if !self_global_uses.contains(*other_global_uses) { + return false; + } + } + true + } +} + +impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo { + type Output = GlobalUse; + fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse { + &self.global_uses[handle.index()] + } +} + +impl ops::Index<Handle<crate::Expression>> for FunctionInfo { + type Output = ExpressionInfo; + fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo { + &self.expressions[handle.index()] + } +} + +/// Disruptor of the uniform control flow. +#[derive(Clone, Copy, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum UniformityDisruptor { + #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")] + Expression(Handle<crate::Expression>), + #[error("There is a Return earlier in the control flow of the function")] + Return, + #[error("There is a Discard earlier in the entry point across all called functions")] + Discard, +} + +impl FunctionInfo { + /// Adds a value-type reference to an expression. + #[must_use] + fn add_ref_impl( + &mut self, + handle: Handle<crate::Expression>, + global_use: GlobalUse, + ) -> NonUniformResult { + let info = &mut self.expressions[handle.index()]; + info.ref_count += 1; + // mark the used global as read + if let Some(global) = info.assignable_global { + self.global_uses[global.index()] |= global_use; + } + info.uniformity.non_uniform_result + } + + /// Adds a value-type reference to an expression. + #[must_use] + fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult { + self.add_ref_impl(handle, GlobalUse::READ) + } + + /// Adds a potentially assignable reference to an expression. + /// These are destinations for `Store` and `ImageStore` statements, + /// which can transit through `Access` and `AccessIndex`. + #[must_use] + fn add_assignable_ref( + &mut self, + handle: Handle<crate::Expression>, + assignable_global: &mut Option<Handle<crate::GlobalVariable>>, + ) -> NonUniformResult { + let info = &mut self.expressions[handle.index()]; + info.ref_count += 1; + // propagate the assignable global up the chain, till it either hits + // a value-type expression, or the assignment statement. + if let Some(global) = info.assignable_global { + if let Some(_old) = assignable_global.replace(global) { + unreachable!() + } + } + info.uniformity.non_uniform_result + } + + /// Inherit information from a called function. + fn process_call( + &mut self, + callee: &Self, + arguments: &[Handle<crate::Expression>], + expression_arena: &Arena<crate::Expression>, + ) -> Result<FunctionUniformity, WithSpan<FunctionError>> { + self.sampling_set + .extend(callee.sampling_set.iter().cloned()); + for sampling in callee.sampling.iter() { + // If the callee was passed the texture or sampler as an argument, + // we may now be able to determine which globals those referred to. + let image_storage = match sampling.image { + GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var), + GlobalOrArgument::Argument(i) => { + let handle = arguments[i as usize]; + GlobalOrArgument::from_expression(expression_arena, handle).map_err( + |source| { + FunctionError::Expression { handle, source } + .with_span_handle(handle, expression_arena) + }, + )? + } + }; + + let sampler_storage = match sampling.sampler { + GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var), + GlobalOrArgument::Argument(i) => { + let handle = arguments[i as usize]; + GlobalOrArgument::from_expression(expression_arena, handle).map_err( + |source| { + FunctionError::Expression { handle, source } + .with_span_handle(handle, expression_arena) + }, + )? + } + }; + + // If we've managed to pin both the image and sampler down to + // specific globals, record that in our `sampling_set`. Otherwise, + // record as much as we do know in our own `sampling` table, for our + // callers to sort out. + match (image_storage, sampler_storage) { + (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => { + self.sampling_set.insert(SamplingKey { image, sampler }); + } + (image, sampler) => { + self.sampling.insert(Sampling { image, sampler }); + } + } + } + + // Inherit global use from our callees. + for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) { + *mine |= *other; + } + + Ok(FunctionUniformity { + result: callee.uniformity.clone(), + exit: if callee.may_kill { + ExitFlags::MAY_KILL + } else { + ExitFlags::empty() + }, + }) + } + + /// Compute the [`ExpressionInfo`] for `handle`. + /// + /// Replace the dummy entry in [`self.expressions`] for `handle` + /// with a real `ExpressionInfo` value describing that expression. + /// + /// This function is called as part of a forward sweep through the + /// arena, so we can assume that all earlier expressions in the + /// arena already have valid info. Since expressions only depend + /// on earlier expressions, this includes all our subexpressions. + /// + /// Adjust the reference counts on all expressions we use. + /// + /// Also populate the [`sampling_set`], [`sampling`] and + /// [`global_uses`] fields of `self`. + /// + /// [`self.expressions`]: FunctionInfo::expressions + /// [`sampling_set`]: FunctionInfo::sampling_set + /// [`sampling`]: FunctionInfo::sampling + /// [`global_uses`]: FunctionInfo::global_uses + #[allow(clippy::or_fun_call)] + fn process_expression( + &mut self, + handle: Handle<crate::Expression>, + expression_arena: &Arena<crate::Expression>, + other_functions: &[FunctionInfo], + resolve_context: &ResolveContext, + capabilities: super::Capabilities, + ) -> Result<(), ExpressionError> { + use crate::{Expression as E, SampleLevel as Sl}; + + let expression = &expression_arena[handle]; + let mut assignable_global = None; + let uniformity = match *expression { + E::Access { base, index } => { + let base_ty = self[base].ty.inner_with(resolve_context.types); + + // build up the caps needed if this is indexed non-uniformly + let mut needed_caps = super::Capabilities::empty(); + let is_binding_array = match *base_ty { + crate::TypeInner::BindingArray { + base: array_element_ty_handle, + .. + } => { + // these are nasty aliases, but these idents are too long and break rustfmt + let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING; + let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING; + let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING; + + // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it. + let array_element_ty = + &resolve_context.types[array_element_ty_handle].inner; + + needed_caps |= match *array_element_ty { + // If we're an image, use the appropriate limit. + crate::TypeInner::Image { class, .. } => match class { + crate::ImageClass::Storage { .. } => ub_st, + _ => st_sb, + }, + crate::TypeInner::Sampler { .. } => sampler, + // If we're anything but an image, assume we're a buffer and use the address space. + _ => { + if let E::GlobalVariable(global_handle) = expression_arena[base] { + let global = &resolve_context.global_vars[global_handle]; + match global.space { + crate::AddressSpace::Uniform => ub_st, + crate::AddressSpace::Storage { .. } => st_sb, + _ => unreachable!(), + } + } else { + unreachable!() + } + } + }; + + true + } + _ => false, + }; + + if self[index].uniformity.non_uniform_result.is_some() + && !capabilities.contains(needed_caps) + && is_binding_array + { + return Err(ExpressionError::MissingCapabilities(needed_caps)); + } + + Uniformity { + non_uniform_result: self + .add_assignable_ref(base, &mut assignable_global) + .or(self.add_ref(index)), + requirements: UniformityRequirements::empty(), + } + } + E::AccessIndex { base, .. } => Uniformity { + non_uniform_result: self.add_assignable_ref(base, &mut assignable_global), + requirements: UniformityRequirements::empty(), + }, + // always uniform + E::Splat { size: _, value } => Uniformity { + non_uniform_result: self.add_ref(value), + requirements: UniformityRequirements::empty(), + }, + E::Swizzle { vector, .. } => Uniformity { + non_uniform_result: self.add_ref(vector), + requirements: UniformityRequirements::empty(), + }, + E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Compose { ref components, .. } => { + let non_uniform_result = components + .iter() + .fold(None, |nur, &comp| nur.or(self.add_ref(comp))); + Uniformity { + non_uniform_result, + requirements: UniformityRequirements::empty(), + } + } + // depends on the builtin or interpolation + E::FunctionArgument(index) => { + let arg = &resolve_context.arguments[index as usize]; + let uniform = match arg.binding { + Some(crate::Binding::BuiltIn(built_in)) => match built_in { + // per-polygon built-ins are uniform + crate::BuiltIn::FrontFacing + // per-work-group built-ins are uniform + | crate::BuiltIn::WorkGroupId + | crate::BuiltIn::WorkGroupSize + | crate::BuiltIn::NumWorkGroups => true, + _ => false, + }, + // only flat inputs are uniform + Some(crate::Binding::Location { + interpolation: Some(crate::Interpolation::Flat), + .. + }) => true, + _ => false, + }; + Uniformity { + non_uniform_result: if uniform { None } else { Some(handle) }, + requirements: UniformityRequirements::empty(), + } + } + // depends on the address space + E::GlobalVariable(gh) => { + use crate::AddressSpace as As; + assignable_global = Some(gh); + let var = &resolve_context.global_vars[gh]; + let uniform = match var.space { + // local data is non-uniform + As::Function | As::Private => false, + // workgroup memory is exclusively accessed by the group + As::WorkGroup => true, + // uniform data + As::Uniform | As::PushConstant => true, + // storage data is only uniform when read-only + As::Storage { access } => !access.contains(crate::StorageAccess::STORE), + As::Handle => false, + }; + Uniformity { + non_uniform_result: if uniform { None } else { Some(handle) }, + requirements: UniformityRequirements::empty(), + } + } + E::LocalVariable(_) => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::Load { pointer } => Uniformity { + non_uniform_result: self.add_ref(pointer), + requirements: UniformityRequirements::empty(), + }, + E::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset: _, + level, + depth_ref, + } => { + let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?; + let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?; + + match (image_storage, sampler_storage) { + (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => { + self.sampling_set.insert(SamplingKey { image, sampler }); + } + _ => { + self.sampling.insert(Sampling { + image: image_storage, + sampler: sampler_storage, + }); + } + } + + // "nur" == "Non-Uniform Result" + let array_nur = array_index.and_then(|h| self.add_ref(h)); + let level_nur = match level { + Sl::Auto | Sl::Zero => None, + Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h), + Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)), + }; + let dref_nur = depth_ref.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self + .add_ref(image) + .or(self.add_ref(sampler)) + .or(self.add_ref(coordinate)) + .or(array_nur) + .or(level_nur) + .or(dref_nur), + requirements: if level.implicit_derivatives() { + UniformityRequirements::IMPLICIT_LEVEL + } else { + UniformityRequirements::empty() + }, + } + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + let array_nur = array_index.and_then(|h| self.add_ref(h)); + let sample_nur = sample.and_then(|h| self.add_ref(h)); + let level_nur = level.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self + .add_ref(image) + .or(self.add_ref(coordinate)) + .or(array_nur) + .or(sample_nur) + .or(level_nur), + requirements: UniformityRequirements::empty(), + } + } + E::ImageQuery { image, query } => { + let query_nur = match query { + crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h), + _ => None, + }; + Uniformity { + non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur), + requirements: UniformityRequirements::empty(), + } + } + E::Unary { expr, .. } => Uniformity { + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::empty(), + }, + E::Binary { left, right, .. } => Uniformity { + non_uniform_result: self.add_ref(left).or(self.add_ref(right)), + requirements: UniformityRequirements::empty(), + }, + E::Select { + condition, + accept, + reject, + } => Uniformity { + non_uniform_result: self + .add_ref(condition) + .or(self.add_ref(accept)) + .or(self.add_ref(reject)), + requirements: UniformityRequirements::empty(), + }, + // explicit derivatives require uniform + E::Derivative { expr, .. } => Uniformity { + //Note: taking a derivative of a uniform doesn't make it non-uniform + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::DERIVATIVE, + }, + E::Relational { argument, .. } => Uniformity { + non_uniform_result: self.add_ref(argument), + requirements: UniformityRequirements::empty(), + }, + E::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + let arg1_nur = arg1.and_then(|h| self.add_ref(h)); + let arg2_nur = arg2.and_then(|h| self.add_ref(h)); + let arg3_nur = arg3.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur).or(arg3_nur), + requirements: UniformityRequirements::empty(), + } + } + E::As { expr, .. } => Uniformity { + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::empty(), + }, + E::CallResult(function) => other_functions[function.index()].uniformity.clone(), + E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::WorkGroupUniformLoadResult { .. } => Uniformity { + // The result of WorkGroupUniformLoad is always uniform by definition + non_uniform_result: None, + // The call is what cares about uniformity, not the expression + // This expression is never emitted, so this requirement should never be used anyway? + requirements: UniformityRequirements::empty(), + }, + E::ArrayLength(expr) => Uniformity { + non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY), + requirements: UniformityRequirements::empty(), + }, + E::RayQueryGetIntersection { + query, + committed: _, + } => Uniformity { + non_uniform_result: self.add_ref(query), + requirements: UniformityRequirements::empty(), + }, + }; + + let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; + self.expressions[handle.index()] = ExpressionInfo { + uniformity, + ref_count: 0, + assignable_global, + ty, + }; + Ok(()) + } + + /// Analyzes the uniformity requirements of a block (as a sequence of statements). + /// Returns the uniformity characteristics at the *function* level, i.e. + /// whether or not the function requires to be called in uniform control flow, + /// and whether the produced result is not disrupting the control flow. + /// + /// The parent control flow is uniform if `disruptor.is_none()`. + /// + /// Returns a `NonUniformControlFlow` error if any of the expressions in the block + /// require uniformity, but the current flow is non-uniform. + #[allow(clippy::or_fun_call)] + fn process_block( + &mut self, + statements: &crate::Block, + other_functions: &[FunctionInfo], + mut disruptor: Option<UniformityDisruptor>, + expression_arena: &Arena<crate::Expression>, + ) -> Result<FunctionUniformity, WithSpan<FunctionError>> { + use crate::Statement as S; + + let mut combined_uniformity = FunctionUniformity::new(); + for statement in statements { + let uniformity = match *statement { + S::Emit(ref range) => { + let mut requirements = UniformityRequirements::empty(); + for expr in range.clone() { + let req = self.expressions[expr.index()].uniformity.requirements; + if self + .flags + .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY) + && !req.is_empty() + { + if let Some(cause) = disruptor { + return Err(FunctionError::NonUniformControlFlow(req, expr, cause) + .with_span_handle(expr, expression_arena)); + } + } + requirements |= req; + } + FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements, + }, + exit: ExitFlags::empty(), + } + } + S::Break | S::Continue => FunctionUniformity::new(), + S::Kill => FunctionUniformity { + result: Uniformity::new(), + exit: if disruptor.is_some() { + ExitFlags::MAY_KILL + } else { + ExitFlags::empty() + }, + }, + S::Barrier(_) => FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::WORK_GROUP_BARRIER, + }, + exit: ExitFlags::empty(), + }, + S::WorkGroupUniformLoad { pointer, .. } => { + let _condition_nur = self.add_ref(pointer); + + // Don't check that this call occurs in uniform control flow until Naga implements WGSL's standard + // uniformity analysis (https://github.com/gfx-rs/naga/issues/1744). + // The uniformity analysis Naga uses now is less accurate than the one in the WGSL standard, + // causing Naga to reject correct uses of `workgroupUniformLoad` in some interesting programs. + + /* + if self + .flags + .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY) + { + let condition_nur = self.add_ref(pointer); + let this_disruptor = + disruptor.or(condition_nur.map(UniformityDisruptor::Expression)); + if let Some(cause) = this_disruptor { + return Err(FunctionError::NonUniformWorkgroupUniformLoad(cause) + .with_span_static(*span, "WorkGroupUniformLoad")); + } + } */ + FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::WORK_GROUP_BARRIER, + }, + exit: ExitFlags::empty(), + } + } + S::Block(ref b) => { + self.process_block(b, other_functions, disruptor, expression_arena)? + } + S::If { + condition, + ref accept, + ref reject, + } => { + let condition_nur = self.add_ref(condition); + let branch_disruptor = + disruptor.or(condition_nur.map(UniformityDisruptor::Expression)); + let accept_uniformity = self.process_block( + accept, + other_functions, + branch_disruptor, + expression_arena, + )?; + let reject_uniformity = self.process_block( + reject, + other_functions, + branch_disruptor, + expression_arena, + )?; + accept_uniformity | reject_uniformity + } + S::Switch { + selector, + ref cases, + } => { + let selector_nur = self.add_ref(selector); + let branch_disruptor = + disruptor.or(selector_nur.map(UniformityDisruptor::Expression)); + let mut uniformity = FunctionUniformity::new(); + let mut case_disruptor = branch_disruptor; + for case in cases.iter() { + let case_uniformity = self.process_block( + &case.body, + other_functions, + case_disruptor, + expression_arena, + )?; + case_disruptor = if case.fall_through { + case_disruptor.or(case_uniformity.exit_disruptor()) + } else { + branch_disruptor + }; + uniformity = uniformity | case_uniformity; + } + uniformity + } + S::Loop { + ref body, + ref continuing, + break_if, + } => { + let body_uniformity = + self.process_block(body, other_functions, disruptor, expression_arena)?; + let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor()); + let continuing_uniformity = self.process_block( + continuing, + other_functions, + continuing_disruptor, + expression_arena, + )?; + if let Some(expr) = break_if { + let _ = self.add_ref(expr); + } + body_uniformity | continuing_uniformity + } + S::Return { value } => FunctionUniformity { + result: Uniformity { + non_uniform_result: value.and_then(|expr| self.add_ref(expr)), + requirements: UniformityRequirements::empty(), + }, + exit: if disruptor.is_some() { + ExitFlags::MAY_RETURN + } else { + ExitFlags::empty() + }, + }, + // Here and below, the used expressions are already emitted, + // and their results do not affect the function return value, + // so we can ignore their non-uniformity. + S::Store { pointer, value } => { + let _ = self.add_ref_impl(pointer, GlobalUse::WRITE); + let _ = self.add_ref(value); + FunctionUniformity::new() + } + S::ImageStore { + image, + coordinate, + array_index, + value, + } => { + let _ = self.add_ref_impl(image, GlobalUse::WRITE); + if let Some(expr) = array_index { + let _ = self.add_ref(expr); + } + let _ = self.add_ref(coordinate); + let _ = self.add_ref(value); + FunctionUniformity::new() + } + S::Call { + function, + ref arguments, + result: _, + } => { + for &argument in arguments { + let _ = self.add_ref(argument); + } + let info = &other_functions[function.index()]; + //Note: the result is validated by the Validator, not here + self.process_call(info, arguments, expression_arena)? + } + S::Atomic { + pointer, + ref fun, + value, + result: _, + } => { + let _ = self.add_ref_impl(pointer, GlobalUse::WRITE); + let _ = self.add_ref(value); + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + let _ = self.add_ref(cmp); + } + FunctionUniformity::new() + } + S::RayQuery { query, ref fun } => { + let _ = self.add_ref(query); + if let crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } = *fun + { + let _ = self.add_ref(acceleration_structure); + let _ = self.add_ref(descriptor); + } + FunctionUniformity::new() + } + }; + + disruptor = disruptor.or(uniformity.exit_disruptor()); + combined_uniformity = combined_uniformity | uniformity; + } + Ok(combined_uniformity) + } +} + +impl ModuleInfo { + /// Populates `self.const_expression_types` + pub(super) fn process_const_expression( + &mut self, + handle: Handle<crate::Expression>, + resolve_context: &ResolveContext, + gctx: crate::proc::GlobalCtx, + ) -> Result<(), super::ConstExpressionError> { + self.const_expression_types[handle.index()] = + resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?; + Ok(()) + } + + /// Builds the `FunctionInfo` based on the function, and validates the + /// uniform control flow if required by the expressions of this function. + pub(super) fn process_function( + &self, + fun: &crate::Function, + module: &crate::Module, + flags: ValidationFlags, + capabilities: super::Capabilities, + ) -> Result<FunctionInfo, WithSpan<FunctionError>> { + let mut info = FunctionInfo { + flags, + available_stages: ShaderStages::all(), + uniformity: Uniformity::new(), + may_kill: false, + sampling_set: crate::FastHashSet::default(), + global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(), + expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(), + sampling: crate::FastHashSet::default(), + dual_source_blending: false, + }; + let resolve_context = + ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); + + for (handle, _) in fun.expressions.iter() { + if let Err(source) = info.process_expression( + handle, + &fun.expressions, + &self.functions, + &resolve_context, + capabilities, + ) { + return Err(FunctionError::Expression { handle, source } + .with_span_handle(handle, &fun.expressions)); + } + } + + for (_, expr) in fun.local_variables.iter() { + if let Some(init) = expr.init { + let _ = info.add_ref(init); + } + } + + let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?; + info.uniformity = uniformity.result; + info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL); + + Ok(info) + } + + pub fn get_entry_point(&self, index: usize) -> &FunctionInfo { + &self.entry_points[index] + } +} + +#[test] +fn uniform_control_flow() { + use crate::{Expression as E, Statement as S}; + + let mut type_arena = crate::UniqueArena::new(); + let ty = type_arena.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + scalar: crate::Scalar::F32, + }, + }, + Default::default(), + ); + let mut global_var_arena = Arena::new(); + let non_uniform_global = global_var_arena.append( + crate::GlobalVariable { + name: None, + init: None, + ty, + space: crate::AddressSpace::Handle, + binding: None, + }, + Default::default(), + ); + let uniform_global = global_var_arena.append( + crate::GlobalVariable { + name: None, + init: None, + ty, + binding: None, + space: crate::AddressSpace::Uniform, + }, + Default::default(), + ); + + let mut expressions = Arena::new(); + // checks the uniform control flow + let constant_expr = expressions.append(E::Literal(crate::Literal::U32(0)), Default::default()); + // checks the non-uniform control flow + let derivative_expr = expressions.append( + E::Derivative { + axis: crate::DerivativeAxis::X, + ctrl: crate::DerivativeControl::None, + expr: constant_expr, + }, + Default::default(), + ); + let emit_range_constant_derivative = expressions.range_from(0); + let non_uniform_global_expr = + expressions.append(E::GlobalVariable(non_uniform_global), Default::default()); + let uniform_global_expr = + expressions.append(E::GlobalVariable(uniform_global), Default::default()); + let emit_range_globals = expressions.range_from(2); + + // checks the QUERY flag + let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default()); + // checks the transitive WRITE flag + let access_expr = expressions.append( + E::AccessIndex { + base: non_uniform_global_expr, + index: 1, + }, + Default::default(), + ); + let emit_range_query_access_globals = expressions.range_from(2); + + let mut info = FunctionInfo { + flags: ValidationFlags::all(), + available_stages: ShaderStages::all(), + uniformity: Uniformity::new(), + may_kill: false, + sampling_set: crate::FastHashSet::default(), + global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(), + expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(), + sampling: crate::FastHashSet::default(), + dual_source_blending: false, + }; + let resolve_context = ResolveContext { + constants: &Arena::new(), + types: &type_arena, + special_types: &crate::SpecialTypes::default(), + global_vars: &global_var_arena, + local_vars: &Arena::new(), + functions: &Arena::new(), + arguments: &[], + }; + for (handle, _) in expressions.iter() { + info.process_expression( + handle, + &expressions, + &[], + &resolve_context, + super::Capabilities::empty(), + ) + .unwrap(); + } + assert_eq!(info[non_uniform_global_expr].ref_count, 1); + assert_eq!(info[uniform_global_expr].ref_count, 1); + assert_eq!(info[query_expr].ref_count, 0); + assert_eq!(info[access_expr].ref_count, 0); + assert_eq!(info[non_uniform_global], GlobalUse::empty()); + assert_eq!(info[uniform_global], GlobalUse::QUERY); + + let stmt_emit1 = S::Emit(emit_range_globals.clone()); + let stmt_if_uniform = S::If { + condition: uniform_global_expr, + accept: crate::Block::new(), + reject: vec![ + S::Emit(emit_range_constant_derivative.clone()), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ] + .into(), + }; + assert_eq!( + info.process_block( + &vec![stmt_emit1, stmt_if_uniform].into(), + &[], + None, + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::DERIVATIVE, + }, + exit: ExitFlags::empty(), + }), + ); + assert_eq!(info[constant_expr].ref_count, 2); + assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY); + + let stmt_emit2 = S::Emit(emit_range_globals.clone()); + let stmt_if_non_uniform = S::If { + condition: non_uniform_global_expr, + accept: vec![ + S::Emit(emit_range_constant_derivative), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ] + .into(), + reject: crate::Block::new(), + }; + { + let block_info = info.process_block( + &vec![stmt_emit2, stmt_if_non_uniform].into(), + &[], + None, + &expressions, + ); + if DISABLE_UNIFORMITY_REQ_FOR_FRAGMENT_STAGE { + assert_eq!(info[derivative_expr].ref_count, 2); + } else { + assert_eq!( + block_info, + Err(FunctionError::NonUniformControlFlow( + UniformityRequirements::DERIVATIVE, + derivative_expr, + UniformityDisruptor::Expression(non_uniform_global_expr) + ) + .with_span()), + ); + assert_eq!(info[derivative_expr].ref_count, 1); + } + } + assert_eq!(info[non_uniform_global], GlobalUse::READ); + + let stmt_emit3 = S::Emit(emit_range_globals); + let stmt_return_non_uniform = S::Return { + value: Some(non_uniform_global_expr), + }; + assert_eq!( + info.process_block( + &vec![stmt_emit3, stmt_return_non_uniform].into(), + &[], + Some(UniformityDisruptor::Return), + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: Some(non_uniform_global_expr), + requirements: UniformityRequirements::empty(), + }, + exit: ExitFlags::MAY_RETURN, + }), + ); + assert_eq!(info[non_uniform_global_expr].ref_count, 3); + + // Check that uniformity requirements reach through a pointer + let stmt_emit4 = S::Emit(emit_range_query_access_globals); + let stmt_assign = S::Store { + pointer: access_expr, + value: query_expr, + }; + let stmt_return_pointer = S::Return { + value: Some(access_expr), + }; + let stmt_kill = S::Kill; + assert_eq!( + info.process_block( + &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(), + &[], + Some(UniformityDisruptor::Discard), + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: Some(non_uniform_global_expr), + requirements: UniformityRequirements::empty(), + }, + exit: ExitFlags::all(), + }), + ); + assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE); +} diff --git a/third_party/rust/naga/src/valid/compose.rs b/third_party/rust/naga/src/valid/compose.rs new file mode 100644 index 0000000000..c21e98c6f2 --- /dev/null +++ b/third_party/rust/naga/src/valid/compose.rs @@ -0,0 +1,128 @@ +use crate::proc::TypeResolution; + +use crate::arena::Handle; + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ComposeError { + #[error("Composing of type {0:?} can't be done")] + Type(Handle<crate::Type>), + #[error("Composing expects {expected} components but {given} were given")] + ComponentCount { given: u32, expected: u32 }, + #[error("Composing {index}'s component type is not expected")] + ComponentType { index: u32 }, +} + +pub fn validate_compose( + self_ty_handle: Handle<crate::Type>, + gctx: crate::proc::GlobalCtx, + component_resolutions: impl ExactSizeIterator<Item = TypeResolution>, +) -> Result<(), ComposeError> { + use crate::TypeInner as Ti; + + match gctx.types[self_ty_handle].inner { + // vectors are composed from scalars or other vectors + Ti::Vector { size, scalar } => { + let mut total = 0; + for (index, comp_res) in component_resolutions.enumerate() { + total += match *comp_res.inner_with(gctx.types) { + Ti::Scalar(comp_scalar) if comp_scalar == scalar => 1, + Ti::Vector { + size: comp_size, + scalar: comp_scalar, + } if comp_scalar == scalar => comp_size as u32, + ref other => { + log::error!( + "Vector component[{}] type {:?}, building {:?}", + index, + other, + scalar + ); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + }; + } + if size as u32 != total { + return Err(ComposeError::ComponentCount { + expected: size as u32, + given: total, + }); + } + } + // matrix are composed from column vectors + Ti::Matrix { + columns, + rows, + scalar, + } => { + let inner = Ti::Vector { size: rows, scalar }; + if columns as usize != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + expected: columns as u32, + given: component_resolutions.len() as u32, + }); + } + for (index, comp_res) in component_resolutions.enumerate() { + if comp_res.inner_with(gctx.types) != &inner { + log::error!("Matrix component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + Ti::Array { + base, + size: crate::ArraySize::Constant(count), + stride: _, + } => { + if count.get() as usize != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + expected: count.get(), + given: component_resolutions.len() as u32, + }); + } + for (index, comp_res) in component_resolutions.enumerate() { + let base_inner = &gctx.types[base].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); + // We don't support arrays of pointers, but it seems best not to + // embed that assumption here, so use `TypeInner::equivalent`. + if !base_inner.equivalent(comp_res_inner, gctx.types) { + log::error!("Array component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + Ti::Struct { ref members, .. } => { + if members.len() != component_resolutions.len() { + return Err(ComposeError::ComponentCount { + given: component_resolutions.len() as u32, + expected: members.len() as u32, + }); + } + for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate() + { + let member_inner = &gctx.types[member.ty].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); + // We don't support pointers in structs, but it seems best not to embed + // that assumption here, so use `TypeInner::equivalent`. + if !comp_res_inner.equivalent(member_inner, gctx.types) { + log::error!("Struct component[{}] type {:?}", index, comp_res); + return Err(ComposeError::ComponentType { + index: index as u32, + }); + } + } + } + ref other => { + log::error!("Composing of {:?}", other); + return Err(ComposeError::Type(self_ty_handle)); + } + } + + Ok(()) +} diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs new file mode 100644 index 0000000000..c82d60f062 --- /dev/null +++ b/third_party/rust/naga/src/valid/expression.rs @@ -0,0 +1,1797 @@ +use super::{ + compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ModuleInfo, + ShaderStages, TypeFlags, +}; +use crate::arena::UniqueArena; + +use crate::{ + arena::Handle, + proc::{IndexableLengthError, ResolveError}, +}; + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum ExpressionError { + #[error("Doesn't exist")] + DoesntExist, + #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] + NotInScope, + #[error("Base type {0:?} is not compatible with this expression")] + InvalidBaseType(Handle<crate::Expression>), + #[error("Accessing with index {0:?} can't be done")] + InvalidIndexType(Handle<crate::Expression>), + #[error("Accessing {0:?} via a negative index is invalid")] + NegativeIndex(Handle<crate::Expression>), + #[error("Accessing index {1} is out of {0:?} bounds")] + IndexOutOfBounds(Handle<crate::Expression>, u32), + #[error("The expression {0:?} may only be indexed by a constant")] + IndexMustBeConstant(Handle<crate::Expression>), + #[error("Function argument {0:?} doesn't exist")] + FunctionArgumentDoesntExist(u32), + #[error("Loading of {0:?} can't be done")] + InvalidPointerType(Handle<crate::Expression>), + #[error("Array length of {0:?} can't be done")] + InvalidArrayType(Handle<crate::Expression>), + #[error("Get intersection of {0:?} can't be done")] + InvalidRayQueryType(Handle<crate::Expression>), + #[error("Splatting {0:?} can't be done")] + InvalidSplatType(Handle<crate::Expression>), + #[error("Swizzling {0:?} can't be done")] + InvalidVectorType(Handle<crate::Expression>), + #[error("Swizzle component {0:?} is outside of vector size {1:?}")] + InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize), + #[error(transparent)] + Compose(#[from] super::ComposeError), + #[error(transparent)] + IndexableLength(#[from] IndexableLengthError), + #[error("Operation {0:?} can't work with {1:?}")] + InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>), + #[error("Operation {0:?} can't work with {1:?} and {2:?}")] + InvalidBinaryOperandTypes( + crate::BinaryOperator, + Handle<crate::Expression>, + Handle<crate::Expression>, + ), + #[error("Selecting is not possible")] + InvalidSelectTypes, + #[error("Relational argument {0:?} is not a boolean vector")] + InvalidBooleanVector(Handle<crate::Expression>), + #[error("Relational argument {0:?} is not a float")] + InvalidFloatArgument(Handle<crate::Expression>), + #[error("Type resolution failed")] + Type(#[from] ResolveError), + #[error("Not a global variable")] + ExpectedGlobalVariable, + #[error("Not a global variable or a function argument")] + ExpectedGlobalOrArgument, + #[error("Needs to be an binding array instead of {0:?}")] + ExpectedBindingArrayType(Handle<crate::Type>), + #[error("Needs to be an image instead of {0:?}")] + ExpectedImageType(Handle<crate::Type>), + #[error("Needs to be an image instead of {0:?}")] + ExpectedSamplerType(Handle<crate::Type>), + #[error("Unable to operate on image class {0:?}")] + InvalidImageClass(crate::ImageClass), + #[error("Derivatives can only be taken from scalar and vector floats")] + InvalidDerivative, + #[error("Image array index parameter is misplaced")] + InvalidImageArrayIndex, + #[error("Inappropriate sample or level-of-detail index for texel access")] + InvalidImageOtherIndex, + #[error("Image array index type of {0:?} is not an integer scalar")] + InvalidImageArrayIndexType(Handle<crate::Expression>), + #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")] + InvalidImageOtherIndexType(Handle<crate::Expression>), + #[error("Image coordinate type of {1:?} does not match dimension {0:?}")] + InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>), + #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")] + ComparisonSamplingMismatch { + image: crate::ImageClass, + sampler: bool, + has_ref: bool, + }, + #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] + InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>), + #[error("Depth reference {0:?} is not a scalar float")] + InvalidDepthReference(Handle<crate::Expression>), + #[error("Depth sample level can only be Auto or Zero")] + InvalidDepthSampleLevel, + #[error("Gather level can only be Zero")] + InvalidGatherLevel, + #[error("Gather component {0:?} doesn't exist in the image")] + InvalidGatherComponent(crate::SwizzleComponent), + #[error("Gather can't be done for image dimension {0:?}")] + InvalidGatherDimension(crate::ImageDimension), + #[error("Sample level (exact) type {0:?} is not a scalar float")] + InvalidSampleLevelExactType(Handle<crate::Expression>), + #[error("Sample level (bias) type {0:?} is not a scalar float")] + InvalidSampleLevelBiasType(Handle<crate::Expression>), + #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")] + InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>), + #[error("Unable to cast")] + InvalidCastArgument, + #[error("Invalid argument count for {0:?}")] + WrongArgumentCount(crate::MathFunction), + #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")] + InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>), + #[error("Atomic result type can't be {0:?}")] + InvalidAtomicResultType(Handle<crate::Type>), + #[error( + "workgroupUniformLoad result type can't be {0:?}. It can only be a constructible type." + )] + InvalidWorkGroupUniformLoadResultType(Handle<crate::Type>), + #[error("Shader requires capability {0:?}")] + MissingCapabilities(super::Capabilities), + #[error(transparent)] + Literal(#[from] LiteralError), +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ConstExpressionError { + #[error("The expression is not a constant expression")] + NonConst, + #[error(transparent)] + Compose(#[from] super::ComposeError), + #[error("Splatting {0:?} can't be done")] + InvalidSplatType(Handle<crate::Expression>), + #[error("Type resolution failed")] + Type(#[from] ResolveError), + #[error(transparent)] + Literal(#[from] LiteralError), + #[error(transparent)] + Width(#[from] super::r#type::WidthError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum LiteralError { + #[error("Float literal is NaN")] + NaN, + #[error("Float literal is infinite")] + Infinity, + #[error(transparent)] + Width(#[from] super::r#type::WidthError), +} + +struct ExpressionTypeResolver<'a> { + root: Handle<crate::Expression>, + types: &'a UniqueArena<crate::Type>, + info: &'a FunctionInfo, +} + +impl<'a> std::ops::Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> { + type Output = crate::TypeInner; + + #[allow(clippy::panic)] + fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output { + if handle < self.root { + self.info[handle].ty.inner_with(self.types) + } else { + // `Validator::validate_module_handles` should have caught this. + panic!( + "Depends on {:?}, which has not been processed yet", + self.root + ) + } + } +} + +impl super::Validator { + pub(super) fn validate_const_expression( + &self, + handle: Handle<crate::Expression>, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), ConstExpressionError> { + use crate::Expression as E; + + match gctx.const_expressions[handle] { + E::Literal(literal) => { + self.validate_literal(literal)?; + } + E::Constant(_) | E::ZeroValue(_) => {} + E::Compose { ref components, ty } => { + validate_compose( + ty, + gctx, + components.iter().map(|&handle| mod_info[handle].clone()), + )?; + } + E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) { + crate::TypeInner::Scalar { .. } => {} + _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), + }, + _ => return Err(super::ConstExpressionError::NonConst), + } + + Ok(()) + } + + pub(super) fn validate_expression( + &self, + root: Handle<crate::Expression>, + expression: &crate::Expression, + function: &crate::Function, + module: &crate::Module, + info: &FunctionInfo, + mod_info: &ModuleInfo, + ) -> Result<ShaderStages, ExpressionError> { + use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; + + let resolver = ExpressionTypeResolver { + root, + types: &module.types, + info, + }; + + let stages = match *expression { + E::Access { base, index } => { + let base_type = &resolver[base]; + // See the documentation for `Expression::Access`. + let dynamic_indexing_restricted = match *base_type { + Ti::Vector { .. } => false, + Ti::Matrix { .. } | Ti::Array { .. } => true, + Ti::Pointer { .. } + | Ti::ValuePointer { size: Some(_), .. } + | Ti::BindingArray { .. } => false, + ref other => { + log::error!("Indexing of {:?}", other); + return Err(ExpressionError::InvalidBaseType(base)); + } + }; + match resolver[index] { + //TODO: only allow one of these + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) => {} + ref other => { + log::error!("Indexing by {:?}", other); + return Err(ExpressionError::InvalidIndexType(index)); + } + } + if dynamic_indexing_restricted + && function.expressions[index].is_dynamic_index(module) + { + return Err(ExpressionError::IndexMustBeConstant(base)); + } + + // If we know both the length and the index, we can do the + // bounds check now. + if let crate::proc::IndexableLength::Known(known_length) = + base_type.indexable_length(module)? + { + match module + .to_ctx() + .eval_expr_to_u32_from(index, &function.expressions) + { + Ok(value) => { + if value >= known_length { + return Err(ExpressionError::IndexOutOfBounds(base, value)); + } + } + Err(crate::proc::U32EvalError::Negative) => { + return Err(ExpressionError::NegativeIndex(base)) + } + Err(crate::proc::U32EvalError::NonConst) => {} + } + } + + ShaderStages::all() + } + E::AccessIndex { base, index } => { + fn resolve_index_limit( + module: &crate::Module, + top: Handle<crate::Expression>, + ty: &crate::TypeInner, + top_level: bool, + ) -> Result<u32, ExpressionError> { + let limit = match *ty { + Ti::Vector { size, .. } + | Ti::ValuePointer { + size: Some(size), .. + } => size as u32, + Ti::Matrix { columns, .. } => columns as u32, + Ti::Array { + size: crate::ArraySize::Constant(len), + .. + } => len.get(), + Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks + Ti::Pointer { base, .. } if top_level => { + resolve_index_limit(module, top, &module.types[base].inner, false)? + } + Ti::Struct { ref members, .. } => members.len() as u32, + ref other => { + log::error!("Indexing of {:?}", other); + return Err(ExpressionError::InvalidBaseType(top)); + } + }; + Ok(limit) + } + + let limit = resolve_index_limit(module, base, &resolver[base], true)?; + if index >= limit { + return Err(ExpressionError::IndexOutOfBounds(base, limit)); + } + ShaderStages::all() + } + E::Splat { size: _, value } => match resolver[value] { + Ti::Scalar { .. } => ShaderStages::all(), + ref other => { + log::error!("Splat scalar type {:?}", other); + return Err(ExpressionError::InvalidSplatType(value)); + } + }, + E::Swizzle { + size, + vector, + pattern, + } => { + let vec_size = match resolver[vector] { + Ti::Vector { size: vec_size, .. } => vec_size, + ref other => { + log::error!("Swizzle vector type {:?}", other); + return Err(ExpressionError::InvalidVectorType(vector)); + } + }; + for &sc in pattern[..size as usize].iter() { + if sc as u8 >= vec_size as u8 { + return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size)); + } + } + ShaderStages::all() + } + E::Literal(literal) => { + self.validate_literal(literal)?; + ShaderStages::all() + } + E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Compose { ref components, ty } => { + validate_compose( + ty, + module.to_ctx(), + components.iter().map(|&handle| info[handle].ty.clone()), + )?; + ShaderStages::all() + } + E::FunctionArgument(index) => { + if index >= function.arguments.len() as u32 { + return Err(ExpressionError::FunctionArgumentDoesntExist(index)); + } + ShaderStages::all() + } + E::GlobalVariable(_handle) => ShaderStages::all(), + E::LocalVariable(_handle) => ShaderStages::all(), + E::Load { pointer } => { + match resolver[pointer] { + Ti::Pointer { base, .. } + if self.types[base.index()] + .flags + .contains(TypeFlags::SIZED | TypeFlags::DATA) => {} + Ti::ValuePointer { .. } => {} + ref other => { + log::error!("Loading {:?}", other); + return Err(ExpressionError::InvalidPointerType(pointer)); + } + } + ShaderStages::all() + } + E::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + // check the validity of expressions + let image_ty = Self::global_var_ty(module, function, image)?; + let sampler_ty = Self::global_var_ty(module, function, sampler)?; + + let comparison = match module.types[sampler_ty].inner { + Ti::Sampler { comparison } => comparison, + _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)), + }; + + let (class, dim) = match module.types[image_ty].inner { + Ti::Image { + class, + arrayed, + dim, + } => { + // check the array property + if arrayed != array_index.is_some() { + return Err(ExpressionError::InvalidImageArrayIndex); + } + if let Some(expr) = array_index { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) => {} + _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), + } + } + (class, dim) + } + _ => return Err(ExpressionError::ExpectedImageType(image_ty)), + }; + + // check sampling and comparison properties + let image_depth = match class { + crate::ImageClass::Sampled { + kind: crate::ScalarKind::Float, + multi: false, + } => false, + crate::ImageClass::Sampled { + kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint, + multi: false, + } if gather.is_some() => false, + crate::ImageClass::Depth { multi: false } => true, + _ => return Err(ExpressionError::InvalidImageClass(class)), + }; + if comparison != depth_ref.is_some() || (comparison && !image_depth) { + return Err(ExpressionError::ComparisonSamplingMismatch { + image: class, + sampler: comparison, + has_ref: depth_ref.is_some(), + }); + } + + // check texture coordinates type + let num_components = match dim { + crate::ImageDimension::D1 => 1, + crate::ImageDimension::D2 => 2, + crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, + }; + match resolver[coordinate] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + } if size as u32 == num_components => {} + _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)), + } + + // check constant offset + if let Some(const_expr) = offset { + match *mod_info[const_expr].inner_with(&module.types) { + Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: Sc { kind: Sk::Sint, .. }, + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleOffset(dim, const_expr)); + } + } + } + + // check depth reference type + if let Some(expr) = depth_ref { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => {} + _ => return Err(ExpressionError::InvalidDepthReference(expr)), + } + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => {} + _ => return Err(ExpressionError::InvalidDepthSampleLevel), + } + } + + if let Some(component) = gather { + match dim { + crate::ImageDimension::D2 | crate::ImageDimension::Cube => {} + crate::ImageDimension::D1 | crate::ImageDimension::D3 => { + return Err(ExpressionError::InvalidGatherDimension(dim)) + } + }; + let max_component = match class { + crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X, + _ => crate::SwizzleComponent::W, + }; + if component > max_component { + return Err(ExpressionError::InvalidGatherComponent(component)); + } + match level { + crate::SampleLevel::Zero => {} + _ => return Err(ExpressionError::InvalidGatherLevel), + } + } + + // check level properties + match level { + crate::SampleLevel::Auto => ShaderStages::FRAGMENT, + crate::SampleLevel::Zero => ShaderStages::all(), + crate::SampleLevel::Exact(expr) => { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => {} + _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)), + } + ShaderStages::all() + } + crate::SampleLevel::Bias(expr) => { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => {} + _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)), + } + ShaderStages::FRAGMENT + } + crate::SampleLevel::Gradient { x, y } => { + match resolver[x] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) + } + } + match resolver[y] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) if num_components == 1 => {} + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + } if size as u32 == num_components => {} + _ => { + return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y)) + } + } + ShaderStages::all() + } + } + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + let ty = Self::global_var_ty(module, function, image)?; + match module.types[ty].inner { + Ti::Image { + class, + arrayed, + dim, + } => { + match resolver[coordinate].image_storage_coordinates() { + Some(coord_dim) if coord_dim == dim => {} + _ => { + return Err(ExpressionError::InvalidImageCoordinateType( + dim, coordinate, + )) + } + }; + if arrayed != array_index.is_some() { + return Err(ExpressionError::InvalidImageArrayIndex); + } + if let Some(expr) = array_index { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + width: _, + }) => {} + _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)), + } + } + + match (sample, class.is_multisampled()) { + (None, false) => {} + (Some(sample), true) => { + if resolver[sample].scalar_kind() != Some(Sk::Sint) { + return Err(ExpressionError::InvalidImageOtherIndexType( + sample, + )); + } + } + _ => { + return Err(ExpressionError::InvalidImageOtherIndex); + } + } + + match (level, class.is_mipmapped()) { + (None, false) => {} + (Some(level), true) => { + if resolver[level].scalar_kind() != Some(Sk::Sint) { + return Err(ExpressionError::InvalidImageOtherIndexType(level)); + } + } + _ => { + return Err(ExpressionError::InvalidImageOtherIndex); + } + } + } + _ => return Err(ExpressionError::ExpectedImageType(ty)), + } + ShaderStages::all() + } + E::ImageQuery { image, query } => { + let ty = Self::global_var_ty(module, function, image)?; + match module.types[ty].inner { + Ti::Image { class, arrayed, .. } => { + let good = match query { + crate::ImageQuery::NumLayers => arrayed, + crate::ImageQuery::Size { level: None } => true, + crate::ImageQuery::Size { level: Some(_) } + | crate::ImageQuery::NumLevels => class.is_mipmapped(), + crate::ImageQuery::NumSamples => class.is_multisampled(), + }; + if !good { + return Err(ExpressionError::InvalidImageClass(class)); + } + } + _ => return Err(ExpressionError::ExpectedImageType(ty)), + } + ShaderStages::all() + } + E::Unary { op, expr } => { + use crate::UnaryOperator as Uo; + let inner = &resolver[expr]; + match (op, inner.scalar_kind()) { + (Uo::Negate, Some(Sk::Float | Sk::Sint)) + | (Uo::LogicalNot, Some(Sk::Bool)) + | (Uo::BitwiseNot, Some(Sk::Sint | Sk::Uint)) => {} + other => { + log::error!("Op {:?} kind {:?}", op, other); + return Err(ExpressionError::InvalidUnaryOperandType(op, expr)); + } + } + ShaderStages::all() + } + E::Binary { op, left, right } => { + use crate::BinaryOperator as Bo; + let left_inner = &resolver[left]; + let right_inner = &resolver[right]; + let good = match op { + Bo::Add | Bo::Subtract => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + Ti::Matrix { .. } => left_inner == right_inner, + _ => false, + }, + Bo::Divide | Bo::Modulo => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + _ => false, + }, + Bo::Multiply => { + let kind_allowed = match left_inner.scalar_kind() { + Some(Sk::Uint | Sk::Sint | Sk::Float) => true, + Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false, + }; + let types_match = match (left_inner, right_inner) { + // Straight scalar and mixed scalar/vector. + (&Ti::Scalar(scalar1), &Ti::Scalar(scalar2)) + | ( + &Ti::Vector { + scalar: scalar1, .. + }, + &Ti::Scalar(scalar2), + ) + | ( + &Ti::Scalar(scalar1), + &Ti::Vector { + scalar: scalar2, .. + }, + ) => scalar1 == scalar2, + // Scalar/matrix. + ( + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + &Ti::Matrix { .. }, + ) + | ( + &Ti::Matrix { .. }, + &Ti::Scalar(Sc { + kind: Sk::Float, .. + }), + ) => true, + // Vector/vector. + ( + &Ti::Vector { + size: size1, + scalar: scalar1, + }, + &Ti::Vector { + size: size2, + scalar: scalar2, + }, + ) => scalar1 == scalar2 && size1 == size2, + // Matrix * vector. + ( + &Ti::Matrix { columns, .. }, + &Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + }, + ) => columns == size, + // Vector * matrix. + ( + &Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Float, .. + }, + }, + &Ti::Matrix { rows, .. }, + ) => size == rows, + (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => { + columns == rows + } + _ => false, + }; + let left_width = left_inner.scalar_width().unwrap_or(0); + let right_width = right_inner.scalar_width().unwrap_or(0); + kind_allowed && types_match && left_width == right_width + } + Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner, + Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { + match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + } + } + Bo::LogicalAnd | Bo::LogicalOr => match *left_inner { + Ti::Scalar(Sc { kind: Sk::Bool, .. }) + | Ti::Vector { + scalar: Sc { kind: Sk::Bool, .. }, + .. + } => left_inner == right_inner, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::And | Bo::InclusiveOr => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::ExclusiveOr => match *left_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { + Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, + }, + ref other => { + log::error!("Op {:?} left type {:?}", op, other); + false + } + }, + Bo::ShiftLeft | Bo::ShiftRight => { + let (base_size, base_scalar) = match *left_inner { + Ti::Scalar(scalar) => (Ok(None), scalar), + Ti::Vector { size, scalar } => (Ok(Some(size)), scalar), + ref other => { + log::error!("Op {:?} base type {:?}", op, other); + (Err(()), Sc::BOOL) + } + }; + let shift_size = match *right_inner { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => Ok(None), + Ti::Vector { + size, + scalar: Sc { kind: Sk::Uint, .. }, + } => Ok(Some(size)), + ref other => { + log::error!("Op {:?} shift type {:?}", op, other); + Err(()) + } + }; + match base_scalar.kind { + Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false, + } + } + }; + if !good { + log::error!( + "Left: {:?} of type {:?}", + function.expressions[left], + left_inner + ); + log::error!( + "Right: {:?} of type {:?}", + function.expressions[right], + right_inner + ); + return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right)); + } + ShaderStages::all() + } + E::Select { + condition, + accept, + reject, + } => { + let accept_inner = &resolver[accept]; + let reject_inner = &resolver[reject]; + let condition_good = match resolver[condition] { + Ti::Scalar(Sc { + kind: Sk::Bool, + width: _, + }) => { + // When `condition` is a single boolean, `accept` and + // `reject` can be vectors or scalars. + match *accept_inner { + Ti::Scalar { .. } | Ti::Vector { .. } => true, + _ => false, + } + } + Ti::Vector { + size, + scalar: + Sc { + kind: Sk::Bool, + width: _, + }, + } => match *accept_inner { + Ti::Vector { + size: other_size, .. + } => size == other_size, + _ => false, + }, + _ => false, + }; + if !condition_good || accept_inner != reject_inner { + return Err(ExpressionError::InvalidSelectTypes); + } + ShaderStages::all() + } + E::Derivative { expr, .. } => { + match resolver[expr] { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidDerivative), + } + ShaderStages::FRAGMENT + } + E::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + let argument_inner = &resolver[argument]; + match fun { + Rf::All | Rf::Any => match *argument_inner { + Ti::Vector { + scalar: Sc { kind: Sk::Bool, .. }, + .. + } => {} + ref other => { + log::error!("All/Any of type {:?}", other); + return Err(ExpressionError::InvalidBooleanVector(argument)); + } + }, + Rf::IsNan | Rf::IsInf => match *argument_inner { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float => {} + ref other => { + log::error!("Float test of type {:?}", other); + return Err(ExpressionError::InvalidFloatArgument(argument)); + } + }, + } + ShaderStages::all() + } + E::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + + let resolve = |arg| &resolver[arg]; + let arg_ty = resolve(arg); + let arg1_ty = arg1.map(resolve); + let arg2_ty = arg2.map(resolve); + let arg3_ty = arg3.map(resolve); + match fun { + Mf::Abs => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + let good = match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + scalar.kind != Sk::Bool + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + } + Mf::Min | Mf::Max => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let good = match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + scalar.kind != Sk::Bool + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Clamp => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let good = match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { + scalar.kind != Sk::Bool + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + if arg2_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + Mf::Saturate + | Mf::Cos + | Mf::Cosh + | Mf::Sin + | Mf::Sinh + | Mf::Tan + | Mf::Tanh + | Mf::Acos + | Mf::Asin + | Mf::Atan + | Mf::Asinh + | Mf::Acosh + | Mf::Atanh + | Mf::Radians + | Mf::Degrees + | Mf::Ceil + | Mf::Floor + | Mf::Round + | Mf::Fract + | Mf::Trunc + | Mf::Exp + | Mf::Exp2 + | Mf::Log + | Mf::Log2 + | Mf::Length + | Mf::Sqrt + | Mf::InverseSqrt => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Sign => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float | Sk::Sint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float | Sk::Sint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Modf | Mf::Frexp => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + if !matches!(*arg_ty, + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } + if scalar.kind == Sk::Float) + { + return Err(ExpressionError::InvalidArgumentType(fun, 1, arg)); + } + } + Mf::Ldexp => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let size0 = match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) => None, + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + size, + } => Some(size), + _ => { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + }; + let good = match *arg1_ty { + Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true, + Ti::Vector { + size, + scalar: Sc { kind: Sk::Sint, .. }, + } if Some(size) == size0 => true, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Dot => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float | Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Outer | Mf::Cross | Mf::Reflect => { + let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), None, None) => ty1, + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + } + Mf::Refract => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + + match (arg_ty, arg2_ty) { + ( + &Ti::Vector { + scalar: + Sc { + width: vector_width, + .. + }, + .. + }, + &Ti::Scalar(Sc { + width: scalar_width, + kind: Sk::Float, + }), + ) if vector_width == scalar_width => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + } + Mf::Normalize => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::FaceForward | Mf::Fma | Mf::SmoothStep => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float, .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float, .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + if arg2_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + Mf::Mix => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let arg_width = match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Float, + width, + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Float, + width, + }, + .. + } => width, + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + }; + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + // the last argument can always be a scalar + match *arg2_ty { + Ti::Scalar(Sc { + kind: Sk::Float, + width, + }) if width == arg_width => {} + _ if arg2_ty == arg_ty => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )); + } + } + } + Mf::Inverse | Mf::Determinant => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + let good = match *arg_ty { + Ti::Matrix { columns, rows, .. } => columns == rows, + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); + } + } + Mf::Transpose => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Matrix { .. } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::CountTrailingZeros + | Mf::CountLeadingZeros + | Mf::CountOneBits + | Mf::ReverseBits + | Mf::FindLsb + | Mf::FindMsb => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::InsertBits => { + let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + if arg1_ty != arg_ty { + return Err(ExpressionError::InvalidArgumentType( + fun, + 1, + arg1.unwrap(), + )); + } + match *arg2_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + match *arg3_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg3.unwrap(), + )) + } + } + } + Mf::ExtractBits => { + let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { + (Some(ty1), Some(ty2), None) => (ty1, ty2), + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + match *arg_ty { + Ti::Scalar(Sc { + kind: Sk::Sint | Sk::Uint, + .. + }) + | Ti::Vector { + scalar: + Sc { + kind: Sk::Sint | Sk::Uint, + .. + }, + .. + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + match *arg1_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg1.unwrap(), + )) + } + } + match *arg2_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => { + return Err(ExpressionError::InvalidArgumentType( + fun, + 2, + arg2.unwrap(), + )) + } + } + } + Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + size: crate::VectorSize::Bi, + scalar: + Sc { + kind: Sk::Float, .. + }, + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Pack4x8snorm | Mf::Pack4x8unorm => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Vector { + size: crate::VectorSize::Quad, + scalar: + Sc { + kind: Sk::Float, .. + }, + } => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + Mf::Unpack2x16float + | Mf::Unpack2x16snorm + | Mf::Unpack2x16unorm + | Mf::Unpack4x8snorm + | Mf::Unpack4x8unorm => { + if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { + return Err(ExpressionError::WrongArgumentCount(fun)); + } + match *arg_ty { + Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} + _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), + } + } + } + ShaderStages::all() + } + E::As { + expr, + kind, + convert, + } => { + let mut base_scalar = match resolver[expr] { + crate::TypeInner::Scalar(scalar) | crate::TypeInner::Vector { scalar, .. } => { + scalar + } + crate::TypeInner::Matrix { scalar, .. } => scalar, + _ => return Err(ExpressionError::InvalidCastArgument), + }; + base_scalar.kind = kind; + if let Some(width) = convert { + base_scalar.width = width; + } + if self.check_width(base_scalar).is_err() { + return Err(ExpressionError::InvalidCastArgument); + } + ShaderStages::all() + } + E::CallResult(function) => mod_info.functions[function.index()].available_stages, + E::AtomicResult { ty, comparison } => { + let scalar_predicate = |ty: &crate::TypeInner| match ty { + &crate::TypeInner::Scalar( + scalar @ Sc { + kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint, + .. + }, + ) => self.check_width(scalar).is_ok(), + _ => false, + }; + let good = match &module.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + &module.types, + members, + scalar_predicate, + ) + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidAtomicResultType(ty)); + } + ShaderStages::all() + } + E::WorkGroupUniformLoadResult { ty } => { + if self.types[ty.index()] + .flags + // Sized | Constructible is exactly the types currently supported by + // WorkGroupUniformLoad + .contains(TypeFlags::SIZED | TypeFlags::CONSTRUCTIBLE) + { + ShaderStages::COMPUTE + } else { + return Err(ExpressionError::InvalidWorkGroupUniformLoadResultType(ty)); + } + } + E::ArrayLength(expr) => match resolver[expr] { + Ti::Pointer { base, .. } => { + let base_ty = &resolver.types[base]; + if let Ti::Array { + size: crate::ArraySize::Dynamic, + .. + } = base_ty.inner + { + ShaderStages::all() + } else { + return Err(ExpressionError::InvalidArrayType(expr)); + } + } + ref other => { + log::error!("Array length of {:?}", other); + return Err(ExpressionError::InvalidArrayType(expr)); + } + }, + E::RayQueryProceedResult => ShaderStages::all(), + E::RayQueryGetIntersection { + query, + committed: _, + } => match resolver[query] { + Ti::Pointer { + base, + space: crate::AddressSpace::Function, + } => match resolver.types[base].inner { + Ti::RayQuery => ShaderStages::all(), + ref other => { + log::error!("Intersection result of a pointer to {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, + ref other => { + log::error!("Intersection result of {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, + }; + Ok(stages) + } + + fn global_var_ty( + module: &crate::Module, + function: &crate::Function, + expr: Handle<crate::Expression>, + ) -> Result<Handle<crate::Type>, ExpressionError> { + use crate::Expression as Ex; + + match function.expressions[expr] { + Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty), + Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty), + Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => { + match function.expressions[base] { + Ex::GlobalVariable(var_handle) => { + let array_ty = module.global_variables[var_handle].ty; + + match module.types[array_ty].inner { + crate::TypeInner::BindingArray { base, .. } => Ok(base), + _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)), + } + } + _ => Err(ExpressionError::ExpectedGlobalVariable), + } + } + _ => Err(ExpressionError::ExpectedGlobalVariable), + } + } + + pub fn validate_literal(&self, literal: crate::Literal) -> Result<(), LiteralError> { + self.check_width(literal.scalar())?; + check_literal_value(literal)?; + + Ok(()) + } +} + +pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> { + let is_nan = match literal { + crate::Literal::F64(v) => v.is_nan(), + crate::Literal::F32(v) => v.is_nan(), + _ => false, + }; + if is_nan { + return Err(LiteralError::NaN); + } + + let is_infinite = match literal { + crate::Literal::F64(v) => v.is_infinite(), + crate::Literal::F32(v) => v.is_infinite(), + _ => false, + }; + if is_infinite { + return Err(LiteralError::Infinity); + } + + Ok(()) +} + +#[cfg(all(test, feature = "validate"))] +/// Validate a module containing the given expression, expecting an error. +fn validate_with_expression( + expr: crate::Expression, + caps: super::Capabilities, +) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> { + use crate::span::Span; + + let mut function = crate::Function::default(); + function.expressions.append(expr, Span::default()); + function.body.push( + crate::Statement::Emit(function.expressions.range_from(0)), + Span::default(), + ); + + let mut module = crate::Module::default(); + module.functions.append(function, Span::default()); + + let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps); + + validator.validate(&module) +} + +#[cfg(all(test, feature = "validate"))] +/// Validate a module containing the given constant expression, expecting an error. +fn validate_with_const_expression( + expr: crate::Expression, + caps: super::Capabilities, +) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> { + use crate::span::Span; + + let mut module = crate::Module::default(); + module.const_expressions.append(expr, Span::default()); + + let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps); + + validator.validate(&module) +} + +/// Using F64 in a function's expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn f64_runtime_literals() { + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::Function { + source: super::FunctionError::Expression { + source: super::ExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + } + ),), + .. + }, + .. + } + )); + + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default() | super::Capabilities::FLOAT64, + ); + assert!(result.is_ok()); +} + +/// Using F64 in a module's constant expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn f64_const_literals() { + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::ConstExpression { + source: super::ConstExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + } + )), + .. + } + )); + + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default() | super::Capabilities::FLOAT64, + ); + assert!(result.is_ok()); +} + +/// Using I64 in a function's expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn i64_runtime_literals() { + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::I64(1729)), + // There is no capability that enables this. + super::Capabilities::all(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::Function { + source: super::FunctionError::Expression { + source: super::ExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::Unsupported64Bit + ),), + .. + }, + .. + } + )); +} + +/// Using I64 in a module's constant expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn i64_const_literals() { + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::I64(1729)), + // There is no capability that enables this. + super::Capabilities::all(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::ConstExpression { + source: super::ConstExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::Unsupported64Bit, + ),), + .. + } + )); +} diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs new file mode 100644 index 0000000000..f0ca22cbda --- /dev/null +++ b/third_party/rust/naga/src/valid/function.rs @@ -0,0 +1,1056 @@ +use crate::arena::Handle; +use crate::arena::{Arena, UniqueArena}; + +use super::validate_atomic_compare_exchange_struct; + +use super::{ + analyzer::{UniformityDisruptor, UniformityRequirements}, + ExpressionError, FunctionInfo, ModuleInfo, +}; +use crate::span::WithSpan; +use crate::span::{AddSpan as _, MapErrWithSpan as _}; + +use bit_set::BitSet; + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum CallError { + #[error("Argument {index} expression is invalid")] + Argument { + index: usize, + source: ExpressionError, + }, + #[error("Result expression {0:?} has already been introduced earlier")] + ResultAlreadyInScope(Handle<crate::Expression>), + #[error("Result value is invalid")] + ResultValue(#[source] ExpressionError), + #[error("Requires {required} arguments, but {seen} are provided")] + ArgumentCount { required: usize, seen: usize }, + #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")] + ArgumentType { + index: usize, + required: Handle<crate::Type>, + seen_expression: Handle<crate::Expression>, + }, + #[error("The emitted expression doesn't match the call")] + ExpressionMismatch(Option<Handle<crate::Expression>>), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum AtomicError { + #[error("Pointer {0:?} to atomic is invalid.")] + InvalidPointer(Handle<crate::Expression>), + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle<crate::Expression>), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle<crate::Expression>), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum LocalVariableError { + #[error("Local variable has a type {0:?} that can't be stored in a local variable.")] + InvalidType(Handle<crate::Type>), + #[error("Initializer doesn't match the variable type")] + InitializerType, + #[error("Initializer is not const")] + NonConstInitializer, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum FunctionError { + #[error("Expression {handle:?} is invalid")] + Expression { + handle: Handle<crate::Expression>, + source: ExpressionError, + }, + #[error("Expression {0:?} can't be introduced - it's already in scope")] + ExpressionAlreadyInScope(Handle<crate::Expression>), + #[error("Local variable {handle:?} '{name}' is invalid")] + LocalVariable { + handle: Handle<crate::LocalVariable>, + name: String, + source: LocalVariableError, + }, + #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")] + InvalidArgumentType { index: usize, name: String }, + #[error("The function's given return type cannot be returned from functions")] + NonConstructibleReturnType, + #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")] + InvalidArgumentPointerSpace { + index: usize, + name: String, + space: crate::AddressSpace, + }, + #[error("There are instructions after `return`/`break`/`continue`")] + InstructionsAfterReturn, + #[error("The `break` is used outside of a `loop` or `switch` context")] + BreakOutsideOfLoopOrSwitch, + #[error("The `continue` is used outside of a `loop` context")] + ContinueOutsideOfLoop, + #[error("The `return` is called within a `continuing` block")] + InvalidReturnSpot, + #[error("The `return` value {0:?} does not match the function return value")] + InvalidReturnType(Option<Handle<crate::Expression>>), + #[error("The `if` condition {0:?} is not a boolean scalar")] + InvalidIfType(Handle<crate::Expression>), + #[error("The `switch` value {0:?} is not an integer scalar")] + InvalidSwitchType(Handle<crate::Expression>), + #[error("Multiple `switch` cases for {0:?} are present")] + ConflictingSwitchCase(crate::SwitchValue), + #[error("The `switch` contains cases with conflicting types")] + ConflictingCaseType, + #[error("The `switch` is missing a `default` case")] + MissingDefaultCase, + #[error("Multiple `default` cases are present")] + MultipleDefaultCases, + #[error("The last `switch` case contains a `fallthrough`")] + LastCaseFallTrough, + #[error("The pointer {0:?} doesn't relate to a valid destination for a store")] + InvalidStorePointer(Handle<crate::Expression>), + #[error("The value {0:?} can not be stored")] + InvalidStoreValue(Handle<crate::Expression>), + #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")] + InvalidStoreTypes { + pointer: Handle<crate::Expression>, + value: Handle<crate::Expression>, + }, + #[error("Image store parameters are invalid")] + InvalidImageStore(#[source] ExpressionError), + #[error("Call to {function:?} is invalid")] + InvalidCall { + function: Handle<crate::Function>, + #[source] + error: CallError, + }, + #[error("Atomic operation is invalid")] + InvalidAtomic(#[from] AtomicError), + #[error("Ray Query {0:?} is not a local variable")] + InvalidRayQueryExpression(Handle<crate::Expression>), + #[error("Acceleration structure {0:?} is not a matching expression")] + InvalidAccelerationStructure(Handle<crate::Expression>), + #[error("Ray descriptor {0:?} is not a matching expression")] + InvalidRayDescriptor(Handle<crate::Expression>), + #[error("Ray Query {0:?} does not have a matching type")] + InvalidRayQueryType(Handle<crate::Type>), + #[error( + "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" + )] + NonUniformControlFlow( + UniformityRequirements, + Handle<crate::Expression>, + UniformityDisruptor, + ), + #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")] + PipelineInputRegularFunction { name: String }, + #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")] + PipelineOutputRegularFunction, + #[error("Required uniformity for WorkGroupUniformLoad is not fulfilled because of {0:?}")] + // The actual load statement will be "pointed to" by the span + NonUniformWorkgroupUniformLoad(UniformityDisruptor), + // This is only possible with a misbehaving frontend + #[error("The expression {0:?} for a WorkGroupUniformLoad isn't a WorkgroupUniformLoadResult")] + WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>), + #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] + WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>), +} + +bitflags::bitflags! { + #[repr(transparent)] + #[derive(Clone, Copy)] + struct ControlFlowAbility: u8 { + /// The control can return out of this block. + const RETURN = 0x1; + /// The control can break. + const BREAK = 0x2; + /// The control can continue. + const CONTINUE = 0x4; + } +} + +struct BlockInfo { + stages: super::ShaderStages, + finished: bool, +} + +struct BlockContext<'a> { + abilities: ControlFlowAbility, + info: &'a FunctionInfo, + expressions: &'a Arena<crate::Expression>, + types: &'a UniqueArena<crate::Type>, + local_vars: &'a Arena<crate::LocalVariable>, + global_vars: &'a Arena<crate::GlobalVariable>, + functions: &'a Arena<crate::Function>, + special_types: &'a crate::SpecialTypes, + prev_infos: &'a [FunctionInfo], + return_type: Option<Handle<crate::Type>>, +} + +impl<'a> BlockContext<'a> { + fn new( + fun: &'a crate::Function, + module: &'a crate::Module, + info: &'a FunctionInfo, + prev_infos: &'a [FunctionInfo], + ) -> Self { + Self { + abilities: ControlFlowAbility::RETURN, + info, + expressions: &fun.expressions, + types: &module.types, + local_vars: &fun.local_variables, + global_vars: &module.global_variables, + functions: &module.functions, + special_types: &module.special_types, + prev_infos, + return_type: fun.result.as_ref().map(|fr| fr.ty), + } + } + + const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self { + BlockContext { abilities, ..*self } + } + + fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression { + &self.expressions[handle] + } + + fn resolve_type_impl( + &self, + handle: Handle<crate::Expression>, + valid_expressions: &BitSet, + ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> { + if handle.index() >= self.expressions.len() { + Err(ExpressionError::DoesntExist.with_span()) + } else if !valid_expressions.contains(handle.index()) { + Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions)) + } else { + Ok(self.info[handle].ty.inner_with(self.types)) + } + } + + fn resolve_type( + &self, + handle: Handle<crate::Expression>, + valid_expressions: &BitSet, + ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> { + self.resolve_type_impl(handle, valid_expressions) + .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span()) + } + + fn resolve_pointer_type( + &self, + handle: Handle<crate::Expression>, + ) -> Result<&crate::TypeInner, FunctionError> { + if handle.index() >= self.expressions.len() { + Err(FunctionError::Expression { + handle, + source: ExpressionError::DoesntExist, + }) + } else { + Ok(self.info[handle].ty.inner_with(self.types)) + } + } +} + +impl super::Validator { + fn validate_call( + &mut self, + function: Handle<crate::Function>, + arguments: &[Handle<crate::Expression>], + result: Option<Handle<crate::Expression>>, + context: &BlockContext, + ) -> Result<super::ShaderStages, WithSpan<CallError>> { + let fun = &context.functions[function]; + if fun.arguments.len() != arguments.len() { + return Err(CallError::ArgumentCount { + required: fun.arguments.len(), + seen: arguments.len(), + } + .with_span()); + } + for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() { + let ty = context + .resolve_type_impl(expr, &self.valid_expression_set) + .map_err_inner(|source| { + CallError::Argument { index, source } + .with_span_handle(expr, context.expressions) + })?; + let arg_inner = &context.types[arg.ty].inner; + if !ty.equivalent(arg_inner, context.types) { + return Err(CallError::ArgumentType { + index, + required: arg.ty, + seen_expression: expr, + } + .with_span_handle(expr, context.expressions)); + } + } + + if let Some(expr) = result { + if self.valid_expression_set.insert(expr.index()) { + self.valid_expression_list.push(expr); + } else { + return Err(CallError::ResultAlreadyInScope(expr) + .with_span_handle(expr, context.expressions)); + } + match context.expressions[expr] { + crate::Expression::CallResult(callee) + if fun.result.is_some() && callee == function => {} + _ => { + return Err(CallError::ExpressionMismatch(result) + .with_span_handle(expr, context.expressions)) + } + } + } else if fun.result.is_some() { + return Err(CallError::ExpressionMismatch(result).with_span()); + } + + let callee_info = &context.prev_infos[function.index()]; + Ok(callee_info.available_stages) + } + + fn emit_expression( + &mut self, + handle: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + if self.valid_expression_set.insert(handle.index()) { + self.valid_expression_list.push(handle); + Ok(()) + } else { + Err(FunctionError::ExpressionAlreadyInScope(handle) + .with_span_handle(handle, context.expressions)) + } + } + + fn validate_atomic( + &mut self, + pointer: Handle<crate::Expression>, + fun: &crate::AtomicFunction, + value: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?; + let ptr_scalar = match *pointer_inner { + crate::TypeInner::Pointer { base, .. } => match context.types[base].inner { + crate::TypeInner::Atomic(scalar) => scalar, + ref other => { + log::error!("Atomic pointer to type {:?}", other); + return Err(AtomicError::InvalidPointer(pointer) + .with_span_handle(pointer, context.expressions) + .into_other()); + } + }, + ref other => { + log::error!("Atomic on type {:?}", other); + return Err(AtomicError::InvalidPointer(pointer) + .with_span_handle(pointer, context.expressions) + .into_other()); + } + }; + + let value_inner = context.resolve_type(value, &self.valid_expression_set)?; + match *value_inner { + crate::TypeInner::Scalar(scalar) if scalar == ptr_scalar => {} + ref other => { + log::error!("Atomic operand type {:?}", other); + return Err(AtomicError::InvalidOperand(value) + .with_span_handle(value, context.expressions) + .into_other()); + } + } + + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner { + log::error!("Atomic exchange comparison has a different type from the value"); + return Err(AtomicError::InvalidOperand(cmp) + .with_span_handle(cmp, context.expressions) + .into_other()); + } + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::AtomicResult { ty, comparison } + if { + let scalar_predicate = + |ty: &crate::TypeInner| *ty == crate::TypeInner::Scalar(ptr_scalar); + match &context.types[ty].inner { + ty if !comparison => scalar_predicate(ty), + &crate::TypeInner::Struct { ref members, .. } if comparison => { + validate_atomic_compare_exchange_struct( + context.types, + members, + scalar_predicate, + ) + } + _ => false, + } + } => {} + _ => { + return Err(AtomicError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + + fn validate_block_impl( + &mut self, + statements: &crate::Block, + context: &BlockContext, + ) -> Result<BlockInfo, WithSpan<FunctionError>> { + use crate::{AddressSpace, Statement as S, TypeInner as Ti}; + let mut finished = false; + let mut stages = super::ShaderStages::all(); + for (statement, &span) in statements.span_iter() { + if finished { + return Err(FunctionError::InstructionsAfterReturn + .with_span_static(span, "instructions after return")); + } + match *statement { + S::Emit(ref range) => { + for handle in range.clone() { + self.emit_expression(handle, context)?; + } + } + S::Block(ref block) => { + let info = self.validate_block(block, context)?; + stages &= info.stages; + finished = info.finished; + } + S::If { + condition, + ref accept, + ref reject, + } => { + match *context.resolve_type(condition, &self.valid_expression_set)? { + Ti::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + }) => {} + _ => { + return Err(FunctionError::InvalidIfType(condition) + .with_span_handle(condition, context.expressions)) + } + } + stages &= self.validate_block(accept, context)?.stages; + stages &= self.validate_block(reject, context)?.stages; + } + S::Switch { + selector, + ref cases, + } => { + let uint = match context + .resolve_type(selector, &self.valid_expression_set)? + .scalar_kind() + { + Some(crate::ScalarKind::Uint) => true, + Some(crate::ScalarKind::Sint) => false, + _ => { + return Err(FunctionError::InvalidSwitchType(selector) + .with_span_handle(selector, context.expressions)) + } + }; + self.switch_values.clear(); + for case in cases { + match case.value { + crate::SwitchValue::I32(_) if !uint => {} + crate::SwitchValue::U32(_) if uint => {} + crate::SwitchValue::Default => {} + _ => { + return Err(FunctionError::ConflictingCaseType.with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "conflicting switch arm here", + )); + } + }; + if !self.switch_values.insert(case.value) { + return Err(match case.value { + crate::SwitchValue::Default => FunctionError::MultipleDefaultCases + .with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "duplicated switch arm here", + ), + _ => FunctionError::ConflictingSwitchCase(case.value) + .with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "conflicting switch arm here", + ), + }); + } + } + if !self.switch_values.contains(&crate::SwitchValue::Default) { + return Err(FunctionError::MissingDefaultCase + .with_span_static(span, "missing default case")); + } + if let Some(case) = cases.last() { + if case.fall_through { + return Err(FunctionError::LastCaseFallTrough.with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "bad switch arm here", + )); + } + } + let pass_through_abilities = context.abilities + & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE); + let sub_context = + context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK); + for case in cases { + stages &= self.validate_block(&case.body, &sub_context)?.stages; + } + } + S::Loop { + ref body, + ref continuing, + break_if, + } => { + // special handling for block scoping is needed here, + // because the continuing{} block inherits the scope + let base_expression_count = self.valid_expression_list.len(); + let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN; + stages &= self + .validate_block_impl( + body, + &context.with_abilities( + pass_through_abilities + | ControlFlowAbility::BREAK + | ControlFlowAbility::CONTINUE, + ), + )? + .stages; + stages &= self + .validate_block_impl( + continuing, + &context.with_abilities(ControlFlowAbility::empty()), + )? + .stages; + + if let Some(condition) = break_if { + match *context.resolve_type(condition, &self.valid_expression_set)? { + Ti::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: _, + }) => {} + _ => { + return Err(FunctionError::InvalidIfType(condition) + .with_span_handle(condition, context.expressions)) + } + } + } + + for handle in self.valid_expression_list.drain(base_expression_count..) { + self.valid_expression_set.remove(handle.index()); + } + } + S::Break => { + if !context.abilities.contains(ControlFlowAbility::BREAK) { + return Err(FunctionError::BreakOutsideOfLoopOrSwitch + .with_span_static(span, "invalid break")); + } + finished = true; + } + S::Continue => { + if !context.abilities.contains(ControlFlowAbility::CONTINUE) { + return Err(FunctionError::ContinueOutsideOfLoop + .with_span_static(span, "invalid continue")); + } + finished = true; + } + S::Return { value } => { + if !context.abilities.contains(ControlFlowAbility::RETURN) { + return Err(FunctionError::InvalidReturnSpot + .with_span_static(span, "invalid return")); + } + let value_ty = value + .map(|expr| context.resolve_type(expr, &self.valid_expression_set)) + .transpose()?; + let expected_ty = context.return_type.map(|ty| &context.types[ty].inner); + // We can't return pointers, but it seems best not to embed that + // assumption here, so use `TypeInner::equivalent` for comparison. + let okay = match (value_ty, expected_ty) { + (None, None) => true, + (Some(value_inner), Some(expected_inner)) => { + value_inner.equivalent(expected_inner, context.types) + } + (_, _) => false, + }; + + if !okay { + log::error!( + "Returning {:?} where {:?} is expected", + value_ty, + expected_ty + ); + if let Some(handle) = value { + return Err(FunctionError::InvalidReturnType(value) + .with_span_handle(handle, context.expressions)); + } else { + return Err(FunctionError::InvalidReturnType(value) + .with_span_static(span, "invalid return")); + } + } + finished = true; + } + S::Kill => { + stages &= super::ShaderStages::FRAGMENT; + finished = true; + } + S::Barrier(_) => { + stages &= super::ShaderStages::COMPUTE; + } + S::Store { pointer, value } => { + let mut current = pointer; + loop { + let _ = context + .resolve_pointer_type(current) + .map_err(|e| e.with_span())?; + match context.expressions[current] { + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => current = base, + crate::Expression::LocalVariable(_) + | crate::Expression::GlobalVariable(_) + | crate::Expression::FunctionArgument(_) => break, + _ => { + return Err(FunctionError::InvalidStorePointer(current) + .with_span_handle(pointer, context.expressions)) + } + } + } + + let value_ty = context.resolve_type(value, &self.valid_expression_set)?; + match *value_ty { + Ti::Image { .. } | Ti::Sampler { .. } => { + return Err(FunctionError::InvalidStoreValue(value) + .with_span_handle(value, context.expressions)); + } + _ => {} + } + + let pointer_ty = context + .resolve_pointer_type(pointer) + .map_err(|e| e.with_span())?; + + let good = match *pointer_ty { + Ti::Pointer { base, space: _ } => match context.types[base].inner { + Ti::Atomic(scalar) => *value_ty == Ti::Scalar(scalar), + ref other => value_ty == other, + }, + Ti::ValuePointer { + size: Some(size), + scalar, + space: _, + } => *value_ty == Ti::Vector { size, scalar }, + Ti::ValuePointer { + size: None, + scalar, + space: _, + } => *value_ty == Ti::Scalar(scalar), + _ => false, + }; + if !good { + return Err(FunctionError::InvalidStoreTypes { pointer, value } + .with_span() + .with_handle(pointer, context.expressions) + .with_handle(value, context.expressions)); + } + + if let Some(space) = pointer_ty.pointer_space() { + if !space.access().contains(crate::StorageAccess::STORE) { + return Err(FunctionError::InvalidStorePointer(pointer) + .with_span_static( + context.expressions.get_span(pointer), + "writing to this location is not permitted", + )); + } + } + } + S::ImageStore { + image, + coordinate, + array_index, + value, + } => { + //Note: this code uses a lot of `FunctionError::InvalidImageStore`, + // and could probably be refactored. + let var = match *context.get_expression(image) { + crate::Expression::GlobalVariable(var_handle) => { + &context.global_vars[var_handle] + } + // We're looking at a binding index situation, so punch through the index and look at the global behind it. + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => { + match *context.get_expression(base) { + crate::Expression::GlobalVariable(var_handle) => { + &context.global_vars[var_handle] + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedGlobalVariable, + ) + .with_span_handle(image, context.expressions)) + } + } + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedGlobalVariable, + ) + .with_span_handle(image, context.expressions)) + } + }; + + // Punch through a binding array to get the underlying type + let global_ty = match context.types[var.ty].inner { + Ti::BindingArray { base, .. } => &context.types[base].inner, + ref inner => inner, + }; + + let value_ty = match *global_ty { + Ti::Image { + class, + arrayed, + dim, + } => { + match context + .resolve_type(coordinate, &self.valid_expression_set)? + .image_storage_coordinates() + { + Some(coord_dim) if coord_dim == dim => {} + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageCoordinateType( + dim, coordinate, + ), + ) + .with_span_handle(coordinate, context.expressions)); + } + }; + if arrayed != array_index.is_some() { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageArrayIndex, + ) + .with_span_handle(coordinate, context.expressions)); + } + if let Some(expr) = array_index { + match *context.resolve_type(expr, &self.valid_expression_set)? { + Ti::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + width: _, + }) => {} + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageArrayIndexType(expr), + ) + .with_span_handle(expr, context.expressions)); + } + } + } + match class { + crate::ImageClass::Storage { format, .. } => { + crate::TypeInner::Vector { + size: crate::VectorSize::Quad, + scalar: crate::Scalar { + kind: format.into(), + width: 4, + }, + } + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::InvalidImageClass(class), + ) + .with_span_handle(image, context.expressions)); + } + } + } + _ => { + return Err(FunctionError::InvalidImageStore( + ExpressionError::ExpectedImageType(var.ty), + ) + .with_span() + .with_handle(var.ty, context.types) + .with_handle(image, context.expressions)) + } + }; + + if *context.resolve_type(value, &self.valid_expression_set)? != value_ty { + return Err(FunctionError::InvalidStoreValue(value) + .with_span_handle(value, context.expressions)); + } + } + S::Call { + function, + ref arguments, + result, + } => match self.validate_call(function, arguments, result, context) { + Ok(callee_stages) => stages &= callee_stages, + Err(error) => { + return Err(error.and_then(|error| { + FunctionError::InvalidCall { function, error } + .with_span_static(span, "invalid function call") + })) + } + }, + S::Atomic { + pointer, + ref fun, + value, + result, + } => { + self.validate_atomic(pointer, fun, value, result, context)?; + } + S::WorkGroupUniformLoad { pointer, result } => { + stages &= super::ShaderStages::COMPUTE; + let pointer_inner = + context.resolve_type(pointer, &self.valid_expression_set)?; + match *pointer_inner { + Ti::Pointer { + space: AddressSpace::WorkGroup, + .. + } => {} + Ti::ValuePointer { + space: AddressSpace::WorkGroup, + .. + } => {} + _ => { + return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer) + .with_span_static(span, "WorkGroupUniformLoad")) + } + } + self.emit_expression(result, context)?; + let ty = match &context.expressions[result] { + &crate::Expression::WorkGroupUniformLoadResult { ty } => ty, + _ => { + return Err(FunctionError::WorkgroupUniformLoadExpressionMismatch( + result, + ) + .with_span_static(span, "WorkGroupUniformLoad")); + } + }; + let expected_pointer_inner = Ti::Pointer { + base: ty, + space: AddressSpace::WorkGroup, + }; + if !expected_pointer_inner.equivalent(pointer_inner, context.types) { + return Err(FunctionError::WorkgroupUniformLoadInvalidPointer(pointer) + .with_span_static(span, "WorkGroupUniformLoad")); + } + } + S::RayQuery { query, ref fun } => { + let query_var = match *context.get_expression(query) { + crate::Expression::LocalVariable(var) => &context.local_vars[var], + ref other => { + log::error!("Unexpected ray query expression {other:?}"); + return Err(FunctionError::InvalidRayQueryExpression(query) + .with_span_static(span, "invalid query expression")); + } + }; + match context.types[query_var.ty].inner { + Ti::RayQuery => {} + ref other => { + log::error!("Unexpected ray query type {other:?}"); + return Err(FunctionError::InvalidRayQueryType(query_var.ty) + .with_span_static(span, "invalid query type")); + } + } + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + match *context + .resolve_type(acceleration_structure, &self.valid_expression_set)? + { + Ti::AccelerationStructure => {} + _ => { + return Err(FunctionError::InvalidAccelerationStructure( + acceleration_structure, + ) + .with_span_static(span, "invalid acceleration structure")) + } + } + let desc_ty_given = + context.resolve_type(descriptor, &self.valid_expression_set)?; + let desc_ty_expected = context + .special_types + .ray_desc + .map(|handle| &context.types[handle].inner); + if Some(desc_ty_given) != desc_ty_expected { + return Err(FunctionError::InvalidRayDescriptor(descriptor) + .with_span_static(span, "invalid ray descriptor")); + } + } + crate::RayQueryFunction::Proceed { result } => { + self.emit_expression(result, context)?; + } + crate::RayQueryFunction::Terminate => {} + } + } + } + } + Ok(BlockInfo { stages, finished }) + } + + fn validate_block( + &mut self, + statements: &crate::Block, + context: &BlockContext, + ) -> Result<BlockInfo, WithSpan<FunctionError>> { + let base_expression_count = self.valid_expression_list.len(); + let info = self.validate_block_impl(statements, context)?; + for handle in self.valid_expression_list.drain(base_expression_count..) { + self.valid_expression_set.remove(handle.index()); + } + Ok(info) + } + + fn validate_local_var( + &self, + var: &crate::LocalVariable, + gctx: crate::proc::GlobalCtx, + fun_info: &FunctionInfo, + expression_constness: &crate::proc::ExpressionConstnessTracker, + ) -> Result<(), LocalVariableError> { + log::debug!("var {:?}", var); + let type_info = self + .types + .get(var.ty.index()) + .ok_or(LocalVariableError::InvalidType(var.ty))?; + if !type_info.flags.contains(super::TypeFlags::CONSTRUCTIBLE) { + return Err(LocalVariableError::InvalidType(var.ty)); + } + + if let Some(init) = var.init { + let decl_ty = &gctx.types[var.ty].inner; + let init_ty = fun_info[init].ty.inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(LocalVariableError::InitializerType); + } + + if !expression_constness.is_const(init) { + return Err(LocalVariableError::NonConstInitializer); + } + } + + Ok(()) + } + + pub(super) fn validate_function( + &mut self, + fun: &crate::Function, + module: &crate::Module, + mod_info: &ModuleInfo, + entry_point: bool, + ) -> Result<FunctionInfo, WithSpan<FunctionError>> { + let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; + + let expression_constness = + crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + + for (var_handle, var) in fun.local_variables.iter() { + self.validate_local_var(var, module.to_ctx(), &info, &expression_constness) + .map_err(|source| { + FunctionError::LocalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(var.ty, &module.types) + .with_handle(var_handle, &fun.local_variables) + })?; + } + + for (index, argument) in fun.arguments.iter().enumerate() { + match module.types[argument.ty].inner.pointer_space() { + Some(crate::AddressSpace::Private | crate::AddressSpace::Function) | None => {} + Some(other) => { + return Err(FunctionError::InvalidArgumentPointerSpace { + index, + name: argument.name.clone().unwrap_or_default(), + space: other, + } + .with_span_handle(argument.ty, &module.types)) + } + } + // Check for the least informative error last. + if !self.types[argument.ty.index()] + .flags + .contains(super::TypeFlags::ARGUMENT) + { + return Err(FunctionError::InvalidArgumentType { + index, + name: argument.name.clone().unwrap_or_default(), + } + .with_span_handle(argument.ty, &module.types)); + } + + if !entry_point && argument.binding.is_some() { + return Err(FunctionError::PipelineInputRegularFunction { + name: argument.name.clone().unwrap_or_default(), + } + .with_span_handle(argument.ty, &module.types)); + } + } + + if let Some(ref result) = fun.result { + if !self.types[result.ty.index()] + .flags + .contains(super::TypeFlags::CONSTRUCTIBLE) + { + return Err(FunctionError::NonConstructibleReturnType + .with_span_handle(result.ty, &module.types)); + } + + if !entry_point && result.binding.is_some() { + return Err(FunctionError::PipelineOutputRegularFunction + .with_span_handle(result.ty, &module.types)); + } + } + + self.valid_expression_set.clear(); + self.valid_expression_list.clear(); + for (handle, expr) in fun.expressions.iter() { + if expr.needs_pre_emit() { + self.valid_expression_set.insert(handle.index()); + } + if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { + match self.validate_expression(handle, expr, fun, module, &info, mod_info) { + Ok(stages) => info.available_stages &= stages, + Err(source) => { + return Err(FunctionError::Expression { handle, source } + .with_span_handle(handle, &fun.expressions)) + } + } + } + } + + if self.flags.contains(super::ValidationFlags::BLOCKS) { + let stages = self + .validate_block( + &fun.body, + &BlockContext::new(fun, module, &info, &mod_info.functions), + )? + .stages; + info.available_stages &= stages; + } + Ok(info) + } +} diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs new file mode 100644 index 0000000000..e482f293bb --- /dev/null +++ b/third_party/rust/naga/src/valid/handles.rs @@ -0,0 +1,699 @@ +//! Implementation of `Validator::validate_module_handles`. + +use crate::{ + arena::{BadHandle, BadRangeError}, + Handle, +}; + +use crate::{Arena, UniqueArena}; + +use super::ValidationError; + +use std::{convert::TryInto, hash::Hash, num::NonZeroU32}; + +impl super::Validator { + /// Validates that all handles within `module` are: + /// + /// * Valid, in the sense that they contain indices within each arena structure inside the + /// [`crate::Module`] type. + /// * No arena contents contain any items that have forward dependencies; that is, the value + /// associated with a handle only may contain references to handles in the same arena that + /// were constructed before it. + /// + /// By validating the above conditions, we free up subsequent logic to assume that handle + /// accesses are infallible. + /// + /// # Errors + /// + /// Errors returned by this method are intentionally sparse, for simplicity of implementation. + /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this + /// validation pass. + pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { + let &crate::Module { + ref constants, + ref entry_points, + ref functions, + ref global_variables, + ref types, + ref special_types, + ref const_expressions, + } = module; + + // NOTE: Types being first is important. All other forms of validation depend on this. + for (this_handle, ty) in types.iter() { + match ty.inner { + crate::TypeInner::Scalar { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::ValuePointer { .. } + | crate::TypeInner::Atomic { .. } + | crate::TypeInner::Image { .. } + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => (), + crate::TypeInner::Pointer { base, space: _ } => { + this_handle.check_dep(base)?; + } + crate::TypeInner::Array { base, .. } + | crate::TypeInner::BindingArray { base, .. } => { + this_handle.check_dep(base)?; + } + crate::TypeInner::Struct { + ref members, + span: _, + } => { + this_handle.check_dep_iter(members.iter().map(|m| m.ty))?; + } + } + } + + for handle_and_expr in const_expressions.iter() { + Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + } + + let validate_type = |handle| Self::validate_type_handle(handle, types); + let validate_const_expr = + |handle| Self::validate_expression_handle(handle, const_expressions); + + for (_handle, constant) in constants.iter() { + let &crate::Constant { + name: _, + r#override: _, + ty, + init, + } = constant; + validate_type(ty)?; + validate_const_expr(init)?; + } + + for (_handle, global_variable) in global_variables.iter() { + let &crate::GlobalVariable { + name: _, + space: _, + binding: _, + ty, + init, + } = global_variable; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } + } + + let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> { + let &crate::Function { + name: _, + ref arguments, + ref result, + ref local_variables, + ref expressions, + ref named_expressions, + ref body, + } = function; + + for arg in arguments.iter() { + let &crate::FunctionArgument { + name: _, + ty, + binding: _, + } = arg; + validate_type(ty)?; + } + + if let &Some(crate::FunctionResult { ty, binding: _ }) = result { + validate_type(ty)?; + } + + for (_handle, local_variable) in local_variables.iter() { + let &crate::LocalVariable { name: _, ty, init } = local_variable; + validate_type(ty)?; + if let Some(init) = init { + Self::validate_expression_handle(init, expressions)?; + } + } + + for handle in named_expressions.keys().copied() { + Self::validate_expression_handle(handle, expressions)?; + } + + for handle_and_expr in expressions.iter() { + Self::validate_expression_handles( + handle_and_expr, + constants, + const_expressions, + types, + local_variables, + global_variables, + functions, + function_handle, + )?; + } + + Self::validate_block_handles(body, expressions, functions)?; + + Ok(()) + }; + + for entry_point in entry_points.iter() { + validate_function(None, &entry_point.function)?; + } + + for (function_handle, function) in functions.iter() { + validate_function(Some(function_handle), function)?; + } + + if let Some(ty) = special_types.ray_desc { + validate_type(ty)?; + } + if let Some(ty) = special_types.ray_intersection { + validate_type(ty)?; + } + + Ok(()) + } + + fn validate_type_handle( + handle: Handle<crate::Type>, + types: &UniqueArena<crate::Type>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for_uniq(types).map(|_| ()) + } + + fn validate_constant_handle( + handle: Handle<crate::Constant>, + constants: &Arena<crate::Constant>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(constants).map(|_| ()) + } + + fn validate_expression_handle( + handle: Handle<crate::Expression>, + expressions: &Arena<crate::Expression>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(expressions).map(|_| ()) + } + + fn validate_function_handle( + handle: Handle<crate::Function>, + functions: &Arena<crate::Function>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(functions).map(|_| ()) + } + + fn validate_const_expression_handles( + (handle, expression): (Handle<crate::Expression>, &crate::Expression), + constants: &Arena<crate::Constant>, + types: &UniqueArena<crate::Type>, + ) -> Result<(), InvalidHandleError> { + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_type = |handle| Self::validate_type_handle(handle, types); + + match *expression { + crate::Expression::Literal(_) => {} + crate::Expression::Constant(constant) => { + validate_constant(constant)?; + handle.check_dep(constants[constant].init)?; + } + crate::Expression::ZeroValue(ty) => { + validate_type(ty)?; + } + crate::Expression::Compose { ty, ref components } => { + validate_type(ty)?; + handle.check_dep_iter(components.iter().copied())?; + } + _ => {} + } + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + fn validate_expression_handles( + (handle, expression): (Handle<crate::Expression>, &crate::Expression), + constants: &Arena<crate::Constant>, + const_expressions: &Arena<crate::Expression>, + types: &UniqueArena<crate::Type>, + local_variables: &Arena<crate::LocalVariable>, + global_variables: &Arena<crate::GlobalVariable>, + functions: &Arena<crate::Function>, + // The handle of the current function or `None` if it's an entry point + current_function: Option<Handle<crate::Function>>, + ) -> Result<(), InvalidHandleError> { + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_const_expr = + |handle| Self::validate_expression_handle(handle, const_expressions); + let validate_type = |handle| Self::validate_type_handle(handle, types); + + match *expression { + crate::Expression::Access { base, index } => { + handle.check_dep(base)?.check_dep(index)?; + } + crate::Expression::AccessIndex { base, .. } => { + handle.check_dep(base)?; + } + crate::Expression::Splat { value, .. } => { + handle.check_dep(value)?; + } + crate::Expression::Swizzle { vector, .. } => { + handle.check_dep(vector)?; + } + crate::Expression::Literal(_) => {} + crate::Expression::Constant(constant) => { + validate_constant(constant)?; + } + crate::Expression::ZeroValue(ty) => { + validate_type(ty)?; + } + crate::Expression::Compose { ty, ref components } => { + validate_type(ty)?; + handle.check_dep_iter(components.iter().copied())?; + } + crate::Expression::FunctionArgument(_arg_idx) => (), + crate::Expression::GlobalVariable(global_variable) => { + global_variable.check_valid_for(global_variables)?; + } + crate::Expression::LocalVariable(local_variable) => { + local_variable.check_valid_for(local_variables)?; + } + crate::Expression::Load { pointer } => { + handle.check_dep(pointer)?; + } + crate::Expression::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + if let Some(offset) = offset { + validate_const_expr(offset)?; + } + + handle + .check_dep(image)? + .check_dep(sampler)? + .check_dep(coordinate)? + .check_dep_opt(array_index)?; + + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => (), + crate::SampleLevel::Exact(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Bias(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Gradient { x, y } => { + handle.check_dep(x)?.check_dep(y)?; + } + }; + + handle.check_dep_opt(depth_ref)?; + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + handle + .check_dep(image)? + .check_dep(coordinate)? + .check_dep_opt(array_index)? + .check_dep_opt(sample)? + .check_dep_opt(level)?; + } + crate::Expression::ImageQuery { image, query } => { + handle.check_dep(image)?; + match query { + crate::ImageQuery::Size { level } => { + handle.check_dep_opt(level)?; + } + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => (), + }; + } + crate::Expression::Unary { + op: _, + expr: operand, + } => { + handle.check_dep(operand)?; + } + crate::Expression::Binary { op: _, left, right } => { + handle.check_dep(left)?.check_dep(right)?; + } + crate::Expression::Select { + condition, + accept, + reject, + } => { + handle + .check_dep(condition)? + .check_dep(accept)? + .check_dep(reject)?; + } + crate::Expression::Derivative { expr: argument, .. } => { + handle.check_dep(argument)?; + } + crate::Expression::Relational { fun: _, argument } => { + handle.check_dep(argument)?; + } + crate::Expression::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + handle + .check_dep(arg)? + .check_dep_opt(arg1)? + .check_dep_opt(arg2)? + .check_dep_opt(arg3)?; + } + crate::Expression::As { + expr: input, + kind: _, + convert: _, + } => { + handle.check_dep(input)?; + } + crate::Expression::CallResult(function) => { + Self::validate_function_handle(function, functions)?; + if let Some(handle) = current_function { + handle.check_dep(function)?; + } + } + crate::Expression::AtomicResult { .. } + | crate::Expression::RayQueryProceedResult + | crate::Expression::WorkGroupUniformLoadResult { .. } => (), + crate::Expression::ArrayLength(array) => { + handle.check_dep(array)?; + } + crate::Expression::RayQueryGetIntersection { + query, + committed: _, + } => { + handle.check_dep(query)?; + } + } + Ok(()) + } + + fn validate_block_handles( + block: &crate::Block, + expressions: &Arena<crate::Expression>, + functions: &Arena<crate::Function>, + ) -> Result<(), InvalidHandleError> { + let validate_block = |block| Self::validate_block_handles(block, expressions, functions); + let validate_expr = |handle| Self::validate_expression_handle(handle, expressions); + let validate_expr_opt = |handle_opt| { + if let Some(handle) = handle_opt { + validate_expr(handle)?; + } + Ok(()) + }; + + block.iter().try_for_each(|stmt| match *stmt { + crate::Statement::Emit(ref expr_range) => { + expr_range.check_valid_for(expressions)?; + Ok(()) + } + crate::Statement::Block(ref block) => { + validate_block(block)?; + Ok(()) + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + validate_expr(condition)?; + validate_block(accept)?; + validate_block(reject)?; + Ok(()) + } + crate::Statement::Switch { + selector, + ref cases, + } => { + validate_expr(selector)?; + for &crate::SwitchCase { + value: _, + ref body, + fall_through: _, + } in cases + { + validate_block(body)?; + } + Ok(()) + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + validate_block(body)?; + validate_block(continuing)?; + validate_expr_opt(break_if)?; + Ok(()) + } + crate::Statement::Return { value } => validate_expr_opt(value), + crate::Statement::Store { pointer, value } => { + validate_expr(pointer)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + validate_expr(image)?; + validate_expr(coordinate)?; + validate_expr_opt(array_index)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::Atomic { + pointer, + fun, + value, + result, + } => { + validate_expr(pointer)?; + match fun { + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max => (), + crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?, + }; + validate_expr(value)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::WorkGroupUniformLoad { pointer, result } => { + validate_expr(pointer)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::Call { + function, + ref arguments, + result, + } => { + Self::validate_function_handle(function, functions)?; + for arg in arguments.iter().copied() { + validate_expr(arg)?; + } + validate_expr_opt(result)?; + Ok(()) + } + crate::Statement::RayQuery { query, ref fun } => { + validate_expr(query)?; + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + validate_expr(acceleration_structure)?; + validate_expr(descriptor)?; + } + crate::RayQueryFunction::Proceed { result } => { + validate_expr(result)?; + } + crate::RayQueryFunction::Terminate => {} + } + Ok(()) + } + crate::Statement::Break + | crate::Statement::Continue + | crate::Statement::Kill + | crate::Statement::Barrier(_) => Ok(()), + }) + } +} + +impl From<BadHandle> for ValidationError { + fn from(source: BadHandle) -> Self { + Self::InvalidHandle(source.into()) + } +} + +impl From<FwdDepError> for ValidationError { + fn from(source: FwdDepError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +impl From<BadRangeError> for ValidationError { + fn from(source: BadRangeError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum InvalidHandleError { + #[error(transparent)] + BadHandle(#[from] BadHandle), + #[error(transparent)] + ForwardDependency(#[from] FwdDepError), + #[error(transparent)] + BadRange(#[from] BadRangeError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[error( + "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \ + which has not been processed yet" +)] +pub struct FwdDepError { + // This error is used for many `Handle` types, but there's no point in making this generic, so + // we just flatten them all to `Handle<()>` here. + subject: Handle<()>, + subject_kind: &'static str, + depends_on: Handle<()>, + depends_on_kind: &'static str, +} + +impl<T> Handle<T> { + /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`]. + pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`]. + pub(self) fn check_valid_for_uniq( + self, + arena: &UniqueArena<T>, + ) -> Result<(), InvalidHandleError> + where + T: Eq + Hash, + { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `depends_on` was constructed before `self` by comparing handle indices. + /// + /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`]) + /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid. + /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating + /// recursive definitions of arena-based values in linear time. + /// + /// # Errors + /// + /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier + /// than `self`'s, this function returns an error. + pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> { + if depends_on < self { + Ok(self) + } else { + let erase_handle_type = |handle: Handle<_>| { + Handle::new(NonZeroU32::new((handle.index() + 1).try_into().unwrap()).unwrap()) + }; + Err(FwdDepError { + subject: erase_handle_type(self), + subject_kind: std::any::type_name::<T>(), + depends_on: erase_handle_type(depends_on), + depends_on_kind: std::any::type_name::<T>(), + }) + } + } + + /// Like [`Self::check_dep`], except for [`Option`]al handle values. + pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> { + self.check_dep_iter(depends_on.into_iter()) + } + + /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values. + pub(self) fn check_dep_iter( + self, + depends_on: impl Iterator<Item = Self>, + ) -> Result<Self, FwdDepError> { + for handle in depends_on { + self.check_dep(handle)?; + } + Ok(self) + } +} + +impl<T> crate::arena::Range<T> { + pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> { + arena.check_contains_range(self) + } +} + +#[test] +fn constant_deps() { + use crate::{Constant, Expression, Literal, Span, Type, TypeInner}; + + let nowhere = Span::default(); + + let mut types = UniqueArena::new(); + let mut const_exprs = Arena::new(); + let mut fun_exprs = Arena::new(); + let mut constants = Arena::new(); + + let i32_handle = types.insert( + Type { + name: None, + inner: TypeInner::Scalar(crate::Scalar::I32), + }, + nowhere, + ); + + // Construct a self-referential constant by misusing a handle to + // fun_exprs as a constant initializer. + let fun_expr = fun_exprs.append(Expression::Literal(Literal::I32(42)), nowhere); + let self_referential_const = constants.append( + Constant { + name: None, + r#override: crate::Override::None, + ty: i32_handle, + init: fun_expr, + }, + nowhere, + ); + let _self_referential_expr = + const_exprs.append(Expression::Constant(self_referential_const), nowhere); + + for handle_and_expr in const_exprs.iter() { + assert!(super::Validator::validate_const_expression_handles( + handle_and_expr, + &constants, + &types, + ) + .is_err()); + } +} diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs new file mode 100644 index 0000000000..84c8b09ddb --- /dev/null +++ b/third_party/rust/naga/src/valid/interface.rs @@ -0,0 +1,709 @@ +use super::{ + analyzer::{FunctionInfo, GlobalUse}, + Capabilities, Disalignment, FunctionError, ModuleInfo, +}; +use crate::arena::{Handle, UniqueArena}; + +use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan}; +use bit_set::BitSet; + +const MAX_WORKGROUP_SIZE: u32 = 0x4000; + +#[derive(Clone, Debug, thiserror::Error)] +pub enum GlobalVariableError { + #[error("Usage isn't compatible with address space {0:?}")] + InvalidUsage(crate::AddressSpace), + #[error("Type isn't compatible with address space {0:?}")] + InvalidType(crate::AddressSpace), + #[error("Type flags {seen:?} do not meet the required {required:?}")] + MissingTypeFlags { + required: super::TypeFlags, + seen: super::TypeFlags, + }, + #[error("Capability {0:?} is not supported")] + UnsupportedCapability(Capabilities), + #[error("Binding decoration is missing or not applicable")] + InvalidBinding, + #[error("Alignment requirements for address space {0:?} are not met by {1:?}")] + Alignment( + crate::AddressSpace, + Handle<crate::Type>, + #[source] Disalignment, + ), + #[error("Initializer doesn't match the variable type")] + InitializerType, + #[error("Initializer can't be used with address space {0:?}")] + InitializerNotAllowed(crate::AddressSpace), + #[error("Storage address space doesn't support write-only access")] + StorageAddressSpaceWriteOnlyNotSupported, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum VaryingError { + #[error("The type {0:?} does not match the varying")] + InvalidType(Handle<crate::Type>), + #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")] + NotIOShareableType(Handle<crate::Type>), + #[error("Interpolation is not valid")] + InvalidInterpolation, + #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")] + MissingInterpolation, + #[error("Built-in {0:?} is not available at this stage")] + InvalidBuiltInStage(crate::BuiltIn), + #[error("Built-in type for {0:?} is invalid")] + InvalidBuiltInType(crate::BuiltIn), + #[error("Entry point arguments and return values must all have bindings")] + MissingBinding, + #[error("Struct member {0} is missing a binding")] + MemberMissingBinding(u32), + #[error("Multiple bindings at location {location} are present")] + BindingCollision { location: u32 }, + #[error("Built-in {0:?} is present more than once")] + DuplicateBuiltIn(crate::BuiltIn), + #[error("Capability {0:?} is not supported")] + UnsupportedCapability(Capabilities), + #[error("The attribute {0:?} is only valid as an output for stage {1:?}")] + InvalidInputAttributeInStage(&'static str, crate::ShaderStage), + #[error("The attribute {0:?} is not valid for stage {1:?}")] + InvalidAttributeInStage(&'static str, crate::ShaderStage), + #[error( + "The location index {location} cannot be used together with the attribute {attribute:?}" + )] + InvalidLocationAttributeCombination { + location: u32, + attribute: &'static str, + }, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum EntryPointError { + #[error("Multiple conflicting entry points")] + Conflict, + #[error("Vertex shaders must return a `@builtin(position)` output value")] + MissingVertexOutputPosition, + #[error("Early depth test is not applicable")] + UnexpectedEarlyDepthTest, + #[error("Workgroup size is not applicable")] + UnexpectedWorkgroupSize, + #[error("Workgroup size is out of range")] + OutOfRangeWorkgroupSize, + #[error("Uses operations forbidden at this stage")] + ForbiddenStageOperations, + #[error("Global variable {0:?} is used incorrectly as {1:?}")] + InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse), + #[error("More than 1 push constant variable is used")] + MoreThanOnePushConstantUsed, + #[error("Bindings for {0:?} conflict with other resource")] + BindingCollision(Handle<crate::GlobalVariable>), + #[error("Argument {0} varying error")] + Argument(u32, #[source] VaryingError), + #[error(transparent)] + Result(#[from] VaryingError), + #[error("Location {location} interpolation of an integer has to be flat")] + InvalidIntegerInterpolation { location: u32 }, + #[error(transparent)] + Function(#[from] FunctionError), + #[error( + "Invalid locations {location_mask:?} are set while dual source blending. Only location 0 may be set." + )] + InvalidLocationsWhileDualSourceBlending { location_mask: BitSet }, +} + +fn storage_usage(access: crate::StorageAccess) -> GlobalUse { + let mut storage_usage = GlobalUse::QUERY; + if access.contains(crate::StorageAccess::LOAD) { + storage_usage |= GlobalUse::READ; + } + if access.contains(crate::StorageAccess::STORE) { + storage_usage |= GlobalUse::WRITE; + } + storage_usage +} + +struct VaryingContext<'a> { + stage: crate::ShaderStage, + output: bool, + second_blend_source: bool, + types: &'a UniqueArena<crate::Type>, + type_info: &'a Vec<super::r#type::TypeInfo>, + location_mask: &'a mut BitSet, + built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>, + capabilities: Capabilities, + flags: super::ValidationFlags, +} + +impl VaryingContext<'_> { + fn validate_impl( + &mut self, + ty: Handle<crate::Type>, + binding: &crate::Binding, + ) -> Result<(), VaryingError> { + use crate::{BuiltIn as Bi, ShaderStage as St, TypeInner as Ti, VectorSize as Vs}; + + let ty_inner = &self.types[ty].inner; + match *binding { + crate::Binding::BuiltIn(built_in) => { + // Ignore the `invariant` field for the sake of duplicate checks, + // but use the original in error messages. + let canonical = if let crate::BuiltIn::Position { .. } = built_in { + crate::BuiltIn::Position { invariant: false } + } else { + built_in + }; + + if self.built_ins.contains(&canonical) { + return Err(VaryingError::DuplicateBuiltIn(built_in)); + } + self.built_ins.insert(canonical); + + let required = match built_in { + Bi::ClipDistance => Capabilities::CLIP_DISTANCE, + Bi::CullDistance => Capabilities::CULL_DISTANCE, + Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX, + Bi::ViewIndex => Capabilities::MULTIVIEW, + Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING, + _ => Capabilities::empty(), + }; + if !self.capabilities.contains(required) { + return Err(VaryingError::UnsupportedCapability(required)); + } + + let (visible, type_good) = match built_in { + Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( + self.stage == St::Vertex && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::ClipDistance | Bi::CullDistance => ( + self.stage == St::Vertex && self.output, + match *ty_inner { + Ti::Array { base, .. } => { + self.types[base].inner == Ti::Scalar(crate::Scalar::F32) + } + _ => false, + }, + ), + Bi::PointSize => ( + self.stage == St::Vertex && self.output, + *ty_inner == Ti::Scalar(crate::Scalar::F32), + ), + Bi::PointCoord => ( + self.stage == St::Fragment && !self.output, + *ty_inner + == Ti::Vector { + size: Vs::Bi, + scalar: crate::Scalar::F32, + }, + ), + Bi::Position { .. } => ( + match self.stage { + St::Vertex => self.output, + St::Fragment => !self.output, + St::Compute => false, + }, + *ty_inner + == Ti::Vector { + size: Vs::Quad, + scalar: crate::Scalar::F32, + }, + ), + Bi::ViewIndex => ( + match self.stage { + St::Vertex | St::Fragment => !self.output, + St::Compute => false, + }, + *ty_inner == Ti::Scalar(crate::Scalar::I32), + ), + Bi::FragDepth => ( + self.stage == St::Fragment && self.output, + *ty_inner == Ti::Scalar(crate::Scalar::F32), + ), + Bi::FrontFacing => ( + self.stage == St::Fragment && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::BOOL), + ), + Bi::PrimitiveIndex => ( + self.stage == St::Fragment && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::SampleIndex => ( + self.stage == St::Fragment && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::SampleMask => ( + self.stage == St::Fragment, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::LocalInvocationIndex => ( + self.stage == St::Compute && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::GlobalInvocationId + | Bi::LocalInvocationId + | Bi::WorkGroupId + | Bi::WorkGroupSize + | Bi::NumWorkGroups => ( + self.stage == St::Compute && !self.output, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), + }; + + if !visible { + return Err(VaryingError::InvalidBuiltInStage(built_in)); + } + if !type_good { + log::warn!("Wrong builtin type: {:?}", ty_inner); + return Err(VaryingError::InvalidBuiltInType(built_in)); + } + } + crate::Binding::Location { + location, + interpolation, + sampling, + second_blend_source, + } => { + // Only IO-shareable types may be stored in locations. + if !self.type_info[ty.index()] + .flags + .contains(super::TypeFlags::IO_SHAREABLE) + { + return Err(VaryingError::NotIOShareableType(ty)); + } + + if second_blend_source { + if !self + .capabilities + .contains(Capabilities::DUAL_SOURCE_BLENDING) + { + return Err(VaryingError::UnsupportedCapability( + Capabilities::DUAL_SOURCE_BLENDING, + )); + } + if self.stage != crate::ShaderStage::Fragment { + return Err(VaryingError::InvalidAttributeInStage( + "second_blend_source", + self.stage, + )); + } + if !self.output { + return Err(VaryingError::InvalidInputAttributeInStage( + "second_blend_source", + self.stage, + )); + } + if location != 0 { + return Err(VaryingError::InvalidLocationAttributeCombination { + location, + attribute: "second_blend_source", + }); + } + + self.second_blend_source = true; + } else if !self.location_mask.insert(location as usize) { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(VaryingError::BindingCollision { location }); + } + } + + let needs_interpolation = match self.stage { + crate::ShaderStage::Vertex => self.output, + crate::ShaderStage::Fragment => !self.output, + crate::ShaderStage::Compute => false, + }; + + // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but + // SPIR-V and GLSL both explicitly tolerate such combinations of decorators / + // qualifiers, so we won't complain about that here. + let _ = sampling; + + let required = match sampling { + Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING, + _ => Capabilities::empty(), + }; + if !self.capabilities.contains(required) { + return Err(VaryingError::UnsupportedCapability(required)); + } + + match ty_inner.scalar_kind() { + Some(crate::ScalarKind::Float) => { + if needs_interpolation && interpolation.is_none() { + return Err(VaryingError::MissingInterpolation); + } + } + Some(_) => { + if needs_interpolation && interpolation != Some(crate::Interpolation::Flat) + { + return Err(VaryingError::InvalidInterpolation); + } + } + None => return Err(VaryingError::InvalidType(ty)), + } + } + } + + Ok(()) + } + + fn validate( + &mut self, + ty: Handle<crate::Type>, + binding: Option<&crate::Binding>, + ) -> Result<(), WithSpan<VaryingError>> { + let span_context = self.types.get_span_context(ty); + match binding { + Some(binding) => self + .validate_impl(ty, binding) + .map_err(|e| e.with_span_context(span_context)), + None => { + match self.types[ty].inner { + crate::TypeInner::Struct { ref members, .. } => { + for (index, member) in members.iter().enumerate() { + let span_context = self.types.get_span_context(ty); + match member.binding { + None => { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(VaryingError::MemberMissingBinding( + index as u32, + ) + .with_span_context(span_context)); + } + } + Some(ref binding) => self + .validate_impl(member.ty, binding) + .map_err(|e| e.with_span_context(span_context))?, + } + } + } + _ => { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(VaryingError::MissingBinding.with_span()); + } + } + } + Ok(()) + } + } + } +} + +impl super::Validator { + pub(super) fn validate_global_var( + &self, + var: &crate::GlobalVariable, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), GlobalVariableError> { + use super::TypeFlags; + + log::debug!("var {:?}", var); + let inner_ty = match gctx.types[var.ty].inner { + // A binding array is (mostly) supposed to behave the same as a + // series of individually bound resources, so we can (mostly) + // validate a `binding_array<T>` as if it were just a plain `T`. + crate::TypeInner::BindingArray { base, .. } => match var.space { + crate::AddressSpace::Storage { .. } + | crate::AddressSpace::Uniform + | crate::AddressSpace::Handle => base, + _ => return Err(GlobalVariableError::InvalidUsage(var.space)), + }, + _ => var.ty, + }; + let type_info = &self.types[inner_ty.index()]; + + let (required_type_flags, is_resource) = match var.space { + crate::AddressSpace::Function => { + return Err(GlobalVariableError::InvalidUsage(var.space)) + } + crate::AddressSpace::Storage { access } => { + if let Err((ty_handle, disalignment)) = type_info.storage_layout { + if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) { + return Err(GlobalVariableError::Alignment( + var.space, + ty_handle, + disalignment, + )); + } + } + if access == crate::StorageAccess::STORE { + return Err(GlobalVariableError::StorageAddressSpaceWriteOnlyNotSupported); + } + (TypeFlags::DATA | TypeFlags::HOST_SHAREABLE, true) + } + crate::AddressSpace::Uniform => { + if let Err((ty_handle, disalignment)) = type_info.uniform_layout { + if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) { + return Err(GlobalVariableError::Alignment( + var.space, + ty_handle, + disalignment, + )); + } + } + ( + TypeFlags::DATA + | TypeFlags::COPY + | TypeFlags::SIZED + | TypeFlags::HOST_SHAREABLE, + true, + ) + } + crate::AddressSpace::Handle => { + match gctx.types[inner_ty].inner { + crate::TypeInner::Image { class, .. } => match class { + crate::ImageClass::Storage { + format: + crate::StorageFormat::R16Unorm + | crate::StorageFormat::R16Snorm + | crate::StorageFormat::Rg16Unorm + | crate::StorageFormat::Rg16Snorm + | crate::StorageFormat::Rgba16Unorm + | crate::StorageFormat::Rgba16Snorm, + .. + } => { + if !self + .capabilities + .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS) + { + return Err(GlobalVariableError::UnsupportedCapability( + Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS, + )); + } + } + _ => {} + }, + crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => {} + _ => { + return Err(GlobalVariableError::InvalidType(var.space)); + } + } + + (TypeFlags::empty(), true) + } + crate::AddressSpace::Private => (TypeFlags::CONSTRUCTIBLE, false), + crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false), + crate::AddressSpace::PushConstant => { + if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) { + return Err(GlobalVariableError::UnsupportedCapability( + Capabilities::PUSH_CONSTANT, + )); + } + ( + TypeFlags::DATA + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::SIZED, + false, + ) + } + }; + + if !type_info.flags.contains(required_type_flags) { + return Err(GlobalVariableError::MissingTypeFlags { + seen: type_info.flags, + required: required_type_flags, + }); + } + + if is_resource != var.binding.is_some() { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(GlobalVariableError::InvalidBinding); + } + } + + if let Some(init) = var.init { + match var.space { + crate::AddressSpace::Private | crate::AddressSpace::Function => {} + _ => { + return Err(GlobalVariableError::InitializerNotAllowed(var.space)); + } + } + + let decl_ty = &gctx.types[var.ty].inner; + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(GlobalVariableError::InitializerType); + } + } + + Ok(()) + } + + pub(super) fn validate_entry_point( + &mut self, + ep: &crate::EntryPoint, + module: &crate::Module, + mod_info: &ModuleInfo, + ) -> Result<FunctionInfo, WithSpan<EntryPointError>> { + if ep.early_depth_test.is_some() { + let required = Capabilities::EARLY_DEPTH_TEST; + if !self.capabilities.contains(required) { + return Err( + EntryPointError::Result(VaryingError::UnsupportedCapability(required)) + .with_span(), + ); + } + + if ep.stage != crate::ShaderStage::Fragment { + return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span()); + } + } + + if ep.stage == crate::ShaderStage::Compute { + if ep + .workgroup_size + .iter() + .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE) + { + return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span()); + } + } else if ep.workgroup_size != [0; 3] { + return Err(EntryPointError::UnexpectedWorkgroupSize.with_span()); + } + + let mut info = self + .validate_function(&ep.function, module, mod_info, true) + .map_err(WithSpan::into_other)?; + + { + use super::ShaderStages; + + let stage_bit = match ep.stage { + crate::ShaderStage::Vertex => ShaderStages::VERTEX, + crate::ShaderStage::Fragment => ShaderStages::FRAGMENT, + crate::ShaderStage::Compute => ShaderStages::COMPUTE, + }; + + if !info.available_stages.contains(stage_bit) { + return Err(EntryPointError::ForbiddenStageOperations.with_span()); + } + } + + self.location_mask.clear(); + let mut argument_built_ins = crate::FastHashSet::default(); + // TODO: add span info to function arguments + for (index, fa) in ep.function.arguments.iter().enumerate() { + let mut ctx = VaryingContext { + stage: ep.stage, + output: false, + second_blend_source: false, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + built_ins: &mut argument_built_ins, + capabilities: self.capabilities, + flags: self.flags, + }; + ctx.validate(fa.ty, fa.binding.as_ref()) + .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; + } + + self.location_mask.clear(); + if let Some(ref fr) = ep.function.result { + let mut result_built_ins = crate::FastHashSet::default(); + let mut ctx = VaryingContext { + stage: ep.stage, + output: true, + second_blend_source: false, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + built_ins: &mut result_built_ins, + capabilities: self.capabilities, + flags: self.flags, + }; + ctx.validate(fr.ty, fr.binding.as_ref()) + .map_err_inner(|e| EntryPointError::Result(e).with_span())?; + if ctx.second_blend_source { + // Only the first location may be used when dual source blending + if ctx.location_mask.len() == 1 && ctx.location_mask.contains(0) { + info.dual_source_blending = true; + } else { + return Err(EntryPointError::InvalidLocationsWhileDualSourceBlending { + location_mask: self.location_mask.clone(), + } + .with_span()); + } + } + + if ep.stage == crate::ShaderStage::Vertex + && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false }) + { + return Err(EntryPointError::MissingVertexOutputPosition.with_span()); + } + } else if ep.stage == crate::ShaderStage::Vertex { + return Err(EntryPointError::MissingVertexOutputPosition.with_span()); + } + + { + let used_push_constants = module + .global_variables + .iter() + .filter(|&(_, var)| var.space == crate::AddressSpace::PushConstant) + .map(|(handle, _)| handle) + .filter(|&handle| !info[handle].is_empty()); + // Check if there is more than one push constant, and error if so. + // Use a loop for when returning multiple errors is supported. + #[allow(clippy::never_loop)] + for handle in used_push_constants.skip(1) { + return Err(EntryPointError::MoreThanOnePushConstantUsed + .with_span_handle(handle, &module.global_variables)); + } + } + + self.ep_resource_bindings.clear(); + for (var_handle, var) in module.global_variables.iter() { + let usage = info[var_handle]; + if usage.is_empty() { + continue; + } + + let allowed_usage = match var.space { + crate::AddressSpace::Function => unreachable!(), + crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY, + crate::AddressSpace::Storage { access } => storage_usage(access), + crate::AddressSpace::Handle => match module.types[var.ty].inner { + crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner { + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => storage_usage(access), + _ => GlobalUse::READ | GlobalUse::QUERY, + }, + crate::TypeInner::Image { + class: crate::ImageClass::Storage { access, .. }, + .. + } => storage_usage(access), + _ => GlobalUse::READ | GlobalUse::QUERY, + }, + crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => GlobalUse::all(), + crate::AddressSpace::PushConstant => GlobalUse::READ, + }; + if !allowed_usage.contains(usage) { + log::warn!("\tUsage error for: {:?}", var); + log::warn!( + "\tAllowed usage: {:?}, requested: {:?}", + allowed_usage, + usage + ); + return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage) + .with_span_handle(var_handle, &module.global_variables)); + } + + if let Some(ref bind) = var.binding { + if !self.ep_resource_bindings.insert(bind.clone()) { + if self.flags.contains(super::ValidationFlags::BINDINGS) { + return Err(EntryPointError::BindingCollision(var_handle) + .with_span_handle(var_handle, &module.global_variables)); + } + } + } + } + + Ok(info) + } +} diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs new file mode 100644 index 0000000000..388495a3ac --- /dev/null +++ b/third_party/rust/naga/src/valid/mod.rs @@ -0,0 +1,477 @@ +/*! +Shader validator. +*/ + +mod analyzer; +mod compose; +mod expression; +mod function; +mod handles; +mod interface; +mod r#type; + +use crate::{ + arena::Handle, + proc::{LayoutError, Layouter, TypeResolution}, + FastHashSet, +}; +use bit_set::BitSet; +use std::ops; + +//TODO: analyze the model at the same time as we validate it, +// merge the corresponding matches over expressions and statements. + +use crate::span::{AddSpan as _, WithSpan}; +pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; +pub use compose::ComposeError; +pub use expression::{check_literal_value, LiteralError}; +pub use expression::{ConstExpressionError, ExpressionError}; +pub use function::{CallError, FunctionError, LocalVariableError}; +pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; +pub use r#type::{Disalignment, TypeError, TypeFlags}; + +use self::handles::InvalidHandleError; + +bitflags::bitflags! { + /// Validation flags. + /// + /// If you are working with trusted shaders, then you may be able + /// to save some time by skipping validation. + /// + /// If you do not perform full validation, invalid shaders may + /// cause Naga to panic. If you do perform full validation and + /// [`Validator::validate`] returns `Ok`, then Naga promises that + /// code generation will either succeed or return an error; it + /// should never panic. + /// + /// The default value for `ValidationFlags` is + /// `ValidationFlags::all()`. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct ValidationFlags: u8 { + /// Expressions. + const EXPRESSIONS = 0x1; + /// Statements and blocks of them. + const BLOCKS = 0x2; + /// Uniformity of control flow for operations that require it. + const CONTROL_FLOW_UNIFORMITY = 0x4; + /// Host-shareable structure layouts. + const STRUCT_LAYOUTS = 0x8; + /// Constants. + const CONSTANTS = 0x10; + /// Group, binding, and location attributes. + const BINDINGS = 0x20; + } +} + +impl Default for ValidationFlags { + fn default() -> Self { + Self::all() + } +} + +bitflags::bitflags! { + /// Allowed IR capabilities. + #[must_use] + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct Capabilities: u16 { + /// Support for [`AddressSpace:PushConstant`]. + const PUSH_CONSTANT = 0x1; + /// Float values with width = 8. + const FLOAT64 = 0x2; + /// Support for [`Builtin:PrimitiveIndex`]. + const PRIMITIVE_INDEX = 0x4; + /// Support for non-uniform indexing of sampled textures and storage buffer arrays. + const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8; + /// Support for non-uniform indexing of uniform buffers and storage texture arrays. + const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10; + /// Support for non-uniform indexing of samplers. + const SAMPLER_NON_UNIFORM_INDEXING = 0x20; + /// Support for [`Builtin::ClipDistance`]. + const CLIP_DISTANCE = 0x40; + /// Support for [`Builtin::CullDistance`]. + const CULL_DISTANCE = 0x80; + /// Support for 16-bit normalized storage texture formats. + const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100; + /// Support for [`BuiltIn::ViewIndex`]. + const MULTIVIEW = 0x200; + /// Support for `early_depth_test`. + const EARLY_DEPTH_TEST = 0x400; + /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`]. + const MULTISAMPLED_SHADING = 0x800; + /// Support for ray queries and acceleration structures. + const RAY_QUERY = 0x1000; + /// Support for generating two sources for blending from fragment shaders. + const DUAL_SOURCE_BLENDING = 0x2000; + /// Support for arrayed cube textures. + const CUBE_ARRAY_TEXTURES = 0x4000; + } +} + +impl Default for Capabilities { + fn default() -> Self { + Self::MULTISAMPLED_SHADING | Self::CUBE_ARRAY_TEXTURES + } +} + +bitflags::bitflags! { + /// Validation flags. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct ShaderStages: u8 { + const VERTEX = 0x1; + const FRAGMENT = 0x2; + const COMPUTE = 0x4; + } +} + +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct ModuleInfo { + type_flags: Vec<TypeFlags>, + functions: Vec<FunctionInfo>, + entry_points: Vec<FunctionInfo>, + const_expression_types: Box<[TypeResolution]>, +} + +impl ops::Index<Handle<crate::Type>> for ModuleInfo { + type Output = TypeFlags; + fn index(&self, handle: Handle<crate::Type>) -> &Self::Output { + &self.type_flags[handle.index()] + } +} + +impl ops::Index<Handle<crate::Function>> for ModuleInfo { + type Output = FunctionInfo; + fn index(&self, handle: Handle<crate::Function>) -> &Self::Output { + &self.functions[handle.index()] + } +} + +impl ops::Index<Handle<crate::Expression>> for ModuleInfo { + type Output = TypeResolution; + fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output { + &self.const_expression_types[handle.index()] + } +} + +#[derive(Debug)] +pub struct Validator { + flags: ValidationFlags, + capabilities: Capabilities, + types: Vec<r#type::TypeInfo>, + layouter: Layouter, + location_mask: BitSet, + ep_resource_bindings: FastHashSet<crate::ResourceBinding>, + #[allow(dead_code)] + switch_values: FastHashSet<crate::SwitchValue>, + valid_expression_list: Vec<Handle<crate::Expression>>, + valid_expression_set: BitSet, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ConstantError { + #[error("The type doesn't match the constant")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ValidationError { + #[error(transparent)] + InvalidHandle(#[from] InvalidHandleError), + #[error(transparent)] + Layouter(#[from] LayoutError), + #[error("Type {handle:?} '{name}' is invalid")] + Type { + handle: Handle<crate::Type>, + name: String, + source: TypeError, + }, + #[error("Constant expression {handle:?} is invalid")] + ConstExpression { + handle: Handle<crate::Expression>, + source: ConstExpressionError, + }, + #[error("Constant {handle:?} '{name}' is invalid")] + Constant { + handle: Handle<crate::Constant>, + name: String, + source: ConstantError, + }, + #[error("Global variable {handle:?} '{name}' is invalid")] + GlobalVariable { + handle: Handle<crate::GlobalVariable>, + name: String, + source: GlobalVariableError, + }, + #[error("Function {handle:?} '{name}' is invalid")] + Function { + handle: Handle<crate::Function>, + name: String, + source: FunctionError, + }, + #[error("Entry point {name} at {stage:?} is invalid")] + EntryPoint { + stage: crate::ShaderStage, + name: String, + source: EntryPointError, + }, + #[error("Module is corrupted")] + Corrupted, +} + +impl crate::TypeInner { + const fn is_sized(&self) -> bool { + match *self { + Self::Scalar { .. } + | Self::Vector { .. } + | Self::Matrix { .. } + | Self::Array { + size: crate::ArraySize::Constant(_), + .. + } + | Self::Atomic { .. } + | Self::Pointer { .. } + | Self::ValuePointer { .. } + | Self::Struct { .. } => true, + Self::Array { .. } + | Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery + | Self::BindingArray { .. } => false, + } + } + + /// Return the `ImageDimension` for which `self` is an appropriate coordinate. + const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> { + match *self { + Self::Scalar(crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + }) => Some(crate::ImageDimension::D1), + Self::Vector { + size: crate::VectorSize::Bi, + scalar: + crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + }, + } => Some(crate::ImageDimension::D2), + Self::Vector { + size: crate::VectorSize::Tri, + scalar: + crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. + }, + } => Some(crate::ImageDimension::D3), + _ => None, + } + } +} + +impl Validator { + /// Construct a new validator instance. + pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self { + Validator { + flags, + capabilities, + types: Vec::new(), + layouter: Layouter::default(), + location_mask: BitSet::new(), + ep_resource_bindings: FastHashSet::default(), + switch_values: FastHashSet::default(), + valid_expression_list: Vec::new(), + valid_expression_set: BitSet::new(), + } + } + + /// Reset the validator internals + pub fn reset(&mut self) { + self.types.clear(); + self.layouter.clear(); + self.location_mask.clear(); + self.ep_resource_bindings.clear(); + self.switch_values.clear(); + self.valid_expression_list.clear(); + self.valid_expression_set.clear(); + } + + fn validate_constant( + &self, + handle: Handle<crate::Constant>, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), ConstantError> { + let con = &gctx.constants[handle]; + + let type_info = &self.types[con.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(ConstantError::NonConstructibleType); + } + + let decl_ty = &gctx.types[con.ty].inner; + let init_ty = mod_info[con.init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(ConstantError::InvalidType); + } + + Ok(()) + } + + /// Check the given module to be valid. + pub fn validate( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.reset(); + self.reset_types(module.types.len()); + + Self::validate_module_handles(module).map_err(|e| e.with_span())?; + + self.layouter.update(module.to_ctx()).map_err(|e| { + let handle = e.ty; + ValidationError::from(e).with_span_handle(handle, &module.types) + })?; + + // These should all get overwritten. + let placeholder = TypeResolution::Value(crate::TypeInner::Scalar(crate::Scalar { + kind: crate::ScalarKind::Bool, + width: 0, + })); + + let mut mod_info = ModuleInfo { + type_flags: Vec::with_capacity(module.types.len()), + functions: Vec::with_capacity(module.functions.len()), + entry_points: Vec::with_capacity(module.entry_points.len()), + const_expression_types: vec![placeholder; module.const_expressions.len()] + .into_boxed_slice(), + }; + + for (handle, ty) in module.types.iter() { + let ty_info = self + .validate_type(handle, module.to_ctx()) + .map_err(|source| { + ValidationError::Type { + handle, + name: ty.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.types) + })?; + mod_info.type_flags.push(ty_info.flags); + self.types[handle.index()] = ty_info; + } + + { + let t = crate::Arena::new(); + let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]); + for (handle, _) in module.const_expressions.iter() { + mod_info + .process_const_expression(handle, &resolve_context, module.to_ctx()) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.const_expressions) + })? + } + } + + if self.flags.contains(ValidationFlags::CONSTANTS) { + for (handle, _) in module.const_expressions.iter() { + self.validate_const_expression(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.const_expressions) + })? + } + + for (handle, constant) in module.constants.iter() { + self.validate_constant(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Constant { + handle, + name: constant.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.constants) + })? + } + } + + for (var_handle, var) in module.global_variables.iter() { + self.validate_global_var(var, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(var_handle, &module.global_variables) + })?; + } + + for (handle, fun) in module.functions.iter() { + match self.validate_function(fun, module, &mod_info, false) { + Ok(info) => mod_info.functions.push(info), + Err(error) => { + return Err(error.and_then(|source| { + ValidationError::Function { + handle, + name: fun.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.functions) + })) + } + } + } + + let mut ep_map = FastHashSet::default(); + for ep in module.entry_points.iter() { + if !ep_map.insert((ep.stage, &ep.name)) { + return Err(ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + source: EntryPointError::Conflict, + } + .with_span()); // TODO: keep some EP span information? + } + + match self.validate_entry_point(ep, module, &mod_info) { + Ok(info) => mod_info.entry_points.push(info), + Err(error) => { + return Err(error.and_then(|source| { + ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + source, + } + .with_span() + })); + } + } + } + + Ok(mod_info) + } +} + +fn validate_atomic_compare_exchange_struct( + types: &crate::UniqueArena<crate::Type>, + members: &[crate::StructMember], + scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool, +) -> bool { + members.len() == 2 + && members[0].name.as_deref() == Some("old_value") + && scalar_predicate(&types[members[0].ty].inner) + && members[1].name.as_deref() == Some("exchanged") + && types[members[1].ty].inner == crate::TypeInner::Scalar(crate::Scalar::BOOL) +} diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs new file mode 100644 index 0000000000..1e3e03fe19 --- /dev/null +++ b/third_party/rust/naga/src/valid/type.rs @@ -0,0 +1,643 @@ +use super::Capabilities; +use crate::{arena::Handle, proc::Alignment}; + +bitflags::bitflags! { + /// Flags associated with [`Type`]s by [`Validator`]. + /// + /// [`Type`]: crate::Type + /// [`Validator`]: crate::valid::Validator + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[repr(transparent)] + #[derive(Clone, Copy, Debug, Eq, PartialEq)] + pub struct TypeFlags: u8 { + /// Can be used for data variables. + /// + /// This flag is required on types of local variables, function + /// arguments, array elements, and struct members. + /// + /// This includes all types except `Image`, `Sampler`, + /// and some `Pointer` types. + const DATA = 0x1; + + /// The data type has a size known by pipeline creation time. + /// + /// Unsized types are quite restricted. The only unsized types permitted + /// by Naga, other than the non-[`DATA`] types like [`Image`] and + /// [`Sampler`], are dynamically-sized [`Array`s], and [`Struct`s] whose + /// last members are such arrays. See the documentation for those types + /// for details. + /// + /// [`DATA`]: TypeFlags::DATA + /// [`Image`]: crate::Type::Image + /// [`Sampler`]: crate::Type::Sampler + /// [`Array`]: crate::Type::Array + /// [`Struct`]: crate::Type::struct + const SIZED = 0x2; + + /// The data can be copied around. + const COPY = 0x4; + + /// Can be be used for user-defined IO between pipeline stages. + /// + /// This covers anything that can be in [`Location`] binding: + /// non-bool scalars and vectors, matrices, and structs and + /// arrays containing only interface types. + const IO_SHAREABLE = 0x8; + + /// Can be used for host-shareable structures. + const HOST_SHAREABLE = 0x10; + + /// This type can be passed as a function argument. + const ARGUMENT = 0x40; + + /// A WGSL [constructible] type. + /// + /// The constructible types are scalars, vectors, matrices, fixed-size + /// arrays of constructible types, and structs whose members are all + /// constructible. + /// + /// [constructible]: https://gpuweb.github.io/gpuweb/wgsl/#constructible + const CONSTRUCTIBLE = 0x80; + } +} + +#[derive(Clone, Copy, Debug, thiserror::Error)] +pub enum Disalignment { + #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")] + ArrayStride { stride: u32, alignment: Alignment }, + #[error("The struct span {span}, is not a multiple of the required alignment {alignment}")] + StructSpan { span: u32, alignment: Alignment }, + #[error("The struct member[{index}] offset {offset} is not a multiple of the required alignment {alignment}")] + MemberOffset { + index: u32, + offset: u32, + alignment: Alignment, + }, + #[error("The struct member[{index}] offset {offset} must be at least {expected}")] + MemberOffsetAfterStruct { + index: u32, + offset: u32, + expected: u32, + }, + #[error("The struct member[{index}] is not statically sized")] + UnsizedMember { index: u32 }, + #[error("The type is not host-shareable")] + NonHostShareable, +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum TypeError { + #[error("Capability {0:?} is required")] + MissingCapability(Capabilities), + #[error("The {0:?} scalar width {1} is not supported for an atomic")] + InvalidAtomicWidth(crate::ScalarKind, crate::Bytes), + #[error("Invalid type for pointer target {0:?}")] + InvalidPointerBase(Handle<crate::Type>), + #[error("Unsized types like {base:?} must be in the `Storage` address space, not `{space:?}`")] + InvalidPointerToUnsized { + base: Handle<crate::Type>, + space: crate::AddressSpace, + }, + #[error("Expected data type, found {0:?}")] + InvalidData(Handle<crate::Type>), + #[error("Base type {0:?} for the array is invalid")] + InvalidArrayBaseType(Handle<crate::Type>), + #[error("Matrix elements must always be floating-point types")] + MatrixElementNotFloat, + #[error("The constant {0:?} is specialized, and cannot be used as an array size")] + UnsupportedSpecializedArrayLength(Handle<crate::Constant>), + #[error("Array stride {stride} does not match the expected {expected}")] + InvalidArrayStride { stride: u32, expected: u32 }, + #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")] + InvalidDynamicArray(String, Handle<crate::Type>), + #[error("The base handle {0:?} has to be a struct")] + BindingArrayBaseTypeNotStruct(Handle<crate::Type>), + #[error("Structure member[{index}] at {offset} overlaps the previous member")] + MemberOverlap { index: u32, offset: u32 }, + #[error( + "Structure member[{index}] at {offset} and size {size} crosses the structure boundary of size {span}" + )] + MemberOutOfBounds { + index: u32, + offset: u32, + size: u32, + span: u32, + }, + #[error("Structure types must have at least one member")] + EmptyStruct, + #[error(transparent)] + WidthError(#[from] WidthError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum WidthError { + #[error("The {0:?} scalar width {1} is not supported")] + Invalid(crate::ScalarKind, crate::Bytes), + #[error("Using `{name}` values requires the `naga::valid::Capabilities::{flag}` flag")] + MissingCapability { + name: &'static str, + flag: &'static str, + }, + + #[error("64-bit integers are not yet supported")] + Unsupported64Bit, + + #[error("Abstract types may only appear in constant expressions")] + Abstract, +} + +// Only makes sense if `flags.contains(HOST_SHAREABLE)` +type LayoutCompatibility = Result<Alignment, (Handle<crate::Type>, Disalignment)>; + +fn check_member_layout( + accum: &mut LayoutCompatibility, + member: &crate::StructMember, + member_index: u32, + member_layout: LayoutCompatibility, + parent_handle: Handle<crate::Type>, +) { + *accum = match (*accum, member_layout) { + (Ok(cur_alignment), Ok(alignment)) => { + if alignment.is_aligned(member.offset) { + Ok(cur_alignment.max(alignment)) + } else { + Err(( + parent_handle, + Disalignment::MemberOffset { + index: member_index, + offset: member.offset, + alignment, + }, + )) + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }; +} + +/// Determine whether a pointer in `space` can be passed as an argument. +/// +/// If a pointer in `space` is permitted to be passed as an argument to a +/// user-defined function, return `TypeFlags::ARGUMENT`. Otherwise, return +/// `TypeFlags::empty()`. +/// +/// Pointers passed as arguments to user-defined functions must be in the +/// `Function` or `Private` address space. +const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags { + use crate::AddressSpace as As; + match space { + As::Function | As::Private => TypeFlags::ARGUMENT, + As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => { + TypeFlags::empty() + } + } +} + +#[derive(Clone, Debug)] +pub(super) struct TypeInfo { + pub flags: TypeFlags, + pub uniform_layout: LayoutCompatibility, + pub storage_layout: LayoutCompatibility, +} + +impl TypeInfo { + const fn dummy() -> Self { + TypeInfo { + flags: TypeFlags::empty(), + uniform_layout: Ok(Alignment::ONE), + storage_layout: Ok(Alignment::ONE), + } + } + + const fn new(flags: TypeFlags, alignment: Alignment) -> Self { + TypeInfo { + flags, + uniform_layout: Ok(alignment), + storage_layout: Ok(alignment), + } + } +} + +impl super::Validator { + const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { + if self.capabilities.contains(capability) { + Ok(()) + } else { + Err(TypeError::MissingCapability(capability)) + } + } + + pub(super) const fn check_width(&self, scalar: crate::Scalar) -> Result<(), WidthError> { + let good = match scalar.kind { + crate::ScalarKind::Bool => scalar.width == crate::BOOL_WIDTH, + crate::ScalarKind::Float => { + if scalar.width == 8 { + if !self.capabilities.contains(Capabilities::FLOAT64) { + return Err(WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + }); + } + true + } else { + scalar.width == 4 + } + } + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + if scalar.width == 8 { + return Err(WidthError::Unsupported64Bit); + } + scalar.width == 4 + } + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(WidthError::Abstract); + } + }; + if good { + Ok(()) + } else { + Err(WidthError::Invalid(scalar.kind, scalar.width)) + } + } + + pub(super) fn reset_types(&mut self, size: usize) { + self.types.clear(); + self.types.resize(size, TypeInfo::dummy()); + self.layouter.clear(); + } + + pub(super) fn validate_type( + &self, + handle: Handle<crate::Type>, + gctx: crate::proc::GlobalCtx, + ) -> Result<TypeInfo, TypeError> { + use crate::TypeInner as Ti; + Ok(match gctx.types[handle].inner { + Ti::Scalar(scalar) => { + self.check_width(scalar)?; + let shareable = if scalar.kind.is_numeric() { + TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE + } else { + TypeFlags::empty() + }; + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + | shareable, + Alignment::from_width(scalar.width), + ) + } + Ti::Vector { size, scalar } => { + self.check_width(scalar)?; + let shareable = if scalar.kind.is_numeric() { + TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE + } else { + TypeFlags::empty() + }; + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + | shareable, + Alignment::from(size) * Alignment::from_width(scalar.width), + ) + } + Ti::Matrix { + columns: _, + rows, + scalar, + } => { + if scalar.kind != crate::ScalarKind::Float { + return Err(TypeError::MatrixElementNotFloat); + } + self.check_width(scalar)?; + TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE, + Alignment::from(rows) * Alignment::from_width(scalar.width), + ) + } + Ti::Atomic(crate::Scalar { kind, width }) => { + let good = match kind { + crate::ScalarKind::Bool + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => false, + crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, + }; + if !good { + return Err(TypeError::InvalidAtomicWidth(kind, width)); + } + TypeInfo::new( + TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE, + Alignment::from_width(width), + ) + } + Ti::Pointer { base, space } => { + use crate::AddressSpace as As; + + let base_info = &self.types[base.index()]; + if !base_info.flags.contains(TypeFlags::DATA) { + return Err(TypeError::InvalidPointerBase(base)); + } + + // Runtime-sized values can only live in the `Storage` address + // space, so it's useless to have a pointer to such a type in + // any other space. + // + // Detecting this problem here prevents the definition of + // functions like: + // + // fn f(p: ptr<workgroup, UnsizedType>) -> ... { ... } + // + // which would otherwise be permitted, but uncallable. (They + // may also present difficulties in code generation). + if !base_info.flags.contains(TypeFlags::SIZED) { + match space { + As::Storage { .. } => {} + _ => { + return Err(TypeError::InvalidPointerToUnsized { base, space }); + } + } + } + + // `Validator::validate_function` actually checks the address + // space of pointer arguments explicitly before checking the + // `ARGUMENT` flag, to give better error messages. But it seems + // best to set `ARGUMENT` accurately anyway. + let argument_flag = ptr_space_argument_flag(space); + + // Pointers cannot be stored in variables, structure members, or + // array elements, so we do not mark them as `DATA`. + TypeInfo::new( + argument_flag | TypeFlags::SIZED | TypeFlags::COPY, + Alignment::ONE, + ) + } + Ti::ValuePointer { + size: _, + scalar, + space, + } => { + // ValuePointer should be treated the same way as the equivalent + // Pointer / Scalar / Vector combination, so each step in those + // variants' match arms should have a counterpart here. + // + // However, some cases are trivial: All our implicit base types + // are DATA and SIZED, so we can never return + // `InvalidPointerBase` or `InvalidPointerToUnsized`. + self.check_width(scalar)?; + + // `Validator::validate_function` actually checks the address + // space of pointer arguments explicitly before checking the + // `ARGUMENT` flag, to give better error messages. But it seems + // best to set `ARGUMENT` accurately anyway. + let argument_flag = ptr_space_argument_flag(space); + + // Pointers cannot be stored in variables, structure members, or + // array elements, so we do not mark them as `DATA`. + TypeInfo::new( + argument_flag | TypeFlags::SIZED | TypeFlags::COPY, + Alignment::ONE, + ) + } + Ti::Array { base, size, stride } => { + let base_info = &self.types[base.index()]; + if !base_info.flags.contains(TypeFlags::DATA | TypeFlags::SIZED) { + return Err(TypeError::InvalidArrayBaseType(base)); + } + + let base_layout = self.layouter[base]; + let general_alignment = base_layout.alignment; + let uniform_layout = match base_info.uniform_layout { + Ok(base_alignment) => { + let alignment = base_alignment + .max(general_alignment) + .max(Alignment::MIN_UNIFORM); + if alignment.is_aligned(stride) { + Ok(alignment) + } else { + Err((handle, Disalignment::ArrayStride { stride, alignment })) + } + } + Err(e) => Err(e), + }; + let storage_layout = match base_info.storage_layout { + Ok(base_alignment) => { + let alignment = base_alignment.max(general_alignment); + if alignment.is_aligned(stride) { + Ok(alignment) + } else { + Err((handle, Disalignment::ArrayStride { stride, alignment })) + } + } + Err(e) => Err(e), + }; + + let type_info_mask = match size { + crate::ArraySize::Constant(_) => { + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE + } + crate::ArraySize::Dynamic => { + // Non-SIZED types may only appear as the last element of a structure. + // This is enforced by checks for SIZED-ness for all compound types, + // and a special case for structs. + TypeFlags::DATA | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE + } + }; + + TypeInfo { + flags: base_info.flags & type_info_mask, + uniform_layout, + storage_layout, + } + } + Ti::Struct { ref members, span } => { + if members.is_empty() { + return Err(TypeError::EmptyStruct); + } + + let mut ti = TypeInfo::new( + TypeFlags::DATA + | TypeFlags::SIZED + | TypeFlags::COPY + | TypeFlags::HOST_SHAREABLE + | TypeFlags::IO_SHAREABLE + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE, + Alignment::ONE, + ); + ti.uniform_layout = Ok(Alignment::MIN_UNIFORM); + + let mut min_offset = 0; + + let mut prev_struct_data: Option<(u32, u32)> = None; + + for (i, member) in members.iter().enumerate() { + let base_info = &self.types[member.ty.index()]; + if !base_info.flags.contains(TypeFlags::DATA) { + return Err(TypeError::InvalidData(member.ty)); + } + if !base_info.flags.contains(TypeFlags::HOST_SHAREABLE) { + if ti.uniform_layout.is_ok() { + ti.uniform_layout = Err((member.ty, Disalignment::NonHostShareable)); + } + if ti.storage_layout.is_ok() { + ti.storage_layout = Err((member.ty, Disalignment::NonHostShareable)); + } + } + ti.flags &= base_info.flags; + + if member.offset < min_offset { + // HACK: this could be nicer. We want to allow some structures + // to not bother with offsets/alignments if they are never + // used for host sharing. + if member.offset == 0 { + ti.flags.set(TypeFlags::HOST_SHAREABLE, false); + } else { + return Err(TypeError::MemberOverlap { + index: i as u32, + offset: member.offset, + }); + } + } + + let base_size = gctx.types[member.ty].inner.size(gctx); + min_offset = member.offset + base_size; + if min_offset > span { + return Err(TypeError::MemberOutOfBounds { + index: i as u32, + offset: member.offset, + size: base_size, + span, + }); + } + + check_member_layout( + &mut ti.uniform_layout, + member, + i as u32, + base_info.uniform_layout, + handle, + ); + check_member_layout( + &mut ti.storage_layout, + member, + i as u32, + base_info.storage_layout, + handle, + ); + + // Validate rule: If a structure member itself has a structure type S, + // then the number of bytes between the start of that member and + // the start of any following member must be at least roundUp(16, SizeOf(S)). + if let Some((span, offset)) = prev_struct_data { + let diff = member.offset - offset; + let min = Alignment::MIN_UNIFORM.round_up(span); + if diff < min { + ti.uniform_layout = Err(( + handle, + Disalignment::MemberOffsetAfterStruct { + index: i as u32, + offset: member.offset, + expected: offset + min, + }, + )); + } + }; + + prev_struct_data = match gctx.types[member.ty].inner { + crate::TypeInner::Struct { span, .. } => Some((span, member.offset)), + _ => None, + }; + + // The last field may be an unsized array. + if !base_info.flags.contains(TypeFlags::SIZED) { + let is_array = match gctx.types[member.ty].inner { + crate::TypeInner::Array { .. } => true, + _ => false, + }; + if !is_array || i + 1 != members.len() { + let name = member.name.clone().unwrap_or_default(); + return Err(TypeError::InvalidDynamicArray(name, member.ty)); + } + if ti.uniform_layout.is_ok() { + ti.uniform_layout = + Err((handle, Disalignment::UnsizedMember { index: i as u32 })); + } + } + } + + let alignment = self.layouter[handle].alignment; + if !alignment.is_aligned(span) { + ti.uniform_layout = Err((handle, Disalignment::StructSpan { span, alignment })); + ti.storage_layout = Err((handle, Disalignment::StructSpan { span, alignment })); + } + + ti + } + Ti::Image { + dim, + arrayed, + class: _, + } => { + if arrayed && matches!(dim, crate::ImageDimension::Cube) { + self.require_type_capability(Capabilities::CUBE_ARRAY_TEXTURES)?; + } + TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) + } + Ti::Sampler { .. } => TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE), + Ti::AccelerationStructure => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) + } + Ti::RayQuery => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new( + TypeFlags::DATA | TypeFlags::CONSTRUCTIBLE | TypeFlags::SIZED, + Alignment::ONE, + ) + } + Ti::BindingArray { base, size } => { + if base >= handle { + return Err(TypeError::InvalidArrayBaseType(base)); + } + let type_info_mask = match size { + crate::ArraySize::Constant(_) => TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE, + crate::ArraySize::Dynamic => { + // Final type is non-sized + TypeFlags::HOST_SHAREABLE + } + }; + let base_info = &self.types[base.index()]; + + if base_info.flags.contains(TypeFlags::DATA) { + // Currently Naga only supports binding arrays of structs for non-handle types. + match gctx.types[base].inner { + crate::TypeInner::Struct { .. } => {} + _ => return Err(TypeError::BindingArrayBaseTypeNotStruct(base)), + }; + } + + TypeInfo::new(base_info.flags & type_info_mask, Alignment::ONE) + } + }) + } +} |