summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/valid
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/naga/src/valid
parentInitial commit. (diff)
downloadfirefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz
firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/valid')
-rw-r--r--third_party/rust/naga/src/valid/analyzer.rs1200
-rw-r--r--third_party/rust/naga/src/valid/compose.rs138
-rw-r--r--third_party/rust/naga/src/valid/expression.rs1482
-rw-r--r--third_party/rust/naga/src/valid/function.rs1048
-rw-r--r--third_party/rust/naga/src/valid/handles.rs652
-rw-r--r--third_party/rust/naga/src/valid/interface.rs676
-rw-r--r--third_party/rust/naga/src/valid/mod.rs467
-rw-r--r--third_party/rust/naga/src/valid/type.rs636
8 files changed, 6299 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs
new file mode 100644
index 0000000000..e9b155b6eb
--- /dev/null
+++ b/third_party/rust/naga/src/valid/analyzer.rs
@@ -0,0 +1,1200 @@
+/*! Module analyzer.
+
+Figures out the following properties:
+ - control flow uniformity
+ - texture/sampler pairs
+ - expression reference counts
+!*/
+
+use super::{ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
+use crate::span::{AddSpan as _, WithSpan};
+use crate::{
+ arena::{Arena, Handle},
+ proc::{ResolveContext, TypeResolution},
+};
+use std::ops;
+
+pub type NonUniformResult = Option<Handle<crate::Expression>>;
+
+bitflags::bitflags! {
+ /// Kinds of expressions that require uniform control flow.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct UniformityRequirements: u8 {
+ const WORK_GROUP_BARRIER = 0x1;
+ const DERIVATIVE = 0x2;
+ const IMPLICIT_LEVEL = 0x4;
+ }
+}
+
+/// Uniform control flow characteristics.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Uniformity {
+ /// A child expression with non-uniform result.
+ ///
+ /// This means, when the relevant invocations are scheduled on a compute unit,
+ /// they have to use vector registers to store an individual value
+ /// per invocation.
+ ///
+ /// Whenever the control flow is conditioned on such value,
+ /// the hardware needs to keep track of the mask of invocations,
+ /// and process all branches of the control flow.
+ ///
+ /// Any operations that depend on non-uniform results also produce non-uniform.
+ pub non_uniform_result: NonUniformResult,
+ /// If this expression requires uniform control flow, store the reason here.
+ pub requirements: UniformityRequirements,
+}
+
+impl Uniformity {
+ const fn new() -> Self {
+ Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ struct ExitFlags: u8 {
+ /// Control flow may return from the function, which makes all the
+ /// subsequent statements within the current function (only!)
+ /// to be executed in a non-uniform control flow.
+ const MAY_RETURN = 0x1;
+ /// Control flow may be killed. Anything after `Statement::Kill` is
+ /// considered inside non-uniform context.
+ const MAY_KILL = 0x2;
+ }
+}
+
+/// Uniformity characteristics of a function.
+#[cfg_attr(test, derive(Debug, PartialEq))]
+struct FunctionUniformity {
+ result: Uniformity,
+ exit: ExitFlags,
+}
+
+impl ops::BitOr for FunctionUniformity {
+ type Output = Self;
+ fn bitor(self, other: Self) -> Self {
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: self
+ .result
+ .non_uniform_result
+ .or(other.result.non_uniform_result),
+ requirements: self.result.requirements | other.result.requirements,
+ },
+ exit: self.exit | other.exit,
+ }
+ }
+}
+
+impl FunctionUniformity {
+ const fn new() -> Self {
+ FunctionUniformity {
+ result: Uniformity::new(),
+ exit: ExitFlags::empty(),
+ }
+ }
+
+ /// Returns a disruptor based on the stored exit flags, if any.
+ const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
+ if self.exit.contains(ExitFlags::MAY_RETURN) {
+ Some(UniformityDisruptor::Return)
+ } else if self.exit.contains(ExitFlags::MAY_KILL) {
+ Some(UniformityDisruptor::Discard)
+ } else {
+ None
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Indicates how a global variable is used.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct GlobalUse: u8 {
+ /// Data will be read from the variable.
+ const READ = 0x1;
+ /// Data will be written to the variable.
+ const WRITE = 0x2;
+ /// The information about the data is queried.
+ const QUERY = 0x4;
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct SamplingKey {
+ pub image: Handle<crate::GlobalVariable>,
+ pub sampler: Handle<crate::GlobalVariable>,
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct ExpressionInfo {
+ pub uniformity: Uniformity,
+ pub ref_count: usize,
+ assignable_global: Option<Handle<crate::GlobalVariable>>,
+ pub ty: TypeResolution,
+}
+
+impl ExpressionInfo {
+ const fn new() -> Self {
+ ExpressionInfo {
+ uniformity: Uniformity::new(),
+ ref_count: 0,
+ assignable_global: None,
+ // this doesn't matter at this point, will be overwritten
+ ty: TypeResolution::Value(crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 0,
+ }),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+enum GlobalOrArgument {
+ Global(Handle<crate::GlobalVariable>),
+ Argument(u32),
+}
+
+impl GlobalOrArgument {
+ fn from_expression(
+ expression_arena: &Arena<crate::Expression>,
+ expression: Handle<crate::Expression>,
+ ) -> Result<GlobalOrArgument, ExpressionError> {
+ Ok(match expression_arena[expression] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ },
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ })
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+struct Sampling {
+ image: GlobalOrArgument,
+ sampler: GlobalOrArgument,
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct FunctionInfo {
+ /// Validation flags.
+ #[allow(dead_code)]
+ flags: ValidationFlags,
+ /// Set of shader stages where calling this function is valid.
+ pub available_stages: ShaderStages,
+ /// Uniformity characteristics.
+ pub uniformity: Uniformity,
+ /// Function may kill the invocation.
+ pub may_kill: bool,
+
+ /// All pairs of (texture, sampler) globals that may be used together in
+ /// sampling operations by this function and its callees. This includes
+ /// pairings that arise when this function passes textures and samplers as
+ /// arguments to its callees.
+ ///
+ /// This table does not include uses of textures and samplers passed as
+ /// arguments to this function itself, since we do not know which globals
+ /// those will be. However, this table *is* exhaustive when computed for an
+ /// entry point function: entry points never receive textures or samplers as
+ /// arguments, so all an entry point's sampling can be reported in terms of
+ /// globals.
+ ///
+ /// The GLSL back end uses this table to construct reflection info that
+ /// clients need to construct texture-combined sampler values.
+ pub sampling_set: crate::FastHashSet<SamplingKey>,
+
+ /// How this function and its callees use this module's globals.
+ ///
+ /// This is indexed by `Handle<GlobalVariable>` indices. However,
+ /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`,
+ /// so you can simply index this struct with a global handle to retrieve
+ /// its usage information.
+ global_uses: Box<[GlobalUse]>,
+
+ /// Information about each expression in this function's body.
+ ///
+ /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
+ /// implements `std::ops::Index<Handle<Expression>>`, so you can simply
+ /// index this struct with an expression handle to retrieve its
+ /// `ExpressionInfo`.
+ expressions: Box<[ExpressionInfo]>,
+
+ /// All (texture, sampler) pairs that may be used together in sampling
+ /// operations by this function and its callees, whether they are accessed
+ /// as globals or passed as arguments.
+ ///
+ /// Participants are represented by [`GlobalVariable`] handles whenever
+ /// possible, and otherwise by indices of this function's arguments.
+ ///
+ /// When analyzing a function call, we combine this data about the callee
+ /// with the actual arguments being passed to produce the callers' own
+ /// `sampling_set` and `sampling` tables.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ sampling: crate::FastHashSet<Sampling>,
+}
+
+impl FunctionInfo {
+ pub const fn global_variable_count(&self) -> usize {
+ self.global_uses.len()
+ }
+ pub const fn expression_count(&self) -> usize {
+ self.expressions.len()
+ }
+ pub fn dominates_global_use(&self, other: &Self) -> bool {
+ for (self_global_uses, other_global_uses) in
+ self.global_uses.iter().zip(other.global_uses.iter())
+ {
+ if !self_global_uses.contains(*other_global_uses) {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
+ type Output = GlobalUse;
+ fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
+ &self.global_uses[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
+ type Output = ExpressionInfo;
+ fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
+ &self.expressions[handle.index()]
+ }
+}
+
+/// Disruptor of the uniform control flow.
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum UniformityDisruptor {
+ #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
+ Expression(Handle<crate::Expression>),
+ #[error("There is a Return earlier in the control flow of the function")]
+ Return,
+ #[error("There is a Discard earlier in the entry point across all called functions")]
+ Discard,
+}
+
+impl FunctionInfo {
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref_impl(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ global_use: GlobalUse,
+ ) -> NonUniformResult {
+ let info = &mut self.expressions[handle.index()];
+ info.ref_count += 1;
+ // mark the used global as read
+ if let Some(global) = info.assignable_global {
+ self.global_uses[global.index()] |= global_use;
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
+ self.add_ref_impl(handle, GlobalUse::READ)
+ }
+
+ /// Adds a potentially assignable reference to an expression.
+ /// These are destinations for `Store` and `ImageStore` statements,
+ /// which can transit through `Access` and `AccessIndex`.
+ #[must_use]
+ fn add_assignable_ref(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
+ ) -> NonUniformResult {
+ let info = &mut self.expressions[handle.index()];
+ info.ref_count += 1;
+ // propagate the assignable global up the chain, till it either hits
+ // a value-type expression, or the assignment statement.
+ if let Some(global) = info.assignable_global {
+ if let Some(_old) = assignable_global.replace(global) {
+ unreachable!()
+ }
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Inherit information from a called function.
+ fn process_call(
+ &mut self,
+ callee: &Self,
+ arguments: &[Handle<crate::Expression>],
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ self.sampling_set
+ .extend(callee.sampling_set.iter().cloned());
+ for sampling in callee.sampling.iter() {
+ // If the callee was passed the texture or sampler as an argument,
+ // we may now be able to determine which globals those referred to.
+ let image_storage = match sampling.image {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ let sampler_storage = match sampling.sampler {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ // If we've managed to pin both the image and sampler down to
+ // specific globals, record that in our `sampling_set`. Otherwise,
+ // record as much as we do know in our own `sampling` table, for our
+ // callers to sort out.
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ (image, sampler) => {
+ self.sampling.insert(Sampling { image, sampler });
+ }
+ }
+ }
+
+ // Inherit global use from our callees.
+ for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
+ *mine |= *other;
+ }
+
+ Ok(FunctionUniformity {
+ result: callee.uniformity.clone(),
+ exit: if callee.may_kill {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ })
+ }
+
+ /// Computes the expression info and stores it in `self.expressions`.
+ /// Also, bumps the reference counts on dependent expressions.
+ #[allow(clippy::or_fun_call)]
+ fn process_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ expression: &crate::Expression,
+ expression_arena: &Arena<crate::Expression>,
+ other_functions: &[FunctionInfo],
+ resolve_context: &ResolveContext,
+ capabilities: super::Capabilities,
+ ) -> Result<(), ExpressionError> {
+ use crate::{Expression as E, SampleLevel as Sl};
+
+ let mut assignable_global = None;
+ let uniformity = match *expression {
+ E::Access { base, index } => {
+ let base_ty = self[base].ty.inner_with(resolve_context.types);
+
+ // build up the caps needed if this is indexed non-uniformly
+ let mut needed_caps = super::Capabilities::empty();
+ let is_binding_array = match *base_ty {
+ crate::TypeInner::BindingArray {
+ base: array_element_ty_handle,
+ ..
+ } => {
+ // these are nasty aliases, but these idents are too long and break rustfmt
+ let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
+ let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
+ let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
+
+ // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
+ let array_element_ty =
+ &resolve_context.types[array_element_ty_handle].inner;
+
+ needed_caps |= match *array_element_ty {
+ // If we're an image, use the appropriate limit.
+ crate::TypeInner::Image { class, .. } => match class {
+ crate::ImageClass::Storage { .. } => ub_st,
+ _ => st_sb,
+ },
+ crate::TypeInner::Sampler { .. } => sampler,
+ // If we're anything but an image, assume we're a buffer and use the address space.
+ _ => {
+ if let E::GlobalVariable(global_handle) = expression_arena[base] {
+ let global = &resolve_context.global_vars[global_handle];
+ match global.space {
+ crate::AddressSpace::Uniform => ub_st,
+ crate::AddressSpace::Storage { .. } => st_sb,
+ _ => unreachable!(),
+ }
+ } else {
+ unreachable!()
+ }
+ }
+ };
+
+ true
+ }
+ _ => false,
+ };
+
+ if self[index].uniformity.non_uniform_result.is_some()
+ && !capabilities.contains(needed_caps)
+ && is_binding_array
+ {
+ return Err(ExpressionError::MissingCapabilities(needed_caps));
+ }
+
+ Uniformity {
+ non_uniform_result: self
+ .add_assignable_ref(base, &mut assignable_global)
+ .or(self.add_ref(index)),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::AccessIndex { base, .. } => Uniformity {
+ non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
+ requirements: UniformityRequirements::empty(),
+ },
+ // always uniform
+ E::Constant(_) => Uniformity::new(),
+ E::Splat { size: _, value } => Uniformity {
+ non_uniform_result: self.add_ref(value),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Swizzle { vector, .. } => Uniformity {
+ non_uniform_result: self.add_ref(vector),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Compose { ref components, .. } => {
+ let non_uniform_result = components
+ .iter()
+ .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
+ Uniformity {
+ non_uniform_result,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the builtin or interpolation
+ E::FunctionArgument(index) => {
+ let arg = &resolve_context.arguments[index as usize];
+ let uniform = match arg.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match built_in {
+ // per-polygon built-ins are uniform
+ crate::BuiltIn::FrontFacing
+ // per-work-group built-ins are uniform
+ | crate::BuiltIn::WorkGroupId
+ | crate::BuiltIn::WorkGroupSize
+ | crate::BuiltIn::NumWorkGroups => true,
+ _ => false,
+ },
+ // only flat inputs are uniform
+ Some(crate::Binding::Location {
+ interpolation: Some(crate::Interpolation::Flat),
+ ..
+ }) => true,
+ _ => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the address space
+ E::GlobalVariable(gh) => {
+ use crate::AddressSpace as As;
+ assignable_global = Some(gh);
+ let var = &resolve_context.global_vars[gh];
+ let uniform = match var.space {
+ // local data is non-uniform
+ As::Function | As::Private => false,
+ // workgroup memory is exclusively accessed by the group
+ As::WorkGroup => true,
+ // uniform data
+ As::Uniform | As::PushConstant => true,
+ // storage data is only uniform when read-only
+ As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
+ As::Handle => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::LocalVariable(_) => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Load { pointer } => Uniformity {
+ non_uniform_result: self.add_ref(pointer),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
+ let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
+
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ _ => {
+ self.sampling.insert(Sampling {
+ image: image_storage,
+ sampler: sampler_storage,
+ });
+ }
+ }
+
+ // "nur" == "Non-Uniform Result"
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let level_nur = match level {
+ Sl::Auto | Sl::Zero => None,
+ Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
+ Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
+ };
+ let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(sampler))
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(level_nur)
+ .or(dref_nur),
+ requirements: if level.implicit_derivatives() {
+ UniformityRequirements::IMPLICIT_LEVEL
+ } else {
+ UniformityRequirements::empty()
+ },
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let sample_nur = sample.and_then(|h| self.add_ref(h));
+ let level_nur = level.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(sample_nur)
+ .or(level_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::ImageQuery { image, query } => {
+ let query_nur = match query {
+ crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
+ _ => None,
+ };
+ Uniformity {
+ non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::Unary { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Binary { left, right, .. } => Uniformity {
+ non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => Uniformity {
+ non_uniform_result: self
+ .add_ref(condition)
+ .or(self.add_ref(accept))
+ .or(self.add_ref(reject)),
+ requirements: UniformityRequirements::empty(),
+ },
+ // explicit derivatives require uniform
+ E::Derivative { expr, .. } => Uniformity {
+ //Note: taking a derivative of a uniform doesn't make it non-uniform
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ E::Relational { argument, .. } => Uniformity {
+ non_uniform_result: self.add_ref(argument),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Math {
+ arg, arg1, arg2, ..
+ } => {
+ let arg1_nur = arg1.and_then(|h| self.add_ref(h));
+ let arg2_nur = arg2.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::As { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::CallResult(function) => other_functions[function.index()].uniformity.clone(),
+ E::AtomicResult { .. } | E::RayQueryProceedResult => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ArrayLength(expr) => Uniformity {
+ non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => Uniformity {
+ non_uniform_result: self.add_ref(query),
+ requirements: UniformityRequirements::empty(),
+ },
+ };
+
+ let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
+ self.expressions[handle.index()] = ExpressionInfo {
+ uniformity,
+ ref_count: 0,
+ assignable_global,
+ ty,
+ };
+ Ok(())
+ }
+
+ /// Analyzes the uniformity requirements of a block (as a sequence of statements).
+ /// Returns the uniformity characteristics at the *function* level, i.e.
+ /// whether or not the function requires to be called in uniform control flow,
+ /// and whether the produced result is not disrupting the control flow.
+ ///
+ /// The parent control flow is uniform if `disruptor.is_none()`.
+ ///
+ /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
+ /// require uniformity, but the current flow is non-uniform.
+ #[allow(clippy::or_fun_call)]
+ fn process_block(
+ &mut self,
+ statements: &crate::Block,
+ other_functions: &[FunctionInfo],
+ mut disruptor: Option<UniformityDisruptor>,
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ use crate::Statement as S;
+
+ let mut combined_uniformity = FunctionUniformity::new();
+ for statement in statements {
+ let uniformity = match *statement {
+ S::Emit(ref range) => {
+ let mut requirements = UniformityRequirements::empty();
+ for expr in range.clone() {
+ let req = self.expressions[expr.index()].uniformity.requirements;
+ #[cfg(feature = "validate")]
+ if self
+ .flags
+ .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
+ && !req.is_empty()
+ {
+ if let Some(cause) = disruptor {
+ return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
+ .with_span_handle(expr, expression_arena));
+ }
+ }
+ requirements |= req;
+ }
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements,
+ },
+ exit: ExitFlags::empty(),
+ }
+ }
+ S::Break | S::Continue => FunctionUniformity::new(),
+ S::Kill => FunctionUniformity {
+ result: Uniformity::new(),
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ S::Barrier(_) => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::WORK_GROUP_BARRIER,
+ },
+ exit: ExitFlags::empty(),
+ },
+ S::Block(ref b) => {
+ self.process_block(b, other_functions, disruptor, expression_arena)?
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_nur = self.add_ref(condition);
+ let branch_disruptor =
+ disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
+ let accept_uniformity = self.process_block(
+ accept,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ let reject_uniformity = self.process_block(
+ reject,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ accept_uniformity | reject_uniformity
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_nur = self.add_ref(selector);
+ let branch_disruptor =
+ disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
+ let mut uniformity = FunctionUniformity::new();
+ let mut case_disruptor = branch_disruptor;
+ for case in cases.iter() {
+ let case_uniformity = self.process_block(
+ &case.body,
+ other_functions,
+ case_disruptor,
+ expression_arena,
+ )?;
+ case_disruptor = if case.fall_through {
+ case_disruptor.or(case_uniformity.exit_disruptor())
+ } else {
+ branch_disruptor
+ };
+ uniformity = uniformity | case_uniformity;
+ }
+ uniformity
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if: _,
+ } => {
+ let body_uniformity =
+ self.process_block(body, other_functions, disruptor, expression_arena)?;
+ let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
+ let continuing_uniformity = self.process_block(
+ continuing,
+ other_functions,
+ continuing_disruptor,
+ expression_arena,
+ )?;
+ body_uniformity | continuing_uniformity
+ }
+ S::Return { value } => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_RETURN
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ // Here and below, the used expressions are already emitted,
+ // and their results do not affect the function return value,
+ // so we can ignore their non-uniformity.
+ S::Store { pointer, value } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let _ = self.add_ref_impl(image, GlobalUse::WRITE);
+ if let Some(expr) = array_index {
+ let _ = self.add_ref(expr);
+ }
+ let _ = self.add_ref(coordinate);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result: _,
+ } => {
+ for &argument in arguments {
+ let _ = self.add_ref(argument);
+ }
+ let info = &other_functions[function.index()];
+ //Note: the result is validated by the Validator, not here
+ self.process_call(info, arguments, expression_arena)?
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result: _,
+ } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ let _ = self.add_ref(cmp);
+ }
+ FunctionUniformity::new()
+ }
+ S::RayQuery { query, ref fun } => {
+ let _ = self.add_ref(query);
+ if let crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } = *fun
+ {
+ let _ = self.add_ref(acceleration_structure);
+ let _ = self.add_ref(descriptor);
+ }
+ FunctionUniformity::new()
+ }
+ };
+
+ disruptor = disruptor.or(uniformity.exit_disruptor());
+ combined_uniformity = combined_uniformity | uniformity;
+ }
+ Ok(combined_uniformity)
+ }
+}
+
+impl ModuleInfo {
+ /// Builds the `FunctionInfo` based on the function, and validates the
+ /// uniform control flow if required by the expressions of this function.
+ pub(super) fn process_function(
+ &self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ flags: ValidationFlags,
+ capabilities: super::Capabilities,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ let mut info = FunctionInfo {
+ flags,
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ };
+ let resolve_context =
+ ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments);
+
+ for (handle, expr) in fun.expressions.iter() {
+ if let Err(source) = info.process_expression(
+ handle,
+ expr,
+ &fun.expressions,
+ &self.functions,
+ &resolve_context,
+ capabilities,
+ ) {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions));
+ }
+ }
+
+ let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
+ info.uniformity = uniformity.result;
+ info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
+
+ Ok(info)
+ }
+
+ pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
+ &self.entry_points[index]
+ }
+}
+
+#[test]
+#[cfg(feature = "validate")]
+fn uniform_control_flow() {
+ use crate::{Expression as E, Statement as S};
+
+ let mut constant_arena = Arena::new();
+ let constant = constant_arena.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(0),
+ },
+ },
+ Default::default(),
+ );
+ let mut type_arena = crate::UniqueArena::new();
+ let ty = type_arena.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+ let mut global_var_arena = Arena::new();
+ let non_uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ space: crate::AddressSpace::Handle,
+ binding: None,
+ },
+ Default::default(),
+ );
+ let uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ binding: None,
+ space: crate::AddressSpace::Uniform,
+ },
+ Default::default(),
+ );
+
+ let mut expressions = Arena::new();
+ // checks the uniform control flow
+ let constant_expr = expressions.append(E::Constant(constant), Default::default());
+ // checks the non-uniform control flow
+ let derivative_expr = expressions.append(
+ E::Derivative {
+ axis: crate::DerivativeAxis::X,
+ ctrl: crate::DerivativeControl::None,
+ expr: constant_expr,
+ },
+ Default::default(),
+ );
+ let emit_range_constant_derivative = expressions.range_from(0);
+ let non_uniform_global_expr =
+ expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
+ let uniform_global_expr =
+ expressions.append(E::GlobalVariable(uniform_global), Default::default());
+ let emit_range_globals = expressions.range_from(2);
+
+ // checks the QUERY flag
+ let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
+ // checks the transitive WRITE flag
+ let access_expr = expressions.append(
+ E::AccessIndex {
+ base: non_uniform_global_expr,
+ index: 1,
+ },
+ Default::default(),
+ );
+ let emit_range_query_access_globals = expressions.range_from(2);
+
+ let mut info = FunctionInfo {
+ flags: ValidationFlags::all(),
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ };
+ let resolve_context = ResolveContext {
+ constants: &constant_arena,
+ types: &type_arena,
+ special_types: &crate::SpecialTypes::default(),
+ global_vars: &global_var_arena,
+ local_vars: &Arena::new(),
+ functions: &Arena::new(),
+ arguments: &[],
+ };
+ for (handle, expression) in expressions.iter() {
+ info.process_expression(
+ handle,
+ expression,
+ &expressions,
+ &[],
+ &resolve_context,
+ super::Capabilities::empty(),
+ )
+ .unwrap();
+ }
+ assert_eq!(info[non_uniform_global_expr].ref_count, 1);
+ assert_eq!(info[uniform_global_expr].ref_count, 1);
+ assert_eq!(info[query_expr].ref_count, 0);
+ assert_eq!(info[access_expr].ref_count, 0);
+ assert_eq!(info[non_uniform_global], GlobalUse::empty());
+ assert_eq!(info[uniform_global], GlobalUse::QUERY);
+
+ let stmt_emit1 = S::Emit(emit_range_globals.clone());
+ let stmt_if_uniform = S::If {
+ condition: uniform_global_expr,
+ accept: crate::Block::new(),
+ reject: vec![
+ S::Emit(emit_range_constant_derivative.clone()),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit1, stmt_if_uniform].into(),
+ &[],
+ None,
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ exit: ExitFlags::empty(),
+ }),
+ );
+ assert_eq!(info[constant_expr].ref_count, 2);
+ assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
+
+ let stmt_emit2 = S::Emit(emit_range_globals.clone());
+ let stmt_if_non_uniform = S::If {
+ condition: non_uniform_global_expr,
+ accept: vec![
+ S::Emit(emit_range_constant_derivative),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ reject: crate::Block::new(),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit2, stmt_if_non_uniform].into(),
+ &[],
+ None,
+ &expressions
+ ),
+ Err(FunctionError::NonUniformControlFlow(
+ UniformityRequirements::DERIVATIVE,
+ derivative_expr,
+ UniformityDisruptor::Expression(non_uniform_global_expr)
+ )
+ .with_span()),
+ );
+ assert_eq!(info[derivative_expr].ref_count, 1);
+ assert_eq!(info[non_uniform_global], GlobalUse::READ);
+
+ let stmt_emit3 = S::Emit(emit_range_globals);
+ let stmt_return_non_uniform = S::Return {
+ value: Some(non_uniform_global_expr),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit3, stmt_return_non_uniform].into(),
+ &[],
+ Some(UniformityDisruptor::Return),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::MAY_RETURN,
+ }),
+ );
+ assert_eq!(info[non_uniform_global_expr].ref_count, 3);
+
+ // Check that uniformity requirements reach through a pointer
+ let stmt_emit4 = S::Emit(emit_range_query_access_globals);
+ let stmt_assign = S::Store {
+ pointer: access_expr,
+ value: query_expr,
+ };
+ let stmt_return_pointer = S::Return {
+ value: Some(access_expr),
+ };
+ let stmt_kill = S::Kill;
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
+ &[],
+ Some(UniformityDisruptor::Discard),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::all(),
+ }),
+ );
+ assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
+}
diff --git a/third_party/rust/naga/src/valid/compose.rs b/third_party/rust/naga/src/valid/compose.rs
new file mode 100644
index 0000000000..e77d538255
--- /dev/null
+++ b/third_party/rust/naga/src/valid/compose.rs
@@ -0,0 +1,138 @@
+#[cfg(feature = "validate")]
+use crate::{
+ arena::{Arena, UniqueArena},
+ proc::TypeResolution,
+};
+
+use crate::arena::Handle;
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ComposeError {
+ #[error("Composing of type {0:?} can't be done")]
+ Type(Handle<crate::Type>),
+ #[error("Composing expects {expected} components but {given} were given")]
+ ComponentCount { given: u32, expected: u32 },
+ #[error("Composing {index}'s component type is not expected")]
+ ComponentType { index: u32 },
+}
+
+#[cfg(feature = "validate")]
+pub fn validate_compose(
+ self_ty_handle: Handle<crate::Type>,
+ constant_arena: &Arena<crate::Constant>,
+ type_arena: &UniqueArena<crate::Type>,
+ component_resolutions: impl ExactSizeIterator<Item = TypeResolution>,
+) -> Result<(), ComposeError> {
+ use crate::TypeInner as Ti;
+
+ match type_arena[self_ty_handle].inner {
+ // vectors are composed from scalars or other vectors
+ Ti::Vector { size, kind, width } => {
+ let mut total = 0;
+ for (index, comp_res) in component_resolutions.enumerate() {
+ total += match *comp_res.inner_with(type_arena) {
+ Ti::Scalar {
+ kind: comp_kind,
+ width: comp_width,
+ } if comp_kind == kind && comp_width == width => 1,
+ Ti::Vector {
+ size: comp_size,
+ kind: comp_kind,
+ width: comp_width,
+ } if comp_kind == kind && comp_width == width => comp_size as u32,
+ ref other => {
+ log::error!("Vector component[{}] type {:?}", index, other);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ };
+ }
+ if size as u32 != total {
+ return Err(ComposeError::ComponentCount {
+ expected: size as u32,
+ given: total,
+ });
+ }
+ }
+ // matrix are composed from column vectors
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ let inner = Ti::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ };
+ if columns as usize != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ expected: columns as u32,
+ given: component_resolutions.len() as u32,
+ });
+ }
+ for (index, comp_res) in component_resolutions.enumerate() {
+ if comp_res.inner_with(type_arena) != &inner {
+ log::error!("Matrix component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ Ti::Array {
+ base,
+ size: crate::ArraySize::Constant(handle),
+ stride: _,
+ } => {
+ let count = constant_arena[handle].to_array_length().unwrap();
+ if count as usize != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ expected: count,
+ given: component_resolutions.len() as u32,
+ });
+ }
+ for (index, comp_res) in component_resolutions.enumerate() {
+ let base_inner = &type_arena[base].inner;
+ let comp_res_inner = comp_res.inner_with(type_arena);
+ // We don't support arrays of pointers, but it seems best not to
+ // embed that assumption here, so use `TypeInner::equivalent`.
+ if !base_inner.equivalent(comp_res_inner, type_arena) {
+ log::error!("Array component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ Ti::Struct { ref members, .. } => {
+ if members.len() != component_resolutions.len() {
+ return Err(ComposeError::ComponentCount {
+ given: component_resolutions.len() as u32,
+ expected: members.len() as u32,
+ });
+ }
+ for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate()
+ {
+ let member_inner = &type_arena[member.ty].inner;
+ let comp_res_inner = comp_res.inner_with(type_arena);
+ // We don't support pointers in structs, but it seems best not to embed
+ // that assumption here, so use `TypeInner::equivalent`.
+ if !comp_res_inner.equivalent(member_inner, type_arena) {
+ log::error!("Struct component[{}] type {:?}", index, comp_res);
+ return Err(ComposeError::ComponentType {
+ index: index as u32,
+ });
+ }
+ }
+ }
+ ref other => {
+ log::error!("Composing of {:?}", other);
+ return Err(ComposeError::Type(self_ty_handle));
+ }
+ }
+
+ Ok(())
+}
diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs
new file mode 100644
index 0000000000..3a81707904
--- /dev/null
+++ b/third_party/rust/naga/src/valid/expression.rs
@@ -0,0 +1,1482 @@
+#[cfg(feature = "validate")]
+use std::ops::Index;
+
+#[cfg(feature = "validate")]
+use super::{
+ compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages,
+ TypeFlags,
+};
+#[cfg(feature = "validate")]
+use crate::arena::UniqueArena;
+
+use crate::{
+ arena::Handle,
+ proc::{IndexableLengthError, ResolveError},
+};
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum ExpressionError {
+ #[error("Doesn't exist")]
+ DoesntExist,
+ #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
+ NotInScope,
+ #[error("Base type {0:?} is not compatible with this expression")]
+ InvalidBaseType(Handle<crate::Expression>),
+ #[error("Accessing with index {0:?} can't be done")]
+ InvalidIndexType(Handle<crate::Expression>),
+ #[error("Accessing index {1:?} is out of {0:?} bounds")]
+ IndexOutOfBounds(Handle<crate::Expression>, crate::ScalarValue),
+ #[error("The expression {0:?} may only be indexed by a constant")]
+ IndexMustBeConstant(Handle<crate::Expression>),
+ #[error("Function argument {0:?} doesn't exist")]
+ FunctionArgumentDoesntExist(u32),
+ #[error("Loading of {0:?} can't be done")]
+ InvalidPointerType(Handle<crate::Expression>),
+ #[error("Array length of {0:?} can't be done")]
+ InvalidArrayType(Handle<crate::Expression>),
+ #[error("Get intersection of {0:?} can't be done")]
+ InvalidRayQueryType(Handle<crate::Expression>),
+ #[error("Splatting {0:?} can't be done")]
+ InvalidSplatType(Handle<crate::Expression>),
+ #[error("Swizzling {0:?} can't be done")]
+ InvalidVectorType(Handle<crate::Expression>),
+ #[error("Swizzle component {0:?} is outside of vector size {1:?}")]
+ InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize),
+ #[error(transparent)]
+ Compose(#[from] super::ComposeError),
+ #[error(transparent)]
+ IndexableLength(#[from] IndexableLengthError),
+ #[error("Operation {0:?} can't work with {1:?}")]
+ InvalidUnaryOperandType(crate::UnaryOperator, Handle<crate::Expression>),
+ #[error("Operation {0:?} can't work with {1:?} and {2:?}")]
+ InvalidBinaryOperandTypes(
+ crate::BinaryOperator,
+ Handle<crate::Expression>,
+ Handle<crate::Expression>,
+ ),
+ #[error("Selecting is not possible")]
+ InvalidSelectTypes,
+ #[error("Relational argument {0:?} is not a boolean vector")]
+ InvalidBooleanVector(Handle<crate::Expression>),
+ #[error("Relational argument {0:?} is not a float")]
+ InvalidFloatArgument(Handle<crate::Expression>),
+ #[error("Type resolution failed")]
+ Type(#[from] ResolveError),
+ #[error("Not a global variable")]
+ ExpectedGlobalVariable,
+ #[error("Not a global variable or a function argument")]
+ ExpectedGlobalOrArgument,
+ #[error("Needs to be an binding array instead of {0:?}")]
+ ExpectedBindingArrayType(Handle<crate::Type>),
+ #[error("Needs to be an image instead of {0:?}")]
+ ExpectedImageType(Handle<crate::Type>),
+ #[error("Needs to be an image instead of {0:?}")]
+ ExpectedSamplerType(Handle<crate::Type>),
+ #[error("Unable to operate on image class {0:?}")]
+ InvalidImageClass(crate::ImageClass),
+ #[error("Derivatives can only be taken from scalar and vector floats")]
+ InvalidDerivative,
+ #[error("Image array index parameter is misplaced")]
+ InvalidImageArrayIndex,
+ #[error("Inappropriate sample or level-of-detail index for texel access")]
+ InvalidImageOtherIndex,
+ #[error("Image array index type of {0:?} is not an integer scalar")]
+ InvalidImageArrayIndexType(Handle<crate::Expression>),
+ #[error("Image sample or level-of-detail index's type of {0:?} is not an integer scalar")]
+ InvalidImageOtherIndexType(Handle<crate::Expression>),
+ #[error("Image coordinate type of {1:?} does not match dimension {0:?}")]
+ InvalidImageCoordinateType(crate::ImageDimension, Handle<crate::Expression>),
+ #[error("Comparison sampling mismatch: image has class {image:?}, but the sampler is comparison={sampler}, and the reference was provided={has_ref}")]
+ ComparisonSamplingMismatch {
+ image: crate::ImageClass,
+ sampler: bool,
+ has_ref: bool,
+ },
+ #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
+ InvalidSampleOffset(crate::ImageDimension, Handle<crate::Constant>),
+ #[error("Depth reference {0:?} is not a scalar float")]
+ InvalidDepthReference(Handle<crate::Expression>),
+ #[error("Depth sample level can only be Auto or Zero")]
+ InvalidDepthSampleLevel,
+ #[error("Gather level can only be Zero")]
+ InvalidGatherLevel,
+ #[error("Gather component {0:?} doesn't exist in the image")]
+ InvalidGatherComponent(crate::SwizzleComponent),
+ #[error("Gather can't be done for image dimension {0:?}")]
+ InvalidGatherDimension(crate::ImageDimension),
+ #[error("Sample level (exact) type {0:?} is not a scalar float")]
+ InvalidSampleLevelExactType(Handle<crate::Expression>),
+ #[error("Sample level (bias) type {0:?} is not a scalar float")]
+ InvalidSampleLevelBiasType(Handle<crate::Expression>),
+ #[error("Sample level (gradient) of {1:?} doesn't match the image dimension {0:?}")]
+ InvalidSampleLevelGradientType(crate::ImageDimension, Handle<crate::Expression>),
+ #[error("Unable to cast")]
+ InvalidCastArgument,
+ #[error("Invalid argument count for {0:?}")]
+ WrongArgumentCount(crate::MathFunction),
+ #[error("Argument [{1}] to {0:?} as expression {2:?} has an invalid type.")]
+ InvalidArgumentType(crate::MathFunction, u32, Handle<crate::Expression>),
+ #[error("Atomic result type can't be {0:?}")]
+ InvalidAtomicResultType(Handle<crate::Type>),
+ #[error("Shader requires capability {0:?}")]
+ MissingCapabilities(super::Capabilities),
+}
+
+#[cfg(feature = "validate")]
+struct ExpressionTypeResolver<'a> {
+ root: Handle<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ info: &'a FunctionInfo,
+}
+
+#[cfg(feature = "validate")]
+impl<'a> Index<Handle<crate::Expression>> for ExpressionTypeResolver<'a> {
+ type Output = crate::TypeInner;
+
+ #[allow(clippy::panic)]
+ fn index(&self, handle: Handle<crate::Expression>) -> &Self::Output {
+ if handle < self.root {
+ self.info[handle].ty.inner_with(self.types)
+ } else {
+ // `Validator::validate_module_handles` should have caught this.
+ panic!(
+ "Depends on {:?}, which has not been processed yet",
+ self.root
+ )
+ }
+ }
+}
+
+#[cfg(feature = "validate")]
+impl super::Validator {
+ pub(super) fn validate_expression(
+ &self,
+ root: Handle<crate::Expression>,
+ expression: &crate::Expression,
+ function: &crate::Function,
+ module: &crate::Module,
+ info: &FunctionInfo,
+ other_infos: &[FunctionInfo],
+ ) -> Result<ShaderStages, ExpressionError> {
+ use crate::{Expression as E, ScalarKind as Sk, TypeInner as Ti};
+
+ let resolver = ExpressionTypeResolver {
+ root,
+ types: &module.types,
+ info,
+ };
+
+ let stages = match *expression {
+ E::Access { base, index } => {
+ let base_type = &resolver[base];
+ // See the documentation for `Expression::Access`.
+ let dynamic_indexing_restricted = match *base_type {
+ Ti::Vector { .. } => false,
+ Ti::Matrix { .. } | Ti::Array { .. } => true,
+ Ti::Pointer { .. }
+ | Ti::ValuePointer { size: Some(_), .. }
+ | Ti::BindingArray { .. } => false,
+ ref other => {
+ log::error!("Indexing of {:?}", other);
+ return Err(ExpressionError::InvalidBaseType(base));
+ }
+ };
+ match resolver[index] {
+ //TODO: only allow one of these
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ width: _,
+ } => {}
+ ref other => {
+ log::error!("Indexing by {:?}", other);
+ return Err(ExpressionError::InvalidIndexType(index));
+ }
+ }
+ if dynamic_indexing_restricted
+ && function.expressions[index].is_dynamic_index(module)
+ {
+ return Err(ExpressionError::IndexMustBeConstant(base));
+ }
+
+ // If we know both the length and the index, we can do the
+ // bounds check now.
+ if let crate::proc::IndexableLength::Known(known_length) =
+ base_type.indexable_length(module)?
+ {
+ if let E::Constant(k) = function.expressions[index] {
+ if let crate::Constant {
+ // We must treat specializable constants as unknown.
+ specialization: None,
+ // Non-scalar indices should have been caught above.
+ inner: crate::ConstantInner::Scalar { value, .. },
+ ..
+ } = module.constants[k]
+ {
+ match value {
+ crate::ScalarValue::Uint(u) if u >= known_length as u64 => {
+ return Err(ExpressionError::IndexOutOfBounds(base, value));
+ }
+ crate::ScalarValue::Sint(s)
+ if s < 0 || s >= known_length as i64 =>
+ {
+ return Err(ExpressionError::IndexOutOfBounds(base, value));
+ }
+ _ => (),
+ }
+ }
+ }
+ }
+
+ ShaderStages::all()
+ }
+ E::AccessIndex { base, index } => {
+ fn resolve_index_limit(
+ module: &crate::Module,
+ top: Handle<crate::Expression>,
+ ty: &crate::TypeInner,
+ top_level: bool,
+ ) -> Result<u32, ExpressionError> {
+ let limit = match *ty {
+ Ti::Vector { size, .. }
+ | Ti::ValuePointer {
+ size: Some(size), ..
+ } => size as u32,
+ Ti::Matrix { columns, .. } => columns as u32,
+ Ti::Array {
+ size: crate::ArraySize::Constant(handle),
+ ..
+ } => module.constants[handle].to_array_length().unwrap(),
+ Ti::Array { .. } | Ti::BindingArray { .. } => u32::MAX, // can't statically know, but need run-time checks
+ Ti::Pointer { base, .. } if top_level => {
+ resolve_index_limit(module, top, &module.types[base].inner, false)?
+ }
+ Ti::Struct { ref members, .. } => members.len() as u32,
+ ref other => {
+ log::error!("Indexing of {:?}", other);
+ return Err(ExpressionError::InvalidBaseType(top));
+ }
+ };
+ Ok(limit)
+ }
+
+ let limit = resolve_index_limit(module, base, &resolver[base], true)?;
+ if index >= limit {
+ return Err(ExpressionError::IndexOutOfBounds(
+ base,
+ crate::ScalarValue::Uint(limit as _),
+ ));
+ }
+ ShaderStages::all()
+ }
+ E::Constant(_handle) => ShaderStages::all(),
+ E::Splat { size: _, value } => match resolver[value] {
+ Ti::Scalar { .. } => ShaderStages::all(),
+ ref other => {
+ log::error!("Splat scalar type {:?}", other);
+ return Err(ExpressionError::InvalidSplatType(value));
+ }
+ },
+ E::Swizzle {
+ size,
+ vector,
+ pattern,
+ } => {
+ let vec_size = match resolver[vector] {
+ Ti::Vector { size: vec_size, .. } => vec_size,
+ ref other => {
+ log::error!("Swizzle vector type {:?}", other);
+ return Err(ExpressionError::InvalidVectorType(vector));
+ }
+ };
+ for &sc in pattern[..size as usize].iter() {
+ if sc as u8 >= vec_size as u8 {
+ return Err(ExpressionError::InvalidSwizzleComponent(sc, vec_size));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::Compose { ref components, ty } => {
+ validate_compose(
+ ty,
+ &module.constants,
+ &module.types,
+ components.iter().map(|&handle| info[handle].ty.clone()),
+ )?;
+ ShaderStages::all()
+ }
+ E::FunctionArgument(index) => {
+ if index >= function.arguments.len() as u32 {
+ return Err(ExpressionError::FunctionArgumentDoesntExist(index));
+ }
+ ShaderStages::all()
+ }
+ E::GlobalVariable(_handle) => ShaderStages::all(),
+ E::LocalVariable(_handle) => ShaderStages::all(),
+ E::Load { pointer } => {
+ match resolver[pointer] {
+ Ti::Pointer { base, .. }
+ if self.types[base.index()]
+ .flags
+ .contains(TypeFlags::SIZED | TypeFlags::DATA) => {}
+ Ti::ValuePointer { .. } => {}
+ ref other => {
+ log::error!("Loading {:?}", other);
+ return Err(ExpressionError::InvalidPointerType(pointer));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::ImageSample {
+ image,
+ sampler,
+ gather,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ // check the validity of expressions
+ let image_ty = Self::global_var_ty(module, function, image)?;
+ let sampler_ty = Self::global_var_ty(module, function, sampler)?;
+
+ let comparison = match module.types[sampler_ty].inner {
+ Ti::Sampler { comparison } => comparison,
+ _ => return Err(ExpressionError::ExpectedSamplerType(sampler_ty)),
+ };
+
+ let (class, dim) = match module.types[image_ty].inner {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ // check the array property
+ if arrayed != array_index.is_some() {
+ return Err(ExpressionError::InvalidImageArrayIndex);
+ }
+ if let Some(expr) = array_index {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ width: _,
+ } => {}
+ _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
+ }
+ }
+ (class, dim)
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(image_ty)),
+ };
+
+ // check sampling and comparison properties
+ let image_depth = match class {
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Float,
+ multi: false,
+ } => false,
+ crate::ImageClass::Sampled {
+ kind: crate::ScalarKind::Uint | crate::ScalarKind::Sint,
+ multi: false,
+ } if gather.is_some() => false,
+ crate::ImageClass::Depth { multi: false } => true,
+ _ => return Err(ExpressionError::InvalidImageClass(class)),
+ };
+ if comparison != depth_ref.is_some() || (comparison && !image_depth) {
+ return Err(ExpressionError::ComparisonSamplingMismatch {
+ image: class,
+ sampler: comparison,
+ has_ref: depth_ref.is_some(),
+ });
+ }
+
+ // check texture coordinates type
+ let num_components = match dim {
+ crate::ImageDimension::D1 => 1,
+ crate::ImageDimension::D2 => 2,
+ crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3,
+ };
+ match resolver[coordinate] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ kind: Sk::Float,
+ ..
+ } if size as u32 == num_components => {}
+ _ => return Err(ExpressionError::InvalidImageCoordinateType(dim, coordinate)),
+ }
+
+ // check constant offset
+ if let Some(const_handle) = offset {
+ let good = match module.constants[const_handle].inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(_),
+ } => num_components == 1,
+ crate::ConstantInner::Scalar { .. } => false,
+ crate::ConstantInner::Composite { ty, .. } => {
+ match module.types[ty].inner {
+ Ti::Vector {
+ size,
+ kind: Sk::Sint,
+ ..
+ } => size as u32 == num_components,
+ _ => false,
+ }
+ }
+ };
+ if !good {
+ return Err(ExpressionError::InvalidSampleOffset(dim, const_handle));
+ }
+ }
+
+ // check depth reference type
+ if let Some(expr) = depth_ref {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidDepthReference(expr)),
+ }
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => {}
+ _ => return Err(ExpressionError::InvalidDepthSampleLevel),
+ }
+ }
+
+ if let Some(component) = gather {
+ match dim {
+ crate::ImageDimension::D2 | crate::ImageDimension::Cube => {}
+ crate::ImageDimension::D1 | crate::ImageDimension::D3 => {
+ return Err(ExpressionError::InvalidGatherDimension(dim))
+ }
+ };
+ let max_component = match class {
+ crate::ImageClass::Depth { .. } => crate::SwizzleComponent::X,
+ _ => crate::SwizzleComponent::W,
+ };
+ if component > max_component {
+ return Err(ExpressionError::InvalidGatherComponent(component));
+ }
+ match level {
+ crate::SampleLevel::Zero => {}
+ _ => return Err(ExpressionError::InvalidGatherLevel),
+ }
+ }
+
+ // check level properties
+ match level {
+ crate::SampleLevel::Auto => ShaderStages::FRAGMENT,
+ crate::SampleLevel::Zero => ShaderStages::all(),
+ crate::SampleLevel::Exact(expr) => {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidSampleLevelExactType(expr)),
+ }
+ ShaderStages::all()
+ }
+ crate::SampleLevel::Bias(expr) => {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidSampleLevelBiasType(expr)),
+ }
+ ShaderStages::all()
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ match resolver[x] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ kind: Sk::Float,
+ ..
+ } if size as u32 == num_components => {}
+ _ => {
+ return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x))
+ }
+ }
+ match resolver[y] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ } if num_components == 1 => {}
+ Ti::Vector {
+ size,
+ kind: Sk::Float,
+ ..
+ } if size as u32 == num_components => {}
+ _ => {
+ return Err(ExpressionError::InvalidSampleLevelGradientType(dim, y))
+ }
+ }
+ ShaderStages::all()
+ }
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let ty = Self::global_var_ty(module, function, image)?;
+ match module.types[ty].inner {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ match resolver[coordinate].image_storage_coordinates() {
+ Some(coord_dim) if coord_dim == dim => {}
+ _ => {
+ return Err(ExpressionError::InvalidImageCoordinateType(
+ dim, coordinate,
+ ))
+ }
+ };
+ if arrayed != array_index.is_some() {
+ return Err(ExpressionError::InvalidImageArrayIndex);
+ }
+ if let Some(expr) = array_index {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ width: _,
+ } => {}
+ _ => return Err(ExpressionError::InvalidImageArrayIndexType(expr)),
+ }
+ }
+
+ match (sample, class.is_multisampled()) {
+ (None, false) => {}
+ (Some(sample), true) => {
+ if resolver[sample].scalar_kind() != Some(Sk::Sint) {
+ return Err(ExpressionError::InvalidImageOtherIndexType(
+ sample,
+ ));
+ }
+ }
+ _ => {
+ return Err(ExpressionError::InvalidImageOtherIndex);
+ }
+ }
+
+ match (level, class.is_mipmapped()) {
+ (None, false) => {}
+ (Some(level), true) => {
+ if resolver[level].scalar_kind() != Some(Sk::Sint) {
+ return Err(ExpressionError::InvalidImageOtherIndexType(level));
+ }
+ }
+ _ => {
+ return Err(ExpressionError::InvalidImageOtherIndex);
+ }
+ }
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(ty)),
+ }
+ ShaderStages::all()
+ }
+ E::ImageQuery { image, query } => {
+ let ty = Self::global_var_ty(module, function, image)?;
+ match module.types[ty].inner {
+ Ti::Image { class, arrayed, .. } => {
+ let good = match query {
+ crate::ImageQuery::NumLayers => arrayed,
+ crate::ImageQuery::Size { level: None } => true,
+ crate::ImageQuery::Size { level: Some(_) }
+ | crate::ImageQuery::NumLevels => class.is_mipmapped(),
+ crate::ImageQuery::NumSamples => class.is_multisampled(),
+ };
+ if !good {
+ return Err(ExpressionError::InvalidImageClass(class));
+ }
+ }
+ _ => return Err(ExpressionError::ExpectedImageType(ty)),
+ }
+ ShaderStages::all()
+ }
+ E::Unary { op, expr } => {
+ use crate::UnaryOperator as Uo;
+ let inner = &resolver[expr];
+ match (op, inner.scalar_kind()) {
+ (_, Some(Sk::Sint | Sk::Bool))
+ //TODO: restrict Negate for bools?
+ | (Uo::Negate, Some(Sk::Float))
+ | (Uo::Not, Some(Sk::Uint)) => {}
+ other => {
+ log::error!("Op {:?} kind {:?}", op, other);
+ return Err(ExpressionError::InvalidUnaryOperandType(op, expr));
+ }
+ }
+ ShaderStages::all()
+ }
+ E::Binary { op, left, right } => {
+ use crate::BinaryOperator as Bo;
+ let left_inner = &resolver[left];
+ let right_inner = &resolver[right];
+ let good = match op {
+ Bo::Add | Bo::Subtract => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool => false,
+ },
+ Ti::Matrix { .. } => left_inner == right_inner,
+ _ => false,
+ },
+ Bo::Divide | Bo::Modulo => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool => false,
+ },
+ _ => false,
+ },
+ Bo::Multiply => {
+ let kind_allowed = match left_inner.scalar_kind() {
+ Some(Sk::Uint | Sk::Sint | Sk::Float) => true,
+ Some(Sk::Bool) | None => false,
+ };
+ let types_match = match (left_inner, right_inner) {
+ // Straight scalar and mixed scalar/vector.
+ (&Ti::Scalar { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. })
+ | (&Ti::Vector { kind: kind1, .. }, &Ti::Scalar { kind: kind2, .. })
+ | (&Ti::Scalar { kind: kind1, .. }, &Ti::Vector { kind: kind2, .. }) => {
+ kind1 == kind2
+ }
+ // Scalar/matrix.
+ (
+ &Ti::Scalar {
+ kind: Sk::Float, ..
+ },
+ &Ti::Matrix { .. },
+ )
+ | (
+ &Ti::Matrix { .. },
+ &Ti::Scalar {
+ kind: Sk::Float, ..
+ },
+ ) => true,
+ // Vector/vector.
+ (
+ &Ti::Vector {
+ kind: kind1,
+ size: size1,
+ ..
+ },
+ &Ti::Vector {
+ kind: kind2,
+ size: size2,
+ ..
+ },
+ ) => kind1 == kind2 && size1 == size2,
+ // Matrix * vector.
+ (
+ &Ti::Matrix { columns, .. },
+ &Ti::Vector {
+ kind: Sk::Float,
+ size,
+ ..
+ },
+ ) => columns == size,
+ // Vector * matrix.
+ (
+ &Ti::Vector {
+ kind: Sk::Float,
+ size,
+ ..
+ },
+ &Ti::Matrix { rows, .. },
+ ) => size == rows,
+ (&Ti::Matrix { columns, .. }, &Ti::Matrix { rows, .. }) => {
+ columns == rows
+ }
+ _ => false,
+ };
+ let left_width = match *left_inner {
+ Ti::Scalar { width, .. }
+ | Ti::Vector { width, .. }
+ | Ti::Matrix { width, .. } => width,
+ _ => 0,
+ };
+ let right_width = match *right_inner {
+ Ti::Scalar { width, .. }
+ | Ti::Vector { width, .. }
+ | Ti::Matrix { width, .. } => width,
+ _ => 0,
+ };
+ kind_allowed && types_match && left_width == right_width
+ }
+ Bo::Equal | Bo::NotEqual => left_inner.is_sized() && left_inner == right_inner,
+ Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => {
+ match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner,
+ Sk::Bool => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ }
+ }
+ Bo::LogicalAnd | Bo::LogicalOr => match *left_inner {
+ Ti::Scalar { kind: Sk::Bool, .. } | Ti::Vector { kind: Sk::Bool, .. } => {
+ left_inner == right_inner
+ }
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::And | Bo::InclusiveOr => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner,
+ Sk::Float => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::ExclusiveOr => match *left_inner {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind {
+ Sk::Sint | Sk::Uint => left_inner == right_inner,
+ Sk::Bool | Sk::Float => false,
+ },
+ ref other => {
+ log::error!("Op {:?} left type {:?}", op, other);
+ false
+ }
+ },
+ Bo::ShiftLeft | Bo::ShiftRight => {
+ let (base_size, base_kind) = match *left_inner {
+ Ti::Scalar { kind, .. } => (Ok(None), kind),
+ Ti::Vector { size, kind, .. } => (Ok(Some(size)), kind),
+ ref other => {
+ log::error!("Op {:?} base type {:?}", op, other);
+ (Err(()), Sk::Bool)
+ }
+ };
+ let shift_size = match *right_inner {
+ Ti::Scalar { kind: Sk::Uint, .. } => Ok(None),
+ Ti::Vector {
+ size,
+ kind: Sk::Uint,
+ ..
+ } => Ok(Some(size)),
+ ref other => {
+ log::error!("Op {:?} shift type {:?}", op, other);
+ Err(())
+ }
+ };
+ match base_kind {
+ Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size,
+ Sk::Float | Sk::Bool => false,
+ }
+ }
+ };
+ if !good {
+ log::error!(
+ "Left: {:?} of type {:?}",
+ function.expressions[left],
+ left_inner
+ );
+ log::error!(
+ "Right: {:?} of type {:?}",
+ function.expressions[right],
+ right_inner
+ );
+ return Err(ExpressionError::InvalidBinaryOperandTypes(op, left, right));
+ }
+ ShaderStages::all()
+ }
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ let accept_inner = &resolver[accept];
+ let reject_inner = &resolver[reject];
+ let condition_good = match resolver[condition] {
+ Ti::Scalar {
+ kind: Sk::Bool,
+ width: _,
+ } => {
+ // When `condition` is a single boolean, `accept` and
+ // `reject` can be vectors or scalars.
+ match *accept_inner {
+ Ti::Scalar { .. } | Ti::Vector { .. } => true,
+ _ => false,
+ }
+ }
+ Ti::Vector {
+ size,
+ kind: Sk::Bool,
+ width: _,
+ } => match *accept_inner {
+ Ti::Vector {
+ size: other_size, ..
+ } => size == other_size,
+ _ => false,
+ },
+ _ => false,
+ };
+ if !condition_good || accept_inner != reject_inner {
+ return Err(ExpressionError::InvalidSelectTypes);
+ }
+ ShaderStages::all()
+ }
+ E::Derivative { expr, .. } => {
+ match resolver[expr] {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidDerivative),
+ }
+ ShaderStages::FRAGMENT
+ }
+ E::Relational { fun, argument } => {
+ use crate::RelationalFunction as Rf;
+ let argument_inner = &resolver[argument];
+ match fun {
+ Rf::All | Rf::Any => match *argument_inner {
+ Ti::Vector { kind: Sk::Bool, .. } => {}
+ ref other => {
+ log::error!("All/Any of type {:?}", other);
+ return Err(ExpressionError::InvalidBooleanVector(argument));
+ }
+ },
+ Rf::IsNan | Rf::IsInf | Rf::IsFinite | Rf::IsNormal => match *argument_inner {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ ref other => {
+ log::error!("Float test of type {:?}", other);
+ return Err(ExpressionError::InvalidFloatArgument(argument));
+ }
+ },
+ }
+ ShaderStages::all()
+ }
+ E::Math {
+ fun,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ use crate::MathFunction as Mf;
+
+ let resolve = |arg| &resolver[arg];
+ let arg_ty = resolve(arg);
+ let arg1_ty = arg1.map(resolve);
+ let arg2_ty = arg2.map(resolve);
+ let arg3_ty = arg3.map(resolve);
+ match fun {
+ Mf::Abs => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ let good = match *arg_ty {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ }
+ Mf::Min | Mf::Max => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let good = match *arg_ty {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Clamp => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let good = match *arg_ty {
+ Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => kind != Sk::Bool,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ if arg2_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ Mf::Saturate
+ | Mf::Cos
+ | Mf::Cosh
+ | Mf::Sin
+ | Mf::Sinh
+ | Mf::Tan
+ | Mf::Tanh
+ | Mf::Acos
+ | Mf::Asin
+ | Mf::Atan
+ | Mf::Asinh
+ | Mf::Acosh
+ | Mf::Atanh
+ | Mf::Radians
+ | Mf::Degrees
+ | Mf::Ceil
+ | Mf::Floor
+ | Mf::Round
+ | Mf::Fract
+ | Mf::Trunc
+ | Mf::Exp
+ | Mf::Exp2
+ | Mf::Log
+ | Mf::Log2
+ | Mf::Length
+ | Mf::Sign
+ | Mf::Sqrt
+ | Mf::InverseSqrt => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Modf | Mf::Frexp | Mf::Ldexp => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let (size0, width0) = match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ } => (None, width),
+ Ti::Vector {
+ kind: Sk::Float,
+ size,
+ width,
+ } => (Some(size), width),
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ };
+ let good = match *arg1_ty {
+ Ti::Pointer { base, space: _ } => module.types[base].inner == *arg_ty,
+ Ti::ValuePointer {
+ size,
+ kind: Sk::Float,
+ width,
+ space: _,
+ } => size == size0 && width == width0,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Dot => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float | Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Outer | Mf::Cross | Mf::Reflect => {
+ let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), None, None) => ty1,
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ }
+ Mf::Refract => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+
+ match (arg_ty, arg2_ty) {
+ (
+ &Ti::Vector {
+ width: vector_width,
+ ..
+ },
+ &Ti::Scalar {
+ width: scalar_width,
+ kind: Sk::Float,
+ },
+ ) if vector_width == scalar_width => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::Normalize => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::FaceForward | Mf::Fma | Mf::SmoothStep => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float, ..
+ }
+ | Ti::Vector {
+ kind: Sk::Float, ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ if arg2_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ Mf::Mix => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ let arg_width = match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ }
+ | Ti::Vector {
+ kind: Sk::Float,
+ width,
+ ..
+ } => width,
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ };
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ // the last argument can always be a scalar
+ match *arg2_ty {
+ Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ } if width == arg_width => {}
+ _ if arg2_ty == arg_ty => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ));
+ }
+ }
+ }
+ Mf::Inverse | Mf::Determinant => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ let good = match *arg_ty {
+ Ti::Matrix { columns, rows, .. } => columns == rows,
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidArgumentType(fun, 0, arg));
+ }
+ }
+ Mf::Transpose => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Matrix { .. } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::CountTrailingZeros
+ | Mf::CountLeadingZeros
+ | Mf::CountOneBits
+ | Mf::ReverseBits
+ | Mf::FindLsb
+ | Mf::FindMsb => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }
+ | Ti::Vector {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::InsertBits => {
+ let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }
+ | Ti::Vector {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ if arg1_ty != arg_ty {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 1,
+ arg1.unwrap(),
+ ));
+ }
+ match *arg2_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ match *arg3_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg3.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::ExtractBits => {
+ let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) {
+ (Some(ty1), Some(ty2), None) => (ty1, ty2),
+ _ => return Err(ExpressionError::WrongArgumentCount(fun)),
+ };
+ match *arg_ty {
+ Ti::Scalar {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ }
+ | Ti::Vector {
+ kind: Sk::Sint | Sk::Uint,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ match *arg1_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg1.unwrap(),
+ ))
+ }
+ }
+ match *arg2_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => {
+ return Err(ExpressionError::InvalidArgumentType(
+ fun,
+ 2,
+ arg2.unwrap(),
+ ))
+ }
+ }
+ }
+ Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ size: crate::VectorSize::Bi,
+ kind: Sk::Float,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Pack4x8snorm | Mf::Pack4x8unorm => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Vector {
+ size: crate::VectorSize::Quad,
+ kind: Sk::Float,
+ ..
+ } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ Mf::Unpack2x16float
+ | Mf::Unpack2x16snorm
+ | Mf::Unpack2x16unorm
+ | Mf::Unpack4x8snorm
+ | Mf::Unpack4x8unorm => {
+ if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() {
+ return Err(ExpressionError::WrongArgumentCount(fun));
+ }
+ match *arg_ty {
+ Ti::Scalar { kind: Sk::Uint, .. } => {}
+ _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
+ }
+ }
+ }
+ ShaderStages::all()
+ }
+ E::As {
+ expr,
+ kind,
+ convert,
+ } => {
+ let base_width = match resolver[expr] {
+ crate::TypeInner::Scalar { width, .. }
+ | crate::TypeInner::Vector { width, .. }
+ | crate::TypeInner::Matrix { width, .. } => width,
+ _ => return Err(ExpressionError::InvalidCastArgument),
+ };
+ let width = convert.unwrap_or(base_width);
+ if self.check_width(kind, width).is_err() {
+ return Err(ExpressionError::InvalidCastArgument);
+ }
+ ShaderStages::all()
+ }
+ E::CallResult(function) => other_infos[function.index()].available_stages,
+ E::AtomicResult { ty, comparison } => {
+ let scalar_predicate = |ty: &crate::TypeInner| match ty {
+ &crate::TypeInner::Scalar {
+ kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint),
+ width,
+ } => self.check_width(kind, width).is_ok(),
+ _ => false,
+ };
+ let good = match &module.types[ty].inner {
+ ty if !comparison => scalar_predicate(ty),
+ &crate::TypeInner::Struct { ref members, .. } if comparison => {
+ validate_atomic_compare_exchange_struct(
+ &module.types,
+ members,
+ scalar_predicate,
+ )
+ }
+ _ => false,
+ };
+ if !good {
+ return Err(ExpressionError::InvalidAtomicResultType(ty));
+ }
+ ShaderStages::all()
+ }
+ E::ArrayLength(expr) => match resolver[expr] {
+ Ti::Pointer { base, .. } => {
+ let base_ty = &resolver.types[base];
+ if let Ti::Array {
+ size: crate::ArraySize::Dynamic,
+ ..
+ } = base_ty.inner
+ {
+ ShaderStages::all()
+ } else {
+ return Err(ExpressionError::InvalidArrayType(expr));
+ }
+ }
+ ref other => {
+ log::error!("Array length of {:?}", other);
+ return Err(ExpressionError::InvalidArrayType(expr));
+ }
+ },
+ E::RayQueryProceedResult => ShaderStages::all(),
+ E::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => match resolver[query] {
+ Ti::Pointer {
+ base,
+ space: crate::AddressSpace::Function,
+ } => match resolver.types[base].inner {
+ Ti::RayQuery => ShaderStages::all(),
+ ref other => {
+ log::error!("Intersection result of a pointer to {:?}", other);
+ return Err(ExpressionError::InvalidRayQueryType(query));
+ }
+ },
+ ref other => {
+ log::error!("Intersection result of {:?}", other);
+ return Err(ExpressionError::InvalidRayQueryType(query));
+ }
+ },
+ };
+ Ok(stages)
+ }
+
+ fn global_var_ty(
+ module: &crate::Module,
+ function: &crate::Function,
+ expr: Handle<crate::Expression>,
+ ) -> Result<Handle<crate::Type>, ExpressionError> {
+ use crate::Expression as Ex;
+
+ match function.expressions[expr] {
+ Ex::GlobalVariable(var_handle) => Ok(module.global_variables[var_handle].ty),
+ Ex::FunctionArgument(i) => Ok(function.arguments[i as usize].ty),
+ Ex::Access { base, .. } | Ex::AccessIndex { base, .. } => {
+ match function.expressions[base] {
+ Ex::GlobalVariable(var_handle) => {
+ let array_ty = module.global_variables[var_handle].ty;
+
+ match module.types[array_ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => Ok(base),
+ _ => Err(ExpressionError::ExpectedBindingArrayType(array_ty)),
+ }
+ }
+ _ => Err(ExpressionError::ExpectedGlobalVariable),
+ }
+ }
+ _ => Err(ExpressionError::ExpectedGlobalVariable),
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs
new file mode 100644
index 0000000000..8ebcddd9e3
--- /dev/null
+++ b/third_party/rust/naga/src/valid/function.rs
@@ -0,0 +1,1048 @@
+use crate::arena::Handle;
+#[cfg(feature = "validate")]
+use crate::arena::{Arena, UniqueArena};
+
+#[cfg(feature = "validate")]
+use super::validate_atomic_compare_exchange_struct;
+
+use super::{
+ analyzer::{UniformityDisruptor, UniformityRequirements},
+ ExpressionError, FunctionInfo, ModuleInfo,
+};
+use crate::span::WithSpan;
+#[cfg(feature = "validate")]
+use crate::span::{AddSpan as _, MapErrWithSpan as _};
+
+#[cfg(feature = "validate")]
+use bit_set::BitSet;
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum CallError {
+ #[error("Argument {index} expression is invalid")]
+ Argument {
+ index: usize,
+ source: ExpressionError,
+ },
+ #[error("Result expression {0:?} has already been introduced earlier")]
+ ResultAlreadyInScope(Handle<crate::Expression>),
+ #[error("Result value is invalid")]
+ ResultValue(#[source] ExpressionError),
+ #[error("Requires {required} arguments, but {seen} are provided")]
+ ArgumentCount { required: usize, seen: usize },
+ #[error("Argument {index} value {seen_expression:?} doesn't match the type {required:?}")]
+ ArgumentType {
+ index: usize,
+ required: Handle<crate::Type>,
+ seen_expression: Handle<crate::Expression>,
+ },
+ #[error("The emitted expression doesn't match the call")]
+ ExpressionMismatch(Option<Handle<crate::Expression>>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum AtomicError {
+ #[error("Pointer {0:?} to atomic is invalid.")]
+ InvalidPointer(Handle<crate::Expression>),
+ #[error("Operand {0:?} has invalid type.")]
+ InvalidOperand(Handle<crate::Expression>),
+ #[error("Result type for {0:?} doesn't match the statement")]
+ ResultTypeMismatch(Handle<crate::Expression>),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum LocalVariableError {
+ #[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
+ InvalidType(Handle<crate::Type>),
+ #[error("Initializer doesn't match the variable type")]
+ InitializerType,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum FunctionError {
+ #[error("Expression {handle:?} is invalid")]
+ Expression {
+ handle: Handle<crate::Expression>,
+ source: ExpressionError,
+ },
+ #[error("Expression {0:?} can't be introduced - it's already in scope")]
+ ExpressionAlreadyInScope(Handle<crate::Expression>),
+ #[error("Local variable {handle:?} '{name}' is invalid")]
+ LocalVariable {
+ handle: Handle<crate::LocalVariable>,
+ name: String,
+ source: LocalVariableError,
+ },
+ #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")]
+ InvalidArgumentType { index: usize, name: String },
+ #[error("The function's given return type cannot be returned from functions")]
+ NonConstructibleReturnType,
+ #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")]
+ InvalidArgumentPointerSpace {
+ index: usize,
+ name: String,
+ space: crate::AddressSpace,
+ },
+ #[error("There are instructions after `return`/`break`/`continue`")]
+ InstructionsAfterReturn,
+ #[error("The `break` is used outside of a `loop` or `switch` context")]
+ BreakOutsideOfLoopOrSwitch,
+ #[error("The `continue` is used outside of a `loop` context")]
+ ContinueOutsideOfLoop,
+ #[error("The `return` is called within a `continuing` block")]
+ InvalidReturnSpot,
+ #[error("The `return` value {0:?} does not match the function return value")]
+ InvalidReturnType(Option<Handle<crate::Expression>>),
+ #[error("The `if` condition {0:?} is not a boolean scalar")]
+ InvalidIfType(Handle<crate::Expression>),
+ #[error("The `switch` value {0:?} is not an integer scalar")]
+ InvalidSwitchType(Handle<crate::Expression>),
+ #[error("Multiple `switch` cases for {0:?} are present")]
+ ConflictingSwitchCase(crate::SwitchValue),
+ #[error("The `switch` contains cases with conflicting types")]
+ ConflictingCaseType,
+ #[error("The `switch` is missing a `default` case")]
+ MissingDefaultCase,
+ #[error("Multiple `default` cases are present")]
+ MultipleDefaultCases,
+ #[error("The last `switch` case contains a `falltrough`")]
+ LastCaseFallTrough,
+ #[error("The pointer {0:?} doesn't relate to a valid destination for a store")]
+ InvalidStorePointer(Handle<crate::Expression>),
+ #[error("The value {0:?} can not be stored")]
+ InvalidStoreValue(Handle<crate::Expression>),
+ #[error("Store of {value:?} into {pointer:?} doesn't have matching types")]
+ InvalidStoreTypes {
+ pointer: Handle<crate::Expression>,
+ value: Handle<crate::Expression>,
+ },
+ #[error("Image store parameters are invalid")]
+ InvalidImageStore(#[source] ExpressionError),
+ #[error("Call to {function:?} is invalid")]
+ InvalidCall {
+ function: Handle<crate::Function>,
+ #[source]
+ error: CallError,
+ },
+ #[error("Atomic operation is invalid")]
+ InvalidAtomic(#[from] AtomicError),
+ #[error("Ray Query {0:?} is not a local variable")]
+ InvalidRayQueryExpression(Handle<crate::Expression>),
+ #[error("Acceleration structure {0:?} is not a matching expression")]
+ InvalidAccelerationStructure(Handle<crate::Expression>),
+ #[error("Ray descriptor {0:?} is not a matching expression")]
+ InvalidRayDescriptor(Handle<crate::Expression>),
+ #[error("Ray Query {0:?} does not have a matching type")]
+ InvalidRayQueryType(Handle<crate::Type>),
+ #[error(
+ "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
+ )]
+ NonUniformControlFlow(
+ UniformityRequirements,
+ Handle<crate::Expression>,
+ UniformityDisruptor,
+ ),
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their arguments: \"{name}\" has attributes")]
+ PipelineInputRegularFunction { name: String },
+ #[error("Functions that are not entry points cannot have `@location` or `@builtin` attributes on their return value types")]
+ PipelineOutputRegularFunction,
+}
+
+bitflags::bitflags! {
+ #[repr(transparent)]
+ struct ControlFlowAbility: u8 {
+ /// The control can return out of this block.
+ const RETURN = 0x1;
+ /// The control can break.
+ const BREAK = 0x2;
+ /// The control can continue.
+ const CONTINUE = 0x4;
+ }
+}
+
+#[cfg(feature = "validate")]
+struct BlockInfo {
+ stages: super::ShaderStages,
+ finished: bool,
+}
+
+#[cfg(feature = "validate")]
+struct BlockContext<'a> {
+ abilities: ControlFlowAbility,
+ info: &'a FunctionInfo,
+ expressions: &'a Arena<crate::Expression>,
+ types: &'a UniqueArena<crate::Type>,
+ local_vars: &'a Arena<crate::LocalVariable>,
+ global_vars: &'a Arena<crate::GlobalVariable>,
+ functions: &'a Arena<crate::Function>,
+ special_types: &'a crate::SpecialTypes,
+ prev_infos: &'a [FunctionInfo],
+ return_type: Option<Handle<crate::Type>>,
+}
+
+#[cfg(feature = "validate")]
+impl<'a> BlockContext<'a> {
+ fn new(
+ fun: &'a crate::Function,
+ module: &'a crate::Module,
+ info: &'a FunctionInfo,
+ prev_infos: &'a [FunctionInfo],
+ ) -> Self {
+ Self {
+ abilities: ControlFlowAbility::RETURN,
+ info,
+ expressions: &fun.expressions,
+ types: &module.types,
+ local_vars: &fun.local_variables,
+ global_vars: &module.global_variables,
+ functions: &module.functions,
+ special_types: &module.special_types,
+ prev_infos,
+ return_type: fun.result.as_ref().map(|fr| fr.ty),
+ }
+ }
+
+ const fn with_abilities(&self, abilities: ControlFlowAbility) -> Self {
+ BlockContext { abilities, ..*self }
+ }
+
+ fn get_expression(&self, handle: Handle<crate::Expression>) -> &'a crate::Expression {
+ &self.expressions[handle]
+ }
+
+ fn resolve_type_impl(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<ExpressionError>> {
+ if handle.index() >= self.expressions.len() {
+ Err(ExpressionError::DoesntExist.with_span())
+ } else if !valid_expressions.contains(handle.index()) {
+ Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions))
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+
+ fn resolve_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ valid_expressions: &BitSet,
+ ) -> Result<&crate::TypeInner, WithSpan<FunctionError>> {
+ self.resolve_type_impl(handle, valid_expressions)
+ .map_err_inner(|source| FunctionError::Expression { handle, source }.with_span())
+ }
+
+ fn resolve_pointer_type(
+ &self,
+ handle: Handle<crate::Expression>,
+ ) -> Result<&crate::TypeInner, FunctionError> {
+ if handle.index() >= self.expressions.len() {
+ Err(FunctionError::Expression {
+ handle,
+ source: ExpressionError::DoesntExist,
+ })
+ } else {
+ Ok(self.info[handle].ty.inner_with(self.types))
+ }
+ }
+}
+
+impl super::Validator {
+ #[cfg(feature = "validate")]
+ fn validate_call(
+ &mut self,
+ function: Handle<crate::Function>,
+ arguments: &[Handle<crate::Expression>],
+ result: Option<Handle<crate::Expression>>,
+ context: &BlockContext,
+ ) -> Result<super::ShaderStages, WithSpan<CallError>> {
+ let fun = &context.functions[function];
+ if fun.arguments.len() != arguments.len() {
+ return Err(CallError::ArgumentCount {
+ required: fun.arguments.len(),
+ seen: arguments.len(),
+ }
+ .with_span());
+ }
+ for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
+ let ty = context
+ .resolve_type_impl(expr, &self.valid_expression_set)
+ .map_err_inner(|source| {
+ CallError::Argument { index, source }
+ .with_span_handle(expr, context.expressions)
+ })?;
+ let arg_inner = &context.types[arg.ty].inner;
+ if !ty.equivalent(arg_inner, context.types) {
+ return Err(CallError::ArgumentType {
+ index,
+ required: arg.ty,
+ seen_expression: expr,
+ }
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+
+ if let Some(expr) = result {
+ if self.valid_expression_set.insert(expr.index()) {
+ self.valid_expression_list.push(expr);
+ } else {
+ return Err(CallError::ResultAlreadyInScope(expr)
+ .with_span_handle(expr, context.expressions));
+ }
+ match context.expressions[expr] {
+ crate::Expression::CallResult(callee)
+ if fun.result.is_some() && callee == function => {}
+ _ => {
+ return Err(CallError::ExpressionMismatch(result)
+ .with_span_handle(expr, context.expressions))
+ }
+ }
+ } else if fun.result.is_some() {
+ return Err(CallError::ExpressionMismatch(result).with_span());
+ }
+
+ let callee_info = &context.prev_infos[function.index()];
+ Ok(callee_info.available_stages)
+ }
+
+ #[cfg(feature = "validate")]
+ fn emit_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ if self.valid_expression_set.insert(handle.index()) {
+ self.valid_expression_list.push(handle);
+ Ok(())
+ } else {
+ Err(FunctionError::ExpressionAlreadyInScope(handle)
+ .with_span_handle(handle, context.expressions))
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_atomic(
+ &mut self,
+ pointer: Handle<crate::Expression>,
+ fun: &crate::AtomicFunction,
+ value: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?;
+ let (ptr_kind, ptr_width) = match *pointer_inner {
+ crate::TypeInner::Pointer { base, .. } => match context.types[base].inner {
+ crate::TypeInner::Atomic { kind, width } => (kind, width),
+ ref other => {
+ log::error!("Atomic pointer to type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ },
+ ref other => {
+ log::error!("Atomic on type {:?}", other);
+ return Err(AtomicError::InvalidPointer(pointer)
+ .with_span_handle(pointer, context.expressions)
+ .into_other());
+ }
+ };
+
+ let value_inner = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_inner {
+ crate::TypeInner::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {}
+ ref other => {
+ log::error!("Atomic operand type {:?}", other);
+ return Err(AtomicError::InvalidOperand(value)
+ .with_span_handle(value, context.expressions)
+ .into_other());
+ }
+ }
+
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner {
+ log::error!("Atomic exchange comparison has a different type from the value");
+ return Err(AtomicError::InvalidOperand(cmp)
+ .with_span_handle(cmp, context.expressions)
+ .into_other());
+ }
+ }
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::AtomicResult { ty, comparison }
+ if {
+ let scalar_predicate = |ty: &crate::TypeInner| {
+ *ty == crate::TypeInner::Scalar {
+ kind: ptr_kind,
+ width: ptr_width,
+ }
+ };
+ match &context.types[ty].inner {
+ ty if !comparison => scalar_predicate(ty),
+ &crate::TypeInner::Struct { ref members, .. } if comparison => {
+ validate_atomic_compare_exchange_struct(
+ context.types,
+ members,
+ scalar_predicate,
+ )
+ }
+ _ => false,
+ }
+ } => {}
+ _ => {
+ return Err(AtomicError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_block_impl(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ use crate::{Statement as S, TypeInner as Ti};
+ let mut finished = false;
+ let mut stages = super::ShaderStages::all();
+ for (statement, &span) in statements.span_iter() {
+ if finished {
+ return Err(FunctionError::InstructionsAfterReturn
+ .with_span_static(span, "instructions after return"));
+ }
+ match *statement {
+ S::Emit(ref range) => {
+ for handle in range.clone() {
+ self.emit_expression(handle, context)?;
+ }
+ }
+ S::Block(ref block) => {
+ let info = self.validate_block(block, context)?;
+ stages &= info.stages;
+ finished = info.finished;
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ stages &= self.validate_block(accept, context)?.stages;
+ stages &= self.validate_block(reject, context)?.stages;
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let uint = match context
+ .resolve_type(selector, &self.valid_expression_set)?
+ .scalar_kind()
+ {
+ Some(crate::ScalarKind::Uint) => true,
+ Some(crate::ScalarKind::Sint) => false,
+ _ => {
+ return Err(FunctionError::InvalidSwitchType(selector)
+ .with_span_handle(selector, context.expressions))
+ }
+ };
+ self.switch_values.clear();
+ for case in cases {
+ match case.value {
+ crate::SwitchValue::I32(_) if !uint => {}
+ crate::SwitchValue::U32(_) if uint => {}
+ crate::SwitchValue::Default => {}
+ _ => {
+ return Err(FunctionError::ConflictingCaseType.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ));
+ }
+ };
+ if !self.switch_values.insert(case.value) {
+ return Err(match case.value {
+ crate::SwitchValue::Default => FunctionError::MultipleDefaultCases
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "duplicated switch arm here",
+ ),
+ _ => FunctionError::ConflictingSwitchCase(case.value)
+ .with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "conflicting switch arm here",
+ ),
+ });
+ }
+ }
+ if !self.switch_values.contains(&crate::SwitchValue::Default) {
+ return Err(FunctionError::MissingDefaultCase
+ .with_span_static(span, "missing default case"));
+ }
+ if let Some(case) = cases.last() {
+ if case.fall_through {
+ return Err(FunctionError::LastCaseFallTrough.with_span_static(
+ case.body
+ .span_iter()
+ .next()
+ .map_or(Default::default(), |(_, s)| *s),
+ "bad switch arm here",
+ ));
+ }
+ }
+ let pass_through_abilities = context.abilities
+ & (ControlFlowAbility::RETURN | ControlFlowAbility::CONTINUE);
+ let sub_context =
+ context.with_abilities(pass_through_abilities | ControlFlowAbility::BREAK);
+ for case in cases {
+ stages &= self.validate_block(&case.body, &sub_context)?.stages;
+ }
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ // special handling for block scoping is needed here,
+ // because the continuing{} block inherits the scope
+ let base_expression_count = self.valid_expression_list.len();
+ let pass_through_abilities = context.abilities & ControlFlowAbility::RETURN;
+ stages &= self
+ .validate_block_impl(
+ body,
+ &context.with_abilities(
+ pass_through_abilities
+ | ControlFlowAbility::BREAK
+ | ControlFlowAbility::CONTINUE,
+ ),
+ )?
+ .stages;
+ stages &= self
+ .validate_block_impl(
+ continuing,
+ &context.with_abilities(ControlFlowAbility::empty()),
+ )?
+ .stages;
+
+ if let Some(condition) = break_if {
+ match *context.resolve_type(condition, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidIfType(condition)
+ .with_span_handle(condition, context.expressions))
+ }
+ }
+ }
+
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ }
+ S::Break => {
+ if !context.abilities.contains(ControlFlowAbility::BREAK) {
+ return Err(FunctionError::BreakOutsideOfLoopOrSwitch
+ .with_span_static(span, "invalid break"));
+ }
+ finished = true;
+ }
+ S::Continue => {
+ if !context.abilities.contains(ControlFlowAbility::CONTINUE) {
+ return Err(FunctionError::ContinueOutsideOfLoop
+ .with_span_static(span, "invalid continue"));
+ }
+ finished = true;
+ }
+ S::Return { value } => {
+ if !context.abilities.contains(ControlFlowAbility::RETURN) {
+ return Err(FunctionError::InvalidReturnSpot
+ .with_span_static(span, "invalid return"));
+ }
+ let value_ty = value
+ .map(|expr| context.resolve_type(expr, &self.valid_expression_set))
+ .transpose()?;
+ let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
+ // We can't return pointers, but it seems best not to embed that
+ // assumption here, so use `TypeInner::equivalent` for comparison.
+ let okay = match (value_ty, expected_ty) {
+ (None, None) => true,
+ (Some(value_inner), Some(expected_inner)) => {
+ value_inner.equivalent(expected_inner, context.types)
+ }
+ (_, _) => false,
+ };
+
+ if !okay {
+ log::error!(
+ "Returning {:?} where {:?} is expected",
+ value_ty,
+ expected_ty
+ );
+ if let Some(handle) = value {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_handle(handle, context.expressions));
+ } else {
+ return Err(FunctionError::InvalidReturnType(value)
+ .with_span_static(span, "invalid return"));
+ }
+ }
+ finished = true;
+ }
+ S::Kill => {
+ stages &= super::ShaderStages::FRAGMENT;
+ finished = true;
+ }
+ S::Barrier(_) => {
+ stages &= super::ShaderStages::COMPUTE;
+ }
+ S::Store { pointer, value } => {
+ let mut current = pointer;
+ loop {
+ let _ = context
+ .resolve_pointer_type(current)
+ .map_err(|e| e.with_span())?;
+ match context.expressions[current] {
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => current = base,
+ crate::Expression::LocalVariable(_)
+ | crate::Expression::GlobalVariable(_)
+ | crate::Expression::FunctionArgument(_) => break,
+ _ => {
+ return Err(FunctionError::InvalidStorePointer(current)
+ .with_span_handle(pointer, context.expressions))
+ }
+ }
+ }
+
+ let value_ty = context.resolve_type(value, &self.valid_expression_set)?;
+ match *value_ty {
+ Ti::Image { .. } | Ti::Sampler { .. } => {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ _ => {}
+ }
+
+ let pointer_ty = context
+ .resolve_pointer_type(pointer)
+ .map_err(|e| e.with_span())?;
+
+ let good = match *pointer_ty {
+ Ti::Pointer { base, space: _ } => match context.types[base].inner {
+ Ti::Atomic { kind, width } => *value_ty == Ti::Scalar { kind, width },
+ ref other => value_ty == other,
+ },
+ Ti::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space: _,
+ } => *value_ty == Ti::Vector { size, kind, width },
+ Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space: _,
+ } => *value_ty == Ti::Scalar { kind, width },
+ _ => false,
+ };
+ if !good {
+ return Err(FunctionError::InvalidStoreTypes { pointer, value }
+ .with_span()
+ .with_handle(pointer, context.expressions)
+ .with_handle(value, context.expressions));
+ }
+
+ if let Some(space) = pointer_ty.pointer_space() {
+ if !space.access().contains(crate::StorageAccess::STORE) {
+ return Err(FunctionError::InvalidStorePointer(pointer)
+ .with_span_static(
+ context.expressions.get_span(pointer),
+ "writing to this location is not permitted",
+ ));
+ }
+ }
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ //Note: this code uses a lot of `FunctionError::InvalidImageStore`,
+ // and could probably be refactored.
+ let var = match *context.get_expression(image) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ // We're looking at a binding index situation, so punch through the index and look at the global behind it.
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => {
+ match *context.get_expression(base) {
+ crate::Expression::GlobalVariable(var_handle) => {
+ &context.global_vars[var_handle]
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedGlobalVariable,
+ )
+ .with_span_handle(image, context.expressions))
+ }
+ };
+
+ // Punch through a binding array to get the underlying type
+ let global_ty = match context.types[var.ty].inner {
+ Ti::BindingArray { base, .. } => &context.types[base].inner,
+ ref inner => inner,
+ };
+
+ let value_ty = match *global_ty {
+ Ti::Image {
+ class,
+ arrayed,
+ dim,
+ } => {
+ match context
+ .resolve_type(coordinate, &self.valid_expression_set)?
+ .image_storage_coordinates()
+ {
+ Some(coord_dim) if coord_dim == dim => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageCoordinateType(
+ dim, coordinate,
+ ),
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ };
+ if arrayed != array_index.is_some() {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndex,
+ )
+ .with_span_handle(coordinate, context.expressions));
+ }
+ if let Some(expr) = array_index {
+ match *context.resolve_type(expr, &self.valid_expression_set)? {
+ Ti::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ width: _,
+ } => {}
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageArrayIndexType(expr),
+ )
+ .with_span_handle(expr, context.expressions));
+ }
+ }
+ }
+ match class {
+ crate::ImageClass::Storage { format, .. } => {
+ crate::TypeInner::Vector {
+ kind: format.into(),
+ size: crate::VectorSize::Quad,
+ width: 4,
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::InvalidImageClass(class),
+ )
+ .with_span_handle(image, context.expressions));
+ }
+ }
+ }
+ _ => {
+ return Err(FunctionError::InvalidImageStore(
+ ExpressionError::ExpectedImageType(var.ty),
+ )
+ .with_span()
+ .with_handle(var.ty, context.types)
+ .with_handle(image, context.expressions))
+ }
+ };
+
+ if *context.resolve_type(value, &self.valid_expression_set)? != value_ty {
+ return Err(FunctionError::InvalidStoreValue(value)
+ .with_span_handle(value, context.expressions));
+ }
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result,
+ } => match self.validate_call(function, arguments, result, context) {
+ Ok(callee_stages) => stages &= callee_stages,
+ Err(error) => {
+ return Err(error.and_then(|error| {
+ FunctionError::InvalidCall { function, error }
+ .with_span_static(span, "invalid function call")
+ }))
+ }
+ },
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result,
+ } => {
+ self.validate_atomic(pointer, fun, value, result, context)?;
+ }
+ S::RayQuery { query, ref fun } => {
+ let query_var = match *context.get_expression(query) {
+ crate::Expression::LocalVariable(var) => &context.local_vars[var],
+ ref other => {
+ log::error!("Unexpected ray query expression {other:?}");
+ return Err(FunctionError::InvalidRayQueryExpression(query)
+ .with_span_static(span, "invalid query expression"));
+ }
+ };
+ match context.types[query_var.ty].inner {
+ Ti::RayQuery => {}
+ ref other => {
+ log::error!("Unexpected ray query type {other:?}");
+ return Err(FunctionError::InvalidRayQueryType(query_var.ty)
+ .with_span_static(span, "invalid query type"));
+ }
+ }
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ match *context
+ .resolve_type(acceleration_structure, &self.valid_expression_set)?
+ {
+ Ti::AccelerationStructure => {}
+ _ => {
+ return Err(FunctionError::InvalidAccelerationStructure(
+ acceleration_structure,
+ )
+ .with_span_static(span, "invalid acceleration structure"))
+ }
+ }
+ let desc_ty_given =
+ context.resolve_type(descriptor, &self.valid_expression_set)?;
+ let desc_ty_expected = context
+ .special_types
+ .ray_desc
+ .map(|handle| &context.types[handle].inner);
+ if Some(desc_ty_given) != desc_ty_expected {
+ return Err(FunctionError::InvalidRayDescriptor(descriptor)
+ .with_span_static(span, "invalid ray descriptor"));
+ }
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ self.emit_expression(result, context)?;
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+ }
+ }
+ Ok(BlockInfo { stages, finished })
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_block(
+ &mut self,
+ statements: &crate::Block,
+ context: &BlockContext,
+ ) -> Result<BlockInfo, WithSpan<FunctionError>> {
+ let base_expression_count = self.valid_expression_list.len();
+ let info = self.validate_block_impl(statements, context)?;
+ for handle in self.valid_expression_list.drain(base_expression_count..) {
+ self.valid_expression_set.remove(handle.index());
+ }
+ Ok(info)
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_local_var(
+ &self,
+ var: &crate::LocalVariable,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), LocalVariableError> {
+ log::debug!("var {:?}", var);
+ let type_info = self
+ .types
+ .get(var.ty.index())
+ .ok_or(LocalVariableError::InvalidType(var.ty))?;
+ if !type_info
+ .flags
+ .contains(super::TypeFlags::DATA | super::TypeFlags::SIZED)
+ {
+ return Err(LocalVariableError::InvalidType(var.ty));
+ }
+
+ if let Some(const_handle) = var.init {
+ match constants[const_handle].inner {
+ crate::ConstantInner::Scalar { width, ref value } => {
+ let ty_inner = crate::TypeInner::Scalar {
+ width,
+ kind: value.scalar_kind(),
+ };
+ if types[var.ty].inner != ty_inner {
+ return Err(LocalVariableError::InitializerType);
+ }
+ }
+ crate::ConstantInner::Composite { ty, components: _ } => {
+ if ty != var.ty {
+ return Err(LocalVariableError::InitializerType);
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
+ pub(super) fn validate_function(
+ &mut self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ mod_info: &ModuleInfo,
+ #[cfg_attr(not(feature = "validate"), allow(unused))] entry_point: bool,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ #[cfg_attr(not(feature = "validate"), allow(unused_mut))]
+ let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
+
+ #[cfg(feature = "validate")]
+ for (var_handle, var) in fun.local_variables.iter() {
+ self.validate_local_var(var, &module.types, &module.constants)
+ .map_err(|source| {
+ FunctionError::LocalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(var.ty, &module.types)
+ .with_handle(var_handle, &fun.local_variables)
+ })?;
+ }
+
+ #[cfg(feature = "validate")]
+ for (index, argument) in fun.arguments.iter().enumerate() {
+ match module.types[argument.ty].inner.pointer_space() {
+ Some(
+ crate::AddressSpace::Private
+ | crate::AddressSpace::Function
+ | crate::AddressSpace::WorkGroup,
+ )
+ | None => {}
+ Some(other) => {
+ return Err(FunctionError::InvalidArgumentPointerSpace {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ space: other,
+ }
+ .with_span_handle(argument.ty, &module.types))
+ }
+ }
+ // Check for the least informative error last.
+ if !self.types[argument.ty.index()]
+ .flags
+ .contains(super::TypeFlags::ARGUMENT)
+ {
+ return Err(FunctionError::InvalidArgumentType {
+ index,
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+
+ if !entry_point && argument.binding.is_some() {
+ return Err(FunctionError::PipelineInputRegularFunction {
+ name: argument.name.clone().unwrap_or_default(),
+ }
+ .with_span_handle(argument.ty, &module.types));
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if let Some(ref result) = fun.result {
+ if !self.types[result.ty.index()]
+ .flags
+ .contains(super::TypeFlags::CONSTRUCTIBLE)
+ {
+ return Err(FunctionError::NonConstructibleReturnType
+ .with_span_handle(result.ty, &module.types));
+ }
+
+ if !entry_point && result.binding.is_some() {
+ return Err(FunctionError::PipelineOutputRegularFunction
+ .with_span_handle(result.ty, &module.types));
+ }
+ }
+
+ self.valid_expression_set.clear();
+ self.valid_expression_list.clear();
+ for (handle, expr) in fun.expressions.iter() {
+ if expr.needs_pre_emit() {
+ self.valid_expression_set.insert(handle.index());
+ }
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
+ match self.validate_expression(
+ handle,
+ expr,
+ fun,
+ module,
+ &info,
+ &mod_info.functions,
+ ) {
+ Ok(stages) => info.available_stages &= stages,
+ Err(source) => {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions))
+ }
+ }
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BLOCKS) {
+ let stages = self
+ .validate_block(
+ &fun.body,
+ &BlockContext::new(fun, module, &info, &mod_info.functions),
+ )?
+ .stages;
+ info.available_stages &= stages;
+ }
+ Ok(info)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs
new file mode 100644
index 0000000000..fdd43cd585
--- /dev/null
+++ b/third_party/rust/naga/src/valid/handles.rs
@@ -0,0 +1,652 @@
+//! Implementation of `Validator::validate_module_handles`.
+
+use crate::{
+ arena::{BadHandle, BadRangeError},
+ Handle,
+};
+
+#[cfg(feature = "validate")]
+use crate::{Arena, UniqueArena};
+
+#[cfg(feature = "validate")]
+use super::{TypeError, ValidationError};
+
+#[cfg(feature = "validate")]
+use std::{convert::TryInto, hash::Hash, num::NonZeroU32};
+
+#[cfg(feature = "validate")]
+impl super::Validator {
+ /// Validates that all handles within `module` are:
+ ///
+ /// * Valid, in the sense that they contain indices within each arena structure inside the
+ /// [`crate::Module`] type.
+ /// * No arena contents contain any items that have forward dependencies; that is, the value
+ /// associated with a handle only may contain references to handles in the same arena that
+ /// were constructed before it.
+ ///
+ /// By validating the above conditions, we free up subsequent logic to assume that handle
+ /// accesses are infallible.
+ ///
+ /// # Errors
+ ///
+ /// Errors returned by this method are intentionally sparse, for simplicity of implementation.
+ /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this
+ /// validation pass.
+ pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
+ let &crate::Module {
+ ref constants,
+ ref entry_points,
+ ref functions,
+ ref global_variables,
+ ref types,
+ ref special_types,
+ } = module;
+
+ // NOTE: Types being first is important. All other forms of validation depend on this.
+ for (this_handle, ty) in types.iter() {
+ let &crate::Type {
+ ref name,
+ ref inner,
+ } = ty;
+
+ let validate_array_size = |size| {
+ match size {
+ crate::ArraySize::Constant(constant) => {
+ let &crate::Constant {
+ name: _,
+ specialization: _,
+ ref inner,
+ } = constants.try_get(constant)?;
+ if !matches!(inner, &crate::ConstantInner::Scalar { .. }) {
+ return Err(ValidationError::Type {
+ handle: this_handle,
+ name: name.clone().unwrap_or_default(),
+ source: TypeError::InvalidArraySizeConstant(constant),
+ });
+ }
+ }
+ crate::ArraySize::Dynamic => (),
+ };
+ Ok(this_handle)
+ };
+
+ match *inner {
+ crate::TypeInner::Scalar { .. }
+ | crate::TypeInner::Vector { .. }
+ | crate::TypeInner::Matrix { .. }
+ | crate::TypeInner::ValuePointer { .. }
+ | crate::TypeInner::Atomic { .. }
+ | crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => (),
+ crate::TypeInner::Pointer { base, space: _ } => {
+ this_handle.check_dep(base)?;
+ }
+ crate::TypeInner::Array {
+ base,
+ size,
+ stride: _,
+ }
+ | crate::TypeInner::BindingArray { base, size } => {
+ this_handle.check_dep(base)?;
+ validate_array_size(size)?;
+ }
+ crate::TypeInner::Struct {
+ ref members,
+ span: _,
+ } => {
+ this_handle.check_dep_iter(members.iter().map(|m| m.ty))?;
+ }
+ }
+ }
+
+ let validate_type = |handle| Self::validate_type_handle(handle, types);
+
+ for (this_handle, constant) in constants.iter() {
+ let &crate::Constant {
+ name: _,
+ specialization: _,
+ ref inner,
+ } = constant;
+ match *inner {
+ crate::ConstantInner::Scalar { .. } => (),
+ crate::ConstantInner::Composite { ty, ref components } => {
+ validate_type(ty)?;
+ this_handle.check_dep_iter(components.iter().copied())?;
+ }
+ }
+ }
+
+ let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+
+ for (_handle, global_variable) in global_variables.iter() {
+ let &crate::GlobalVariable {
+ name: _,
+ space: _,
+ binding: _,
+ ty,
+ init,
+ } = global_variable;
+ validate_type(ty)?;
+ if let Some(init_expr) = init {
+ validate_constant(init_expr)?;
+ }
+ }
+
+ let validate_function = |function_handle, function: &_| -> Result<_, InvalidHandleError> {
+ let &crate::Function {
+ name: _,
+ ref arguments,
+ ref result,
+ ref local_variables,
+ ref expressions,
+ ref named_expressions,
+ ref body,
+ } = function;
+
+ for arg in arguments.iter() {
+ let &crate::FunctionArgument {
+ name: _,
+ ty,
+ binding: _,
+ } = arg;
+ validate_type(ty)?;
+ }
+
+ if let &Some(crate::FunctionResult { ty, binding: _ }) = result {
+ validate_type(ty)?;
+ }
+
+ for (_handle, local_variable) in local_variables.iter() {
+ let &crate::LocalVariable { name: _, ty, init } = local_variable;
+ validate_type(ty)?;
+ if let Some(init_constant) = init {
+ validate_constant(init_constant)?;
+ }
+ }
+
+ for handle in named_expressions.keys().copied() {
+ Self::validate_expression_handle(handle, expressions)?;
+ }
+
+ for handle_and_expr in expressions.iter() {
+ Self::validate_expression_handles(
+ handle_and_expr,
+ constants,
+ types,
+ local_variables,
+ global_variables,
+ functions,
+ function_handle,
+ )?;
+ }
+
+ Self::validate_block_handles(body, expressions, functions)?;
+
+ Ok(())
+ };
+
+ for entry_point in entry_points.iter() {
+ validate_function(None, &entry_point.function)?;
+ }
+
+ for (function_handle, function) in functions.iter() {
+ validate_function(Some(function_handle), function)?;
+ }
+
+ if let Some(ty) = special_types.ray_desc {
+ validate_type(ty)?;
+ }
+ if let Some(ty) = special_types.ray_intersection {
+ validate_type(ty)?;
+ }
+
+ Ok(())
+ }
+
+ fn validate_type_handle(
+ handle: Handle<crate::Type>,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for_uniq(types).map(|_| ())
+ }
+
+ fn validate_constant_handle(
+ handle: Handle<crate::Constant>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(constants).map(|_| ())
+ }
+
+ fn validate_expression_handle(
+ handle: Handle<crate::Expression>,
+ expressions: &Arena<crate::Expression>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(expressions).map(|_| ())
+ }
+
+ fn validate_function_handle(
+ handle: Handle<crate::Function>,
+ functions: &Arena<crate::Function>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(functions).map(|_| ())
+ }
+
+ fn validate_expression_handles(
+ (handle, expression): (Handle<crate::Expression>, &crate::Expression),
+ constants: &Arena<crate::Constant>,
+ types: &UniqueArena<crate::Type>,
+ local_variables: &Arena<crate::LocalVariable>,
+ global_variables: &Arena<crate::GlobalVariable>,
+ functions: &Arena<crate::Function>,
+ // The handle of the current function or `None` if it's an entry point
+ current_function: Option<Handle<crate::Function>>,
+ ) -> Result<(), InvalidHandleError> {
+ let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_type = |handle| Self::validate_type_handle(handle, types);
+
+ match *expression {
+ crate::Expression::Access { base, index } => {
+ handle.check_dep(base)?.check_dep(index)?;
+ }
+ crate::Expression::AccessIndex { base, .. } => {
+ handle.check_dep(base)?;
+ }
+ crate::Expression::Constant(constant) => {
+ validate_constant(constant)?;
+ }
+ crate::Expression::Splat { value, .. } => {
+ handle.check_dep(value)?;
+ }
+ crate::Expression::Swizzle { vector, .. } => {
+ handle.check_dep(vector)?;
+ }
+ crate::Expression::Compose { ty, ref components } => {
+ validate_type(ty)?;
+ handle.check_dep_iter(components.iter().copied())?;
+ }
+ crate::Expression::FunctionArgument(_arg_idx) => (),
+ crate::Expression::GlobalVariable(global_variable) => {
+ global_variable.check_valid_for(global_variables)?;
+ }
+ crate::Expression::LocalVariable(local_variable) => {
+ local_variable.check_valid_for(local_variables)?;
+ }
+ crate::Expression::Load { pointer } => {
+ handle.check_dep(pointer)?;
+ }
+ crate::Expression::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset,
+ level,
+ depth_ref,
+ } => {
+ if let Some(offset) = offset {
+ validate_constant(offset)?;
+ }
+
+ handle
+ .check_dep(image)?
+ .check_dep(sampler)?
+ .check_dep(coordinate)?
+ .check_dep_opt(array_index)?;
+
+ match level {
+ crate::SampleLevel::Auto | crate::SampleLevel::Zero => (),
+ crate::SampleLevel::Exact(expr) => {
+ handle.check_dep(expr)?;
+ }
+ crate::SampleLevel::Bias(expr) => {
+ handle.check_dep(expr)?;
+ }
+ crate::SampleLevel::Gradient { x, y } => {
+ handle.check_dep(x)?.check_dep(y)?;
+ }
+ };
+
+ handle.check_dep_opt(depth_ref)?;
+ }
+ crate::Expression::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ handle
+ .check_dep(image)?
+ .check_dep(coordinate)?
+ .check_dep_opt(array_index)?
+ .check_dep_opt(sample)?
+ .check_dep_opt(level)?;
+ }
+ crate::Expression::ImageQuery { image, query } => {
+ handle.check_dep(image)?;
+ match query {
+ crate::ImageQuery::Size { level } => {
+ handle.check_dep_opt(level)?;
+ }
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => (),
+ };
+ }
+ crate::Expression::Unary {
+ op: _,
+ expr: operand,
+ } => {
+ handle.check_dep(operand)?;
+ }
+ crate::Expression::Binary { op: _, left, right } => {
+ handle.check_dep(left)?.check_dep(right)?;
+ }
+ crate::Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => {
+ handle
+ .check_dep(condition)?
+ .check_dep(accept)?
+ .check_dep(reject)?;
+ }
+ crate::Expression::Derivative { expr: argument, .. } => {
+ handle.check_dep(argument)?;
+ }
+ crate::Expression::Relational { fun: _, argument } => {
+ handle.check_dep(argument)?;
+ }
+ crate::Expression::Math {
+ fun: _,
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ } => {
+ handle
+ .check_dep(arg)?
+ .check_dep_opt(arg1)?
+ .check_dep_opt(arg2)?
+ .check_dep_opt(arg3)?;
+ }
+ crate::Expression::As {
+ expr: input,
+ kind: _,
+ convert: _,
+ } => {
+ handle.check_dep(input)?;
+ }
+ crate::Expression::CallResult(function) => {
+ Self::validate_function_handle(function, functions)?;
+ if let Some(handle) = current_function {
+ handle.check_dep(function)?;
+ }
+ }
+ crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult => (),
+ crate::Expression::ArrayLength(array) => {
+ handle.check_dep(array)?;
+ }
+ crate::Expression::RayQueryGetIntersection {
+ query,
+ committed: _,
+ } => {
+ handle.check_dep(query)?;
+ }
+ }
+ Ok(())
+ }
+
+ fn validate_block_handles(
+ block: &crate::Block,
+ expressions: &Arena<crate::Expression>,
+ functions: &Arena<crate::Function>,
+ ) -> Result<(), InvalidHandleError> {
+ let validate_block = |block| Self::validate_block_handles(block, expressions, functions);
+ let validate_expr = |handle| Self::validate_expression_handle(handle, expressions);
+ let validate_expr_opt = |handle_opt| {
+ if let Some(handle) = handle_opt {
+ validate_expr(handle)?;
+ }
+ Ok(())
+ };
+
+ block.iter().try_for_each(|stmt| match *stmt {
+ crate::Statement::Emit(ref expr_range) => {
+ expr_range.check_valid_for(expressions)?;
+ Ok(())
+ }
+ crate::Statement::Block(ref block) => {
+ validate_block(block)?;
+ Ok(())
+ }
+ crate::Statement::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ validate_expr(condition)?;
+ validate_block(accept)?;
+ validate_block(reject)?;
+ Ok(())
+ }
+ crate::Statement::Switch {
+ selector,
+ ref cases,
+ } => {
+ validate_expr(selector)?;
+ for &crate::SwitchCase {
+ value: _,
+ ref body,
+ fall_through: _,
+ } in cases
+ {
+ validate_block(body)?;
+ }
+ Ok(())
+ }
+ crate::Statement::Loop {
+ ref body,
+ ref continuing,
+ break_if,
+ } => {
+ validate_block(body)?;
+ validate_block(continuing)?;
+ validate_expr_opt(break_if)?;
+ Ok(())
+ }
+ crate::Statement::Return { value } => validate_expr_opt(value),
+ crate::Statement::Store { pointer, value } => {
+ validate_expr(pointer)?;
+ validate_expr(value)?;
+ Ok(())
+ }
+ crate::Statement::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ validate_expr(image)?;
+ validate_expr(coordinate)?;
+ validate_expr_opt(array_index)?;
+ validate_expr(value)?;
+ Ok(())
+ }
+ crate::Statement::Atomic {
+ pointer,
+ fun,
+ value,
+ result,
+ } => {
+ validate_expr(pointer)?;
+ match fun {
+ crate::AtomicFunction::Add
+ | crate::AtomicFunction::Subtract
+ | crate::AtomicFunction::And
+ | crate::AtomicFunction::ExclusiveOr
+ | crate::AtomicFunction::InclusiveOr
+ | crate::AtomicFunction::Min
+ | crate::AtomicFunction::Max => (),
+ crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?,
+ };
+ validate_expr(value)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::Call {
+ function,
+ ref arguments,
+ result,
+ } => {
+ Self::validate_function_handle(function, functions)?;
+ for arg in arguments.iter().copied() {
+ validate_expr(arg)?;
+ }
+ validate_expr_opt(result)?;
+ Ok(())
+ }
+ crate::Statement::RayQuery { query, ref fun } => {
+ validate_expr(query)?;
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ acceleration_structure,
+ descriptor,
+ } => {
+ validate_expr(acceleration_structure)?;
+ validate_expr(descriptor)?;
+ }
+ crate::RayQueryFunction::Proceed { result } => {
+ validate_expr(result)?;
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ Ok(())
+ }
+ crate::Statement::Break
+ | crate::Statement::Continue
+ | crate::Statement::Kill
+ | crate::Statement::Barrier(_) => Ok(()),
+ })
+ }
+}
+
+#[cfg(feature = "validate")]
+impl From<BadHandle> for ValidationError {
+ fn from(source: BadHandle) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+#[cfg(feature = "validate")]
+impl From<FwdDepError> for ValidationError {
+ fn from(source: FwdDepError) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+#[cfg(feature = "validate")]
+impl From<BadRangeError> for ValidationError {
+ fn from(source: BadRangeError) -> Self {
+ Self::InvalidHandle(source.into())
+ }
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum InvalidHandleError {
+ #[error(transparent)]
+ BadHandle(#[from] BadHandle),
+ #[error(transparent)]
+ ForwardDependency(#[from] FwdDepError),
+ #[error(transparent)]
+ BadRange(#[from] BadRangeError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[error(
+ "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
+ which has not been processed yet"
+)]
+pub struct FwdDepError {
+ // This error is used for many `Handle` types, but there's no point in making this generic, so
+ // we just flatten them all to `Handle<()>` here.
+ subject: Handle<()>,
+ subject_kind: &'static str,
+ depends_on: Handle<()>,
+ depends_on_kind: &'static str,
+}
+
+#[cfg(feature = "validate")]
+impl<T> Handle<T> {
+ /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`].
+ pub(self) fn check_valid_for(self, arena: &Arena<T>) -> Result<(), InvalidHandleError> {
+ arena.check_contains_handle(self)?;
+ Ok(())
+ }
+
+ /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`].
+ pub(self) fn check_valid_for_uniq(
+ self,
+ arena: &UniqueArena<T>,
+ ) -> Result<(), InvalidHandleError>
+ where
+ T: Eq + Hash,
+ {
+ arena.check_contains_handle(self)?;
+ Ok(())
+ }
+
+ /// Check that `depends_on` was constructed before `self` by comparing handle indices.
+ ///
+ /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`])
+ /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid.
+ /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating
+ /// recursive definitions of arena-based values in linear time.
+ ///
+ /// # Errors
+ ///
+ /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier
+ /// than `self`'s, this function returns an error.
+ pub(self) fn check_dep(self, depends_on: Self) -> Result<Self, FwdDepError> {
+ if depends_on < self {
+ Ok(self)
+ } else {
+ let erase_handle_type = |handle: Handle<_>| {
+ Handle::new(NonZeroU32::new((handle.index() + 1).try_into().unwrap()).unwrap())
+ };
+ Err(FwdDepError {
+ subject: erase_handle_type(self),
+ subject_kind: std::any::type_name::<T>(),
+ depends_on: erase_handle_type(depends_on),
+ depends_on_kind: std::any::type_name::<T>(),
+ })
+ }
+ }
+
+ /// Like [`Self::check_dep`], except for [`Option`]al handle values.
+ pub(self) fn check_dep_opt(self, depends_on: Option<Self>) -> Result<Self, FwdDepError> {
+ self.check_dep_iter(depends_on.into_iter())
+ }
+
+ /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values.
+ pub(self) fn check_dep_iter(
+ self,
+ depends_on: impl Iterator<Item = Self>,
+ ) -> Result<Self, FwdDepError> {
+ for handle in depends_on {
+ self.check_dep(handle)?;
+ }
+ Ok(self)
+ }
+}
+
+#[cfg(feature = "validate")]
+impl<T> crate::arena::Range<T> {
+ pub(self) fn check_valid_for(&self, arena: &Arena<T>) -> Result<(), BadRangeError> {
+ arena.check_contains_range(self)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs
new file mode 100644
index 0000000000..268490c52e
--- /dev/null
+++ b/third_party/rust/naga/src/valid/interface.rs
@@ -0,0 +1,676 @@
+use super::{
+ analyzer::{FunctionInfo, GlobalUse},
+ Capabilities, Disalignment, FunctionError, ModuleInfo,
+};
+use crate::arena::{Handle, UniqueArena};
+
+use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan};
+use bit_set::BitSet;
+
+#[cfg(feature = "validate")]
+const MAX_WORKGROUP_SIZE: u32 = 0x4000;
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum GlobalVariableError {
+ #[error("Usage isn't compatible with address space {0:?}")]
+ InvalidUsage(crate::AddressSpace),
+ #[error("Type isn't compatible with address space {0:?}")]
+ InvalidType(crate::AddressSpace),
+ #[error("Type flags {seen:?} do not meet the required {required:?}")]
+ MissingTypeFlags {
+ required: super::TypeFlags,
+ seen: super::TypeFlags,
+ },
+ #[error("Capability {0:?} is not supported")]
+ UnsupportedCapability(Capabilities),
+ #[error("Binding decoration is missing or not applicable")]
+ InvalidBinding,
+ #[error("Alignment requirements for address space {0:?} are not met by {1:?}")]
+ Alignment(
+ crate::AddressSpace,
+ Handle<crate::Type>,
+ #[source] Disalignment,
+ ),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum VaryingError {
+ #[error("The type {0:?} does not match the varying")]
+ InvalidType(Handle<crate::Type>),
+ #[error("The type {0:?} cannot be used for user-defined entry point inputs or outputs")]
+ NotIOShareableType(Handle<crate::Type>),
+ #[error("Interpolation is not valid")]
+ InvalidInterpolation,
+ #[error("Interpolation must be specified on vertex shader outputs and fragment shader inputs")]
+ MissingInterpolation,
+ #[error("Built-in {0:?} is not available at this stage")]
+ InvalidBuiltInStage(crate::BuiltIn),
+ #[error("Built-in type for {0:?} is invalid")]
+ InvalidBuiltInType(crate::BuiltIn),
+ #[error("Entry point arguments and return values must all have bindings")]
+ MissingBinding,
+ #[error("Struct member {0} is missing a binding")]
+ MemberMissingBinding(u32),
+ #[error("Multiple bindings at location {location} are present")]
+ BindingCollision { location: u32 },
+ #[error("Built-in {0:?} is present more than once")]
+ DuplicateBuiltIn(crate::BuiltIn),
+ #[error("Capability {0:?} is not supported")]
+ UnsupportedCapability(Capabilities),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum EntryPointError {
+ #[error("Multiple conflicting entry points")]
+ Conflict,
+ #[error("Vertex shaders must return a `@builtin(position)` output value")]
+ MissingVertexOutputPosition,
+ #[error("Early depth test is not applicable")]
+ UnexpectedEarlyDepthTest,
+ #[error("Workgroup size is not applicable")]
+ UnexpectedWorkgroupSize,
+ #[error("Workgroup size is out of range")]
+ OutOfRangeWorkgroupSize,
+ #[error("Uses operations forbidden at this stage")]
+ ForbiddenStageOperations,
+ #[error("Global variable {0:?} is used incorrectly as {1:?}")]
+ InvalidGlobalUsage(Handle<crate::GlobalVariable>, GlobalUse),
+ #[error("Bindings for {0:?} conflict with other resource")]
+ BindingCollision(Handle<crate::GlobalVariable>),
+ #[error("Argument {0} varying error")]
+ Argument(u32, #[source] VaryingError),
+ #[error(transparent)]
+ Result(#[from] VaryingError),
+ #[error("Location {location} interpolation of an integer has to be flat")]
+ InvalidIntegerInterpolation { location: u32 },
+ #[error(transparent)]
+ Function(#[from] FunctionError),
+}
+
+#[cfg(feature = "validate")]
+fn storage_usage(access: crate::StorageAccess) -> GlobalUse {
+ let mut storage_usage = GlobalUse::QUERY;
+ if access.contains(crate::StorageAccess::LOAD) {
+ storage_usage |= GlobalUse::READ;
+ }
+ if access.contains(crate::StorageAccess::STORE) {
+ storage_usage |= GlobalUse::WRITE;
+ }
+ storage_usage
+}
+
+struct VaryingContext<'a> {
+ stage: crate::ShaderStage,
+ output: bool,
+ types: &'a UniqueArena<crate::Type>,
+ type_info: &'a Vec<super::r#type::TypeInfo>,
+ location_mask: &'a mut BitSet,
+ built_ins: &'a mut crate::FastHashSet<crate::BuiltIn>,
+ capabilities: Capabilities,
+
+ #[cfg(feature = "validate")]
+ flags: super::ValidationFlags,
+}
+
+impl VaryingContext<'_> {
+ fn validate_impl(
+ &mut self,
+ ty: Handle<crate::Type>,
+ binding: &crate::Binding,
+ ) -> Result<(), VaryingError> {
+ use crate::{
+ BuiltIn as Bi, ScalarKind as Sk, ShaderStage as St, TypeInner as Ti, VectorSize as Vs,
+ };
+
+ let ty_inner = &self.types[ty].inner;
+ match *binding {
+ crate::Binding::BuiltIn(built_in) => {
+ // Ignore the `invariant` field for the sake of duplicate checks,
+ // but use the original in error messages.
+ let canonical = if let crate::BuiltIn::Position { .. } = built_in {
+ crate::BuiltIn::Position { invariant: false }
+ } else {
+ built_in
+ };
+
+ if self.built_ins.contains(&canonical) {
+ return Err(VaryingError::DuplicateBuiltIn(built_in));
+ }
+ self.built_ins.insert(canonical);
+
+ let required = match built_in {
+ Bi::ClipDistance => Capabilities::CLIP_DISTANCE,
+ Bi::CullDistance => Capabilities::CULL_DISTANCE,
+ Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
+ Bi::ViewIndex => Capabilities::MULTIVIEW,
+ Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
+ _ => Capabilities::empty(),
+ };
+ if !self.capabilities.contains(required) {
+ return Err(VaryingError::UnsupportedCapability(required));
+ }
+
+ let width = 4;
+ let (visible, type_good) = match built_in {
+ Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
+ self.stage == St::Vertex && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::ClipDistance | Bi::CullDistance => (
+ self.stage == St::Vertex && self.output,
+ match *ty_inner {
+ Ti::Array { base, .. } => {
+ self.types[base].inner
+ == Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ }
+ }
+ _ => false,
+ },
+ ),
+ Bi::PointSize => (
+ self.stage == St::Vertex && self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::PointCoord => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Bi,
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::Position { .. } => (
+ match self.stage {
+ St::Vertex => self.output,
+ St::Fragment => !self.output,
+ St::Compute => false,
+ },
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Quad,
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::ViewIndex => (
+ match self.stage {
+ St::Vertex | St::Fragment => !self.output,
+ St::Compute => false,
+ },
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Sint,
+ width,
+ },
+ ),
+ Bi::FragDepth => (
+ self.stage == St::Fragment && self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Float,
+ width,
+ },
+ ),
+ Bi::FrontFacing => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Bool,
+ width: crate::BOOL_WIDTH,
+ },
+ ),
+ Bi::PrimitiveIndex => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::SampleIndex => (
+ self.stage == St::Fragment && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::SampleMask => (
+ self.stage == St::Fragment,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::LocalInvocationIndex => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner
+ == Ti::Scalar {
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ Bi::GlobalInvocationId
+ | Bi::LocalInvocationId
+ | Bi::WorkGroupId
+ | Bi::WorkGroupSize
+ | Bi::NumWorkGroups => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner
+ == Ti::Vector {
+ size: Vs::Tri,
+ kind: Sk::Uint,
+ width,
+ },
+ ),
+ };
+
+ if !visible {
+ return Err(VaryingError::InvalidBuiltInStage(built_in));
+ }
+ if !type_good {
+ log::warn!("Wrong builtin type: {:?}", ty_inner);
+ return Err(VaryingError::InvalidBuiltInType(built_in));
+ }
+ }
+ crate::Binding::Location {
+ location,
+ interpolation,
+ sampling,
+ } => {
+ // Only IO-shareable types may be stored in locations.
+ if !self.type_info[ty.index()]
+ .flags
+ .contains(super::TypeFlags::IO_SHAREABLE)
+ {
+ return Err(VaryingError::NotIOShareableType(ty));
+ }
+ if !self.location_mask.insert(location as usize) {
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::BindingCollision { location });
+ }
+ }
+
+ let needs_interpolation = match self.stage {
+ crate::ShaderStage::Vertex => self.output,
+ crate::ShaderStage::Fragment => !self.output,
+ crate::ShaderStage::Compute => false,
+ };
+
+ // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but
+ // SPIR-V and GLSL both explicitly tolerate such combinations of decorators /
+ // qualifiers, so we won't complain about that here.
+ let _ = sampling;
+
+ let required = match sampling {
+ Some(crate::Sampling::Sample) => Capabilities::MULTISAMPLED_SHADING,
+ _ => Capabilities::empty(),
+ };
+ if !self.capabilities.contains(required) {
+ return Err(VaryingError::UnsupportedCapability(required));
+ }
+
+ match ty_inner.scalar_kind() {
+ Some(crate::ScalarKind::Float) => {
+ if needs_interpolation && interpolation.is_none() {
+ return Err(VaryingError::MissingInterpolation);
+ }
+ }
+ Some(_) => {
+ if needs_interpolation && interpolation != Some(crate::Interpolation::Flat)
+ {
+ return Err(VaryingError::InvalidInterpolation);
+ }
+ }
+ None => return Err(VaryingError::InvalidType(ty)),
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ fn validate(
+ &mut self,
+ ty: Handle<crate::Type>,
+ binding: Option<&crate::Binding>,
+ ) -> Result<(), WithSpan<VaryingError>> {
+ let span_context = self.types.get_span_context(ty);
+ match binding {
+ Some(binding) => self
+ .validate_impl(ty, binding)
+ .map_err(|e| e.with_span_context(span_context)),
+ None => {
+ match self.types[ty].inner {
+ crate::TypeInner::Struct { ref members, .. } => {
+ for (index, member) in members.iter().enumerate() {
+ let span_context = self.types.get_span_context(ty);
+ match member.binding {
+ None => {
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::MemberMissingBinding(
+ index as u32,
+ )
+ .with_span_context(span_context));
+ }
+ #[cfg(not(feature = "validate"))]
+ let _ = index;
+ }
+ Some(ref binding) => self
+ .validate_impl(member.ty, binding)
+ .map_err(|e| e.with_span_context(span_context))?,
+ }
+ }
+ }
+ _ =>
+ {
+ #[cfg(feature = "validate")]
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(VaryingError::MissingBinding.with_span());
+ }
+ }
+ }
+ Ok(())
+ }
+ }
+ }
+}
+
+impl super::Validator {
+ #[cfg(feature = "validate")]
+ pub(super) fn validate_global_var(
+ &self,
+ var: &crate::GlobalVariable,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), GlobalVariableError> {
+ use super::TypeFlags;
+
+ log::debug!("var {:?}", var);
+ let type_info = &self.types[var.ty.index()];
+
+ let (required_type_flags, is_resource) = match var.space {
+ crate::AddressSpace::Function => {
+ return Err(GlobalVariableError::InvalidUsage(var.space))
+ }
+ crate::AddressSpace::Storage { .. } => {
+ if let Err((ty_handle, disalignment)) = type_info.storage_layout {
+ if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
+ return Err(GlobalVariableError::Alignment(
+ var.space,
+ ty_handle,
+ disalignment,
+ ));
+ }
+ }
+ (TypeFlags::DATA | TypeFlags::HOST_SHAREABLE, true)
+ }
+ crate::AddressSpace::Uniform => {
+ if let Err((ty_handle, disalignment)) = type_info.uniform_layout {
+ if self.flags.contains(super::ValidationFlags::STRUCT_LAYOUTS) {
+ return Err(GlobalVariableError::Alignment(
+ var.space,
+ ty_handle,
+ disalignment,
+ ));
+ }
+ }
+ (
+ TypeFlags::DATA
+ | TypeFlags::COPY
+ | TypeFlags::SIZED
+ | TypeFlags::HOST_SHAREABLE,
+ true,
+ )
+ }
+ crate::AddressSpace::Handle => {
+ match types[var.ty].inner {
+ crate::TypeInner::Image { .. }
+ | crate::TypeInner::Sampler { .. }
+ | crate::TypeInner::BindingArray { .. }
+ | crate::TypeInner::AccelerationStructure
+ | crate::TypeInner::RayQuery => {}
+ _ => {
+ return Err(GlobalVariableError::InvalidType(var.space));
+ }
+ };
+ let inner_ty = match &types[var.ty].inner {
+ &crate::TypeInner::BindingArray { base, .. } => &types[base].inner,
+ ty => ty,
+ };
+ if let crate::TypeInner::Image {
+ class:
+ crate::ImageClass::Storage {
+ format:
+ crate::StorageFormat::R16Unorm
+ | crate::StorageFormat::R16Snorm
+ | crate::StorageFormat::Rg16Unorm
+ | crate::StorageFormat::Rg16Snorm
+ | crate::StorageFormat::Rgba16Unorm
+ | crate::StorageFormat::Rgba16Snorm,
+ ..
+ },
+ ..
+ } = *inner_ty
+ {
+ if !self
+ .capabilities
+ .contains(Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS)
+ {
+ return Err(GlobalVariableError::UnsupportedCapability(
+ Capabilities::STORAGE_TEXTURE_16BIT_NORM_FORMATS,
+ ));
+ }
+ }
+
+ (TypeFlags::empty(), true)
+ }
+ crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {
+ (TypeFlags::DATA | TypeFlags::SIZED, false)
+ }
+ crate::AddressSpace::PushConstant => {
+ if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) {
+ return Err(GlobalVariableError::UnsupportedCapability(
+ Capabilities::PUSH_CONSTANT,
+ ));
+ }
+ (
+ TypeFlags::DATA
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::SIZED,
+ false,
+ )
+ }
+ };
+
+ if !type_info.flags.contains(required_type_flags) {
+ return Err(GlobalVariableError::MissingTypeFlags {
+ seen: type_info.flags,
+ required: required_type_flags,
+ });
+ }
+
+ if is_resource != var.binding.is_some() {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(GlobalVariableError::InvalidBinding);
+ }
+ }
+
+ Ok(())
+ }
+
+ pub(super) fn validate_entry_point(
+ &mut self,
+ ep: &crate::EntryPoint,
+ module: &crate::Module,
+ mod_info: &ModuleInfo,
+ ) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
+ #[cfg(feature = "validate")]
+ if ep.early_depth_test.is_some() {
+ let required = Capabilities::EARLY_DEPTH_TEST;
+ if !self.capabilities.contains(required) {
+ return Err(
+ EntryPointError::Result(VaryingError::UnsupportedCapability(required))
+ .with_span(),
+ );
+ }
+
+ if ep.stage != crate::ShaderStage::Fragment {
+ return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span());
+ }
+ }
+
+ #[cfg(feature = "validate")]
+ if ep.stage == crate::ShaderStage::Compute {
+ if ep
+ .workgroup_size
+ .iter()
+ .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE)
+ {
+ return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span());
+ }
+ } else if ep.workgroup_size != [0; 3] {
+ return Err(EntryPointError::UnexpectedWorkgroupSize.with_span());
+ }
+
+ let info = self
+ .validate_function(&ep.function, module, mod_info, true)
+ .map_err(WithSpan::into_other)?;
+
+ #[cfg(feature = "validate")]
+ {
+ use super::ShaderStages;
+
+ let stage_bit = match ep.stage {
+ crate::ShaderStage::Vertex => ShaderStages::VERTEX,
+ crate::ShaderStage::Fragment => ShaderStages::FRAGMENT,
+ crate::ShaderStage::Compute => ShaderStages::COMPUTE,
+ };
+
+ if !info.available_stages.contains(stage_bit) {
+ return Err(EntryPointError::ForbiddenStageOperations.with_span());
+ }
+ }
+
+ self.location_mask.clear();
+ let mut argument_built_ins = crate::FastHashSet::default();
+ // TODO: add span info to function arguments
+ for (index, fa) in ep.function.arguments.iter().enumerate() {
+ let mut ctx = VaryingContext {
+ stage: ep.stage,
+ output: false,
+ types: &module.types,
+ type_info: &self.types,
+ location_mask: &mut self.location_mask,
+ built_ins: &mut argument_built_ins,
+ capabilities: self.capabilities,
+
+ #[cfg(feature = "validate")]
+ flags: self.flags,
+ };
+ ctx.validate(fa.ty, fa.binding.as_ref())
+ .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
+ }
+
+ self.location_mask.clear();
+ if let Some(ref fr) = ep.function.result {
+ let mut result_built_ins = crate::FastHashSet::default();
+ let mut ctx = VaryingContext {
+ stage: ep.stage,
+ output: true,
+ types: &module.types,
+ type_info: &self.types,
+ location_mask: &mut self.location_mask,
+ built_ins: &mut result_built_ins,
+ capabilities: self.capabilities,
+
+ #[cfg(feature = "validate")]
+ flags: self.flags,
+ };
+ ctx.validate(fr.ty, fr.binding.as_ref())
+ .map_err_inner(|e| EntryPointError::Result(e).with_span())?;
+
+ #[cfg(feature = "validate")]
+ if ep.stage == crate::ShaderStage::Vertex
+ && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false })
+ {
+ return Err(EntryPointError::MissingVertexOutputPosition.with_span());
+ }
+ } else if ep.stage == crate::ShaderStage::Vertex {
+ #[cfg(feature = "validate")]
+ return Err(EntryPointError::MissingVertexOutputPosition.with_span());
+ }
+
+ for bg in self.bind_group_masks.iter_mut() {
+ bg.clear();
+ }
+
+ #[cfg(feature = "validate")]
+ for (var_handle, var) in module.global_variables.iter() {
+ let usage = info[var_handle];
+ if usage.is_empty() {
+ continue;
+ }
+
+ let allowed_usage = match var.space {
+ crate::AddressSpace::Function => unreachable!(),
+ crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY,
+ crate::AddressSpace::Storage { access } => storage_usage(access),
+ crate::AddressSpace::Handle => match module.types[var.ty].inner {
+ crate::TypeInner::BindingArray { base, .. } => match module.types[base].inner {
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => storage_usage(access),
+ _ => GlobalUse::READ | GlobalUse::QUERY,
+ },
+ crate::TypeInner::Image {
+ class: crate::ImageClass::Storage { access, .. },
+ ..
+ } => storage_usage(access),
+ _ => GlobalUse::READ | GlobalUse::QUERY,
+ },
+ crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => GlobalUse::all(),
+ crate::AddressSpace::PushConstant => GlobalUse::READ,
+ };
+ if !allowed_usage.contains(usage) {
+ log::warn!("\tUsage error for: {:?}", var);
+ log::warn!(
+ "\tAllowed usage: {:?}, requested: {:?}",
+ allowed_usage,
+ usage
+ );
+ return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)
+ .with_span_handle(var_handle, &module.global_variables));
+ }
+
+ if let Some(ref bind) = var.binding {
+ while self.bind_group_masks.len() <= bind.group as usize {
+ self.bind_group_masks.push(BitSet::new());
+ }
+ if !self.bind_group_masks[bind.group as usize].insert(bind.binding as usize) {
+ if self.flags.contains(super::ValidationFlags::BINDINGS) {
+ return Err(EntryPointError::BindingCollision(var_handle)
+ .with_span_handle(var_handle, &module.global_variables));
+ }
+ }
+ }
+ }
+
+ Ok(info)
+ }
+}
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs
new file mode 100644
index 0000000000..6b3a2e1456
--- /dev/null
+++ b/third_party/rust/naga/src/valid/mod.rs
@@ -0,0 +1,467 @@
+/*!
+Shader validator.
+*/
+
+mod analyzer;
+mod compose;
+mod expression;
+mod function;
+mod handles;
+mod interface;
+mod r#type;
+
+#[cfg(feature = "validate")]
+use crate::arena::{Arena, UniqueArena};
+
+use crate::{
+ arena::Handle,
+ proc::{LayoutError, Layouter},
+ FastHashSet,
+};
+use bit_set::BitSet;
+use std::ops;
+
+//TODO: analyze the model at the same time as we validate it,
+// merge the corresponding matches over expressions and statements.
+
+use crate::span::{AddSpan as _, WithSpan};
+pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
+pub use compose::ComposeError;
+pub use expression::ExpressionError;
+pub use function::{CallError, FunctionError, LocalVariableError};
+pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
+pub use r#type::{Disalignment, TypeError, TypeFlags};
+
+use self::handles::InvalidHandleError;
+
+bitflags::bitflags! {
+ /// Validation flags.
+ ///
+ /// If you are working with trusted shaders, then you may be able
+ /// to save some time by skipping validation.
+ ///
+ /// If you do not perform full validation, invalid shaders may
+ /// cause Naga to panic. If you do perform full validation and
+ /// [`Validator::validate`] returns `Ok`, then Naga promises that
+ /// code generation will either succeed or return an error; it
+ /// should never panic.
+ ///
+ /// The default value for `ValidationFlags` is
+ /// `ValidationFlags::all()`. If Naga's `"validate"` feature is
+ /// enabled, this requests full validation; otherwise, this
+ /// requests no validation. (The `"validate"` feature is disabled
+ /// by default.)
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct ValidationFlags: u8 {
+ /// Expressions.
+ #[cfg(feature = "validate")]
+ const EXPRESSIONS = 0x1;
+ /// Statements and blocks of them.
+ #[cfg(feature = "validate")]
+ const BLOCKS = 0x2;
+ /// Uniformity of control flow for operations that require it.
+ #[cfg(feature = "validate")]
+ const CONTROL_FLOW_UNIFORMITY = 0x4;
+ /// Host-shareable structure layouts.
+ #[cfg(feature = "validate")]
+ const STRUCT_LAYOUTS = 0x8;
+ /// Constants.
+ #[cfg(feature = "validate")]
+ const CONSTANTS = 0x10;
+ /// Group, binding, and location attributes.
+ #[cfg(feature = "validate")]
+ const BINDINGS = 0x20;
+ }
+}
+
+impl Default for ValidationFlags {
+ fn default() -> Self {
+ Self::all()
+ }
+}
+
+bitflags::bitflags! {
+ /// Allowed IR capabilities.
+ #[must_use]
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct Capabilities: u16 {
+ /// Support for [`AddressSpace:PushConstant`].
+ const PUSH_CONSTANT = 0x1;
+ /// Float values with width = 8.
+ const FLOAT64 = 0x2;
+ /// Support for [`Builtin:PrimitiveIndex`].
+ const PRIMITIVE_INDEX = 0x4;
+ /// Support for non-uniform indexing of sampled textures and storage buffer arrays.
+ const SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING = 0x8;
+ /// Support for non-uniform indexing of uniform buffers and storage texture arrays.
+ const UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING = 0x10;
+ /// Support for non-uniform indexing of samplers.
+ const SAMPLER_NON_UNIFORM_INDEXING = 0x20;
+ /// Support for [`Builtin::ClipDistance`].
+ const CLIP_DISTANCE = 0x40;
+ /// Support for [`Builtin::CullDistance`].
+ const CULL_DISTANCE = 0x80;
+ /// Support for 16-bit normalized storage texture formats.
+ const STORAGE_TEXTURE_16BIT_NORM_FORMATS = 0x100;
+ /// Support for [`BuiltIn::ViewIndex`].
+ const MULTIVIEW = 0x200;
+ /// Support for `early_depth_test`.
+ const EARLY_DEPTH_TEST = 0x400;
+ /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`].
+ const MULTISAMPLED_SHADING = 0x800;
+ /// Support for ray queries and acceleration structures.
+ const RAY_QUERY = 0x1000;
+ }
+}
+
+impl Default for Capabilities {
+ fn default() -> Self {
+ Self::MULTISAMPLED_SHADING
+ }
+}
+
+bitflags::bitflags! {
+ /// Validation flags.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct ShaderStages: u8 {
+ const VERTEX = 0x1;
+ const FRAGMENT = 0x2;
+ const COMPUTE = 0x4;
+ }
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct ModuleInfo {
+ type_flags: Vec<TypeFlags>,
+ functions: Vec<FunctionInfo>,
+ entry_points: Vec<FunctionInfo>,
+}
+
+impl ops::Index<Handle<crate::Type>> for ModuleInfo {
+ type Output = TypeFlags;
+ fn index(&self, handle: Handle<crate::Type>) -> &Self::Output {
+ &self.type_flags[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Function>> for ModuleInfo {
+ type Output = FunctionInfo;
+ fn index(&self, handle: Handle<crate::Function>) -> &Self::Output {
+ &self.functions[handle.index()]
+ }
+}
+
+#[derive(Debug)]
+pub struct Validator {
+ flags: ValidationFlags,
+ capabilities: Capabilities,
+ types: Vec<r#type::TypeInfo>,
+ layouter: Layouter,
+ location_mask: BitSet,
+ bind_group_masks: Vec<BitSet>,
+ #[allow(dead_code)]
+ switch_values: FastHashSet<crate::SwitchValue>,
+ valid_expression_list: Vec<Handle<crate::Expression>>,
+ valid_expression_set: BitSet,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ConstantError {
+ #[error("The type doesn't match the constant")]
+ InvalidType,
+ #[error("The component handle {0:?} can not be resolved")]
+ UnresolvedComponent(Handle<crate::Constant>),
+ #[error("The array size handle {0:?} can not be resolved")]
+ UnresolvedSize(Handle<crate::Constant>),
+ #[error(transparent)]
+ Compose(#[from] ComposeError),
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum ValidationError {
+ #[error(transparent)]
+ InvalidHandle(#[from] InvalidHandleError),
+ #[error(transparent)]
+ Layouter(#[from] LayoutError),
+ #[error("Type {handle:?} '{name}' is invalid")]
+ Type {
+ handle: Handle<crate::Type>,
+ name: String,
+ source: TypeError,
+ },
+ #[error("Constant {handle:?} '{name}' is invalid")]
+ Constant {
+ handle: Handle<crate::Constant>,
+ name: String,
+ source: ConstantError,
+ },
+ #[error("Global variable {handle:?} '{name}' is invalid")]
+ GlobalVariable {
+ handle: Handle<crate::GlobalVariable>,
+ name: String,
+ source: GlobalVariableError,
+ },
+ #[error("Function {handle:?} '{name}' is invalid")]
+ Function {
+ handle: Handle<crate::Function>,
+ name: String,
+ source: FunctionError,
+ },
+ #[error("Entry point {name} at {stage:?} is invalid")]
+ EntryPoint {
+ stage: crate::ShaderStage,
+ name: String,
+ source: EntryPointError,
+ },
+ #[error("Module is corrupted")]
+ Corrupted,
+}
+
+impl crate::TypeInner {
+ #[cfg(feature = "validate")]
+ const fn is_sized(&self) -> bool {
+ match *self {
+ Self::Scalar { .. }
+ | Self::Vector { .. }
+ | Self::Matrix { .. }
+ | Self::Array {
+ size: crate::ArraySize::Constant(_),
+ ..
+ }
+ | Self::Atomic { .. }
+ | Self::Pointer { .. }
+ | Self::ValuePointer { .. }
+ | Self::Struct { .. } => true,
+ Self::Array { .. }
+ | Self::Image { .. }
+ | Self::Sampler { .. }
+ | Self::AccelerationStructure
+ | Self::RayQuery
+ | Self::BindingArray { .. } => false,
+ }
+ }
+
+ /// Return the `ImageDimension` for which `self` is an appropriate coordinate.
+ #[cfg(feature = "validate")]
+ const fn image_storage_coordinates(&self) -> Option<crate::ImageDimension> {
+ match *self {
+ Self::Scalar {
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ } => Some(crate::ImageDimension::D1),
+ Self::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ } => Some(crate::ImageDimension::D2),
+ Self::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
+ ..
+ } => Some(crate::ImageDimension::D3),
+ _ => None,
+ }
+ }
+}
+
+impl Validator {
+ /// Construct a new validator instance.
+ pub fn new(flags: ValidationFlags, capabilities: Capabilities) -> Self {
+ Validator {
+ flags,
+ capabilities,
+ types: Vec::new(),
+ layouter: Layouter::default(),
+ location_mask: BitSet::new(),
+ bind_group_masks: Vec::new(),
+ switch_values: FastHashSet::default(),
+ valid_expression_list: Vec::new(),
+ valid_expression_set: BitSet::new(),
+ }
+ }
+
+ /// Reset the validator internals
+ pub fn reset(&mut self) {
+ self.types.clear();
+ self.layouter.clear();
+ self.location_mask.clear();
+ self.bind_group_masks.clear();
+ self.switch_values.clear();
+ self.valid_expression_list.clear();
+ self.valid_expression_set.clear();
+ }
+
+ #[cfg(feature = "validate")]
+ fn validate_constant(
+ &self,
+ handle: Handle<crate::Constant>,
+ constants: &Arena<crate::Constant>,
+ types: &UniqueArena<crate::Type>,
+ ) -> Result<(), ConstantError> {
+ let con = &constants[handle];
+ match con.inner {
+ crate::ConstantInner::Scalar { width, ref value } => {
+ if self.check_width(value.scalar_kind(), width).is_err() {
+ return Err(ConstantError::InvalidType);
+ }
+ }
+ crate::ConstantInner::Composite { ty, ref components } => {
+ match types[ty].inner {
+ crate::TypeInner::Array {
+ size: crate::ArraySize::Constant(size_handle),
+ ..
+ } if handle <= size_handle => {
+ return Err(ConstantError::UnresolvedSize(size_handle));
+ }
+ _ => {}
+ }
+ if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) {
+ return Err(ConstantError::UnresolvedComponent(comp));
+ }
+ compose::validate_compose(
+ ty,
+ constants,
+ types,
+ components
+ .iter()
+ .map(|&component| constants[component].inner.resolve_type()),
+ )?;
+ }
+ }
+ Ok(())
+ }
+
+ /// Check the given module to be valid.
+ pub fn validate(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.reset();
+ self.reset_types(module.types.len());
+
+ #[cfg(feature = "validate")]
+ Self::validate_module_handles(module).map_err(|e| e.with_span())?;
+
+ self.layouter
+ .update(&module.types, &module.constants)
+ .map_err(|e| {
+ let handle = e.ty;
+ ValidationError::from(e).with_span_handle(handle, &module.types)
+ })?;
+
+ #[cfg(feature = "validate")]
+ if self.flags.contains(ValidationFlags::CONSTANTS) {
+ for (handle, constant) in module.constants.iter() {
+ self.validate_constant(handle, &module.constants, &module.types)
+ .map_err(|source| {
+ ValidationError::Constant {
+ handle,
+ name: constant.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.constants)
+ })?
+ }
+ }
+
+ let mut mod_info = ModuleInfo {
+ type_flags: Vec::with_capacity(module.types.len()),
+ functions: Vec::with_capacity(module.functions.len()),
+ entry_points: Vec::with_capacity(module.entry_points.len()),
+ };
+
+ for (handle, ty) in module.types.iter() {
+ let ty_info = self
+ .validate_type(handle, &module.types, &module.constants)
+ .map_err(|source| {
+ ValidationError::Type {
+ handle,
+ name: ty.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.types)
+ })?;
+ mod_info.type_flags.push(ty_info.flags);
+ self.types[handle.index()] = ty_info;
+ }
+
+ #[cfg(feature = "validate")]
+ for (var_handle, var) in module.global_variables.iter() {
+ self.validate_global_var(var, &module.types)
+ .map_err(|source| {
+ ValidationError::GlobalVariable {
+ handle: var_handle,
+ name: var.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(var_handle, &module.global_variables)
+ })?;
+ }
+
+ for (handle, fun) in module.functions.iter() {
+ match self.validate_function(fun, module, &mod_info, false) {
+ Ok(info) => mod_info.functions.push(info),
+ Err(error) => {
+ return Err(error.and_then(|source| {
+ ValidationError::Function {
+ handle,
+ name: fun.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.functions)
+ }))
+ }
+ }
+ }
+
+ let mut ep_map = FastHashSet::default();
+ for ep in module.entry_points.iter() {
+ if !ep_map.insert((ep.stage, &ep.name)) {
+ return Err(ValidationError::EntryPoint {
+ stage: ep.stage,
+ name: ep.name.clone(),
+ source: EntryPointError::Conflict,
+ }
+ .with_span()); // TODO: keep some EP span information?
+ }
+
+ match self.validate_entry_point(ep, module, &mod_info) {
+ Ok(info) => mod_info.entry_points.push(info),
+ Err(error) => {
+ return Err(error.and_then(|source| {
+ ValidationError::EntryPoint {
+ stage: ep.stage,
+ name: ep.name.clone(),
+ source,
+ }
+ .with_span()
+ }));
+ }
+ }
+ }
+
+ Ok(mod_info)
+ }
+}
+
+#[cfg(feature = "validate")]
+fn validate_atomic_compare_exchange_struct(
+ types: &UniqueArena<crate::Type>,
+ members: &[crate::StructMember],
+ scalar_predicate: impl FnOnce(&crate::TypeInner) -> bool,
+) -> bool {
+ members.len() == 2
+ && members[0].name.as_deref() == Some("old_value")
+ && scalar_predicate(&types[members[0].ty].inner)
+ && members[1].name.as_deref() == Some("exchanged")
+ && types[members[1].ty].inner
+ == crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ }
+}
diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs
new file mode 100644
index 0000000000..d8dd37d09b
--- /dev/null
+++ b/third_party/rust/naga/src/valid/type.rs
@@ -0,0 +1,636 @@
+use super::Capabilities;
+use crate::{
+ arena::{Arena, Handle, UniqueArena},
+ proc::Alignment,
+};
+
+bitflags::bitflags! {
+ /// Flags associated with [`Type`]s by [`Validator`].
+ ///
+ /// [`Type`]: crate::Type
+ /// [`Validator`]: crate::valid::Validator
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[repr(transparent)]
+ pub struct TypeFlags: u8 {
+ /// Can be used for data variables.
+ ///
+ /// This flag is required on types of local variables, function
+ /// arguments, array elements, and struct members.
+ ///
+ /// This includes all types except `Image`, `Sampler`,
+ /// and some `Pointer` types.
+ const DATA = 0x1;
+
+ /// The data type has a size known by pipeline creation time.
+ ///
+ /// Unsized types are quite restricted. The only unsized types permitted
+ /// by Naga, other than the non-[`DATA`] types like [`Image`] and
+ /// [`Sampler`], are dynamically-sized [`Array`s], and [`Struct`s] whose
+ /// last members are such arrays. See the documentation for those types
+ /// for details.
+ ///
+ /// [`DATA`]: TypeFlags::DATA
+ /// [`Image`]: crate::Type::Image
+ /// [`Sampler`]: crate::Type::Sampler
+ /// [`Array`]: crate::Type::Array
+ /// [`Struct`]: crate::Type::struct
+ const SIZED = 0x2;
+
+ /// The data can be copied around.
+ const COPY = 0x4;
+
+ /// Can be be used for user-defined IO between pipeline stages.
+ ///
+ /// This covers anything that can be in [`Location`] binding:
+ /// non-bool scalars and vectors, matrices, and structs and
+ /// arrays containing only interface types.
+ const IO_SHAREABLE = 0x8;
+
+ /// Can be used for host-shareable structures.
+ const HOST_SHAREABLE = 0x10;
+
+ /// This type can be passed as a function argument.
+ const ARGUMENT = 0x40;
+
+ /// A WGSL [constructible] type.
+ ///
+ /// The constructible types are scalars, vectors, matrices, fixed-size
+ /// arrays of constructible types, and structs whose members are all
+ /// constructible.
+ ///
+ /// [constructible]: https://gpuweb.github.io/gpuweb/wgsl/#constructible
+ const CONSTRUCTIBLE = 0x80;
+ }
+}
+
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+pub enum Disalignment {
+ #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")]
+ ArrayStride { stride: u32, alignment: Alignment },
+ #[error("The struct span {span}, is not a multiple of the required alignment {alignment}")]
+ StructSpan { span: u32, alignment: Alignment },
+ #[error("The struct member[{index}] offset {offset} is not a multiple of the required alignment {alignment}")]
+ MemberOffset {
+ index: u32,
+ offset: u32,
+ alignment: Alignment,
+ },
+ #[error("The struct member[{index}] offset {offset} must be at least {expected}")]
+ MemberOffsetAfterStruct {
+ index: u32,
+ offset: u32,
+ expected: u32,
+ },
+ #[error("The struct member[{index}] is not statically sized")]
+ UnsizedMember { index: u32 },
+ #[error("The type is not host-shareable")]
+ NonHostShareable,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+pub enum TypeError {
+ #[error("Capability {0:?} is required")]
+ MissingCapability(Capabilities),
+ #[error("The {0:?} scalar width {1} is not supported")]
+ InvalidWidth(crate::ScalarKind, crate::Bytes),
+ #[error("The {0:?} scalar width {1} is not supported for an atomic")]
+ InvalidAtomicWidth(crate::ScalarKind, crate::Bytes),
+ #[error("The base handle {0:?} can not be resolved")]
+ UnresolvedBase(Handle<crate::Type>),
+ #[error("Invalid type for pointer target {0:?}")]
+ InvalidPointerBase(Handle<crate::Type>),
+ #[error("Unsized types like {base:?} must be in the `Storage` address space, not `{space:?}`")]
+ InvalidPointerToUnsized {
+ base: Handle<crate::Type>,
+ space: crate::AddressSpace,
+ },
+ #[error("Expected data type, found {0:?}")]
+ InvalidData(Handle<crate::Type>),
+ #[error("Base type {0:?} for the array is invalid")]
+ InvalidArrayBaseType(Handle<crate::Type>),
+ #[error("The constant {0:?} can not be used for an array size")]
+ InvalidArraySizeConstant(Handle<crate::Constant>),
+ #[error("The constant {0:?} is specialized, and cannot be used as an array size")]
+ UnsupportedSpecializedArrayLength(Handle<crate::Constant>),
+ #[error("Array type {0:?} must have a length of one or more")]
+ NonPositiveArrayLength(Handle<crate::Constant>),
+ #[error("Array stride {stride} does not match the expected {expected}")]
+ InvalidArrayStride { stride: u32, expected: u32 },
+ #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")]
+ InvalidDynamicArray(String, Handle<crate::Type>),
+ #[error("Structure member[{index}] at {offset} overlaps the previous member")]
+ MemberOverlap { index: u32, offset: u32 },
+ #[error(
+ "Structure member[{index}] at {offset} and size {size} crosses the structure boundary of size {span}"
+ )]
+ MemberOutOfBounds {
+ index: u32,
+ offset: u32,
+ size: u32,
+ span: u32,
+ },
+ #[error("Structure types must have at least one member")]
+ EmptyStruct,
+}
+
+// Only makes sense if `flags.contains(HOST_SHAREABLE)`
+type LayoutCompatibility = Result<Alignment, (Handle<crate::Type>, Disalignment)>;
+
+fn check_member_layout(
+ accum: &mut LayoutCompatibility,
+ member: &crate::StructMember,
+ member_index: u32,
+ member_layout: LayoutCompatibility,
+ parent_handle: Handle<crate::Type>,
+) {
+ *accum = match (*accum, member_layout) {
+ (Ok(cur_alignment), Ok(alignment)) => {
+ if alignment.is_aligned(member.offset) {
+ Ok(cur_alignment.max(alignment))
+ } else {
+ Err((
+ parent_handle,
+ Disalignment::MemberOffset {
+ index: member_index,
+ offset: member.offset,
+ alignment,
+ },
+ ))
+ }
+ }
+ (Err(e), _) | (_, Err(e)) => Err(e),
+ };
+}
+
+/// Determine whether a pointer in `space` can be passed as an argument.
+///
+/// If a pointer in `space` is permitted to be passed as an argument to a
+/// user-defined function, return `TypeFlags::ARGUMENT`. Otherwise, return
+/// `TypeFlags::empty()`.
+///
+/// Pointers passed as arguments to user-defined functions must be in the
+/// `Function`, `Private`, or `Workgroup` storage space.
+const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags {
+ use crate::AddressSpace as As;
+ match space {
+ As::Function | As::Private | As::WorkGroup => TypeFlags::ARGUMENT,
+ As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant => TypeFlags::empty(),
+ }
+}
+
+#[derive(Clone, Debug)]
+pub(super) struct TypeInfo {
+ pub flags: TypeFlags,
+ pub uniform_layout: LayoutCompatibility,
+ pub storage_layout: LayoutCompatibility,
+}
+
+impl TypeInfo {
+ const fn dummy() -> Self {
+ TypeInfo {
+ flags: TypeFlags::empty(),
+ uniform_layout: Ok(Alignment::ONE),
+ storage_layout: Ok(Alignment::ONE),
+ }
+ }
+
+ const fn new(flags: TypeFlags, alignment: Alignment) -> Self {
+ TypeInfo {
+ flags,
+ uniform_layout: Ok(alignment),
+ storage_layout: Ok(alignment),
+ }
+ }
+}
+
+impl super::Validator {
+ const fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> {
+ if self.capabilities.contains(capability) {
+ Ok(())
+ } else {
+ Err(TypeError::MissingCapability(capability))
+ }
+ }
+
+ pub(super) fn check_width(
+ &self,
+ kind: crate::ScalarKind,
+ width: crate::Bytes,
+ ) -> Result<(), TypeError> {
+ let good = match kind {
+ crate::ScalarKind::Bool => width == crate::BOOL_WIDTH,
+ crate::ScalarKind::Float => {
+ if width == 8 {
+ self.require_type_capability(Capabilities::FLOAT64)?;
+ true
+ } else {
+ width == 4
+ }
+ }
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4,
+ };
+ if good {
+ Ok(())
+ } else {
+ Err(TypeError::InvalidWidth(kind, width))
+ }
+ }
+
+ pub(super) fn reset_types(&mut self, size: usize) {
+ self.types.clear();
+ self.types.resize(size, TypeInfo::dummy());
+ self.layouter.clear();
+ }
+
+ pub(super) fn validate_type(
+ &self,
+ handle: Handle<crate::Type>,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<TypeInfo, TypeError> {
+ use crate::TypeInner as Ti;
+ Ok(match types[handle].inner {
+ Ti::Scalar { kind, width } => {
+ self.check_width(kind, width)?;
+ let shareable = if kind.is_numeric() {
+ TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE
+ } else {
+ TypeFlags::empty()
+ };
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ | shareable,
+ Alignment::from_width(width),
+ )
+ }
+ Ti::Vector { size, kind, width } => {
+ self.check_width(kind, width)?;
+ let shareable = if kind.is_numeric() {
+ TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE
+ } else {
+ TypeFlags::empty()
+ };
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ | shareable,
+ Alignment::from(size) * Alignment::from_width(width),
+ )
+ }
+ Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ } => {
+ self.check_width(crate::ScalarKind::Float, width)?;
+ TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE,
+ Alignment::from(rows) * Alignment::from_width(width),
+ )
+ }
+ Ti::Atomic { kind, width } => {
+ let good = match kind {
+ crate::ScalarKind::Bool | crate::ScalarKind::Float => false,
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4,
+ };
+ if !good {
+ return Err(TypeError::InvalidAtomicWidth(kind, width));
+ }
+ TypeInfo::new(
+ TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::HOST_SHAREABLE,
+ Alignment::from_width(width),
+ )
+ }
+ Ti::Pointer { base, space } => {
+ use crate::AddressSpace as As;
+
+ if base >= handle {
+ return Err(TypeError::UnresolvedBase(base));
+ }
+
+ let base_info = &self.types[base.index()];
+ if !base_info.flags.contains(TypeFlags::DATA) {
+ return Err(TypeError::InvalidPointerBase(base));
+ }
+
+ // Runtime-sized values can only live in the `Storage` storage
+ // space, so it's useless to have a pointer to such a type in
+ // any other space.
+ //
+ // Detecting this problem here prevents the definition of
+ // functions like:
+ //
+ // fn f(p: ptr<workgroup, UnsizedType>) -> ... { ... }
+ //
+ // which would otherwise be permitted, but uncallable. (They
+ // may also present difficulties in code generation).
+ if !base_info.flags.contains(TypeFlags::SIZED) {
+ match space {
+ As::Storage { .. } => {}
+ _ => {
+ return Err(TypeError::InvalidPointerToUnsized { base, space });
+ }
+ }
+ }
+
+ // `Validator::validate_function` actually checks the storage
+ // space of pointer arguments explicitly before checking the
+ // `ARGUMENT` flag, to give better error messages. But it seems
+ // best to set `ARGUMENT` accurately anyway.
+ let argument_flag = ptr_space_argument_flag(space);
+
+ // Pointers cannot be stored in variables, structure members, or
+ // array elements, so we do not mark them as `DATA`.
+ TypeInfo::new(
+ argument_flag | TypeFlags::SIZED | TypeFlags::COPY,
+ Alignment::ONE,
+ )
+ }
+ Ti::ValuePointer {
+ size: _,
+ kind,
+ width,
+ space,
+ } => {
+ // ValuePointer should be treated the same way as the equivalent
+ // Pointer / Scalar / Vector combination, so each step in those
+ // variants' match arms should have a counterpart here.
+ //
+ // However, some cases are trivial: All our implicit base types
+ // are DATA and SIZED, so we can never return
+ // `InvalidPointerBase` or `InvalidPointerToUnsized`.
+ self.check_width(kind, width)?;
+
+ // `Validator::validate_function` actually checks the storage
+ // space of pointer arguments explicitly before checking the
+ // `ARGUMENT` flag, to give better error messages. But it seems
+ // best to set `ARGUMENT` accurately anyway.
+ let argument_flag = ptr_space_argument_flag(space);
+
+ // Pointers cannot be stored in variables, structure members, or
+ // array elements, so we do not mark them as `DATA`.
+ TypeInfo::new(
+ argument_flag | TypeFlags::SIZED | TypeFlags::COPY,
+ Alignment::ONE,
+ )
+ }
+ Ti::Array { base, size, stride } => {
+ if base >= handle {
+ return Err(TypeError::UnresolvedBase(base));
+ }
+ let base_info = &self.types[base.index()];
+ if !base_info.flags.contains(TypeFlags::DATA | TypeFlags::SIZED) {
+ return Err(TypeError::InvalidArrayBaseType(base));
+ }
+
+ let base_layout = self.layouter[base];
+ let general_alignment = base_layout.alignment;
+ let uniform_layout = match base_info.uniform_layout {
+ Ok(base_alignment) => {
+ let alignment = base_alignment
+ .max(general_alignment)
+ .max(Alignment::MIN_UNIFORM);
+ if alignment.is_aligned(stride) {
+ Ok(alignment)
+ } else {
+ Err((handle, Disalignment::ArrayStride { stride, alignment }))
+ }
+ }
+ Err(e) => Err(e),
+ };
+ let storage_layout = match base_info.storage_layout {
+ Ok(base_alignment) => {
+ let alignment = base_alignment.max(general_alignment);
+ if alignment.is_aligned(stride) {
+ Ok(alignment)
+ } else {
+ Err((handle, Disalignment::ArrayStride { stride, alignment }))
+ }
+ }
+ Err(e) => Err(e),
+ };
+
+ let type_info_mask = match size {
+ crate::ArraySize::Constant(const_handle) => {
+ let constant = &constants[const_handle];
+ let length_is_positive = match *constant {
+ crate::Constant {
+ specialization: Some(_),
+ ..
+ } => {
+ // Many of our back ends don't seem to support
+ // specializable array lengths. If you want to try to make
+ // this work, be sure to address all uses of
+ // `Constant::to_array_length`, which ignores
+ // specialization.
+ return Err(TypeError::UnsupportedSpecializedArrayLength(
+ const_handle,
+ ));
+ }
+ crate::Constant {
+ inner:
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Uint(length),
+ },
+ ..
+ } => length > 0,
+ // Accept a signed integer size to avoid
+ // requiring an explicit uint
+ // literal. Type inference should make
+ // this unnecessary.
+ crate::Constant {
+ inner:
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(length),
+ },
+ ..
+ } => length > 0,
+ _ => {
+ log::warn!("Array size {:?}", constant);
+ return Err(TypeError::InvalidArraySizeConstant(const_handle));
+ }
+ };
+
+ if !length_is_positive {
+ return Err(TypeError::NonPositiveArrayLength(const_handle));
+ }
+
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE
+ }
+ crate::ArraySize::Dynamic => {
+ // Non-SIZED types may only appear as the last element of a structure.
+ // This is enforced by checks for SIZED-ness for all compound types,
+ // and a special case for structs.
+ TypeFlags::DATA | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE
+ }
+ };
+
+ TypeInfo {
+ flags: base_info.flags & type_info_mask,
+ uniform_layout,
+ storage_layout,
+ }
+ }
+ Ti::Struct { ref members, span } => {
+ if members.is_empty() {
+ return Err(TypeError::EmptyStruct);
+ }
+
+ let mut ti = TypeInfo::new(
+ TypeFlags::DATA
+ | TypeFlags::SIZED
+ | TypeFlags::COPY
+ | TypeFlags::HOST_SHAREABLE
+ | TypeFlags::IO_SHAREABLE
+ | TypeFlags::ARGUMENT
+ | TypeFlags::CONSTRUCTIBLE,
+ Alignment::ONE,
+ );
+ ti.uniform_layout = Ok(Alignment::MIN_UNIFORM);
+
+ let mut min_offset = 0;
+
+ let mut prev_struct_data: Option<(u32, u32)> = None;
+
+ for (i, member) in members.iter().enumerate() {
+ if member.ty >= handle {
+ return Err(TypeError::UnresolvedBase(member.ty));
+ }
+ let base_info = &self.types[member.ty.index()];
+ if !base_info.flags.contains(TypeFlags::DATA) {
+ return Err(TypeError::InvalidData(member.ty));
+ }
+ if !base_info.flags.contains(TypeFlags::HOST_SHAREABLE) {
+ if ti.uniform_layout.is_ok() {
+ ti.uniform_layout = Err((member.ty, Disalignment::NonHostShareable));
+ }
+ if ti.storage_layout.is_ok() {
+ ti.storage_layout = Err((member.ty, Disalignment::NonHostShareable));
+ }
+ }
+ ti.flags &= base_info.flags;
+
+ if member.offset < min_offset {
+ // HACK: this could be nicer. We want to allow some structures
+ // to not bother with offsets/alignments if they are never
+ // used for host sharing.
+ if member.offset == 0 {
+ ti.flags.set(TypeFlags::HOST_SHAREABLE, false);
+ } else {
+ return Err(TypeError::MemberOverlap {
+ index: i as u32,
+ offset: member.offset,
+ });
+ }
+ }
+
+ let base_size = types[member.ty].inner.size(constants);
+ min_offset = member.offset + base_size;
+ if min_offset > span {
+ return Err(TypeError::MemberOutOfBounds {
+ index: i as u32,
+ offset: member.offset,
+ size: base_size,
+ span,
+ });
+ }
+
+ check_member_layout(
+ &mut ti.uniform_layout,
+ member,
+ i as u32,
+ base_info.uniform_layout,
+ handle,
+ );
+ check_member_layout(
+ &mut ti.storage_layout,
+ member,
+ i as u32,
+ base_info.storage_layout,
+ handle,
+ );
+
+ // Validate rule: If a structure member itself has a structure type S,
+ // then the number of bytes between the start of that member and
+ // the start of any following member must be at least roundUp(16, SizeOf(S)).
+ if let Some((span, offset)) = prev_struct_data {
+ let diff = member.offset - offset;
+ let min = Alignment::MIN_UNIFORM.round_up(span);
+ if diff < min {
+ ti.uniform_layout = Err((
+ handle,
+ Disalignment::MemberOffsetAfterStruct {
+ index: i as u32,
+ offset: member.offset,
+ expected: offset + min,
+ },
+ ));
+ }
+ };
+
+ prev_struct_data = match types[member.ty].inner {
+ crate::TypeInner::Struct { span, .. } => Some((span, member.offset)),
+ _ => None,
+ };
+
+ // The last field may be an unsized array.
+ if !base_info.flags.contains(TypeFlags::SIZED) {
+ let is_array = match types[member.ty].inner {
+ crate::TypeInner::Array { .. } => true,
+ _ => false,
+ };
+ if !is_array || i + 1 != members.len() {
+ let name = member.name.clone().unwrap_or_default();
+ return Err(TypeError::InvalidDynamicArray(name, member.ty));
+ }
+ if ti.uniform_layout.is_ok() {
+ ti.uniform_layout =
+ Err((handle, Disalignment::UnsizedMember { index: i as u32 }));
+ }
+ }
+ }
+
+ let alignment = self.layouter[handle].alignment;
+ if !alignment.is_aligned(span) {
+ ti.uniform_layout = Err((handle, Disalignment::StructSpan { span, alignment }));
+ ti.storage_layout = Err((handle, Disalignment::StructSpan { span, alignment }));
+ }
+
+ ti
+ }
+ Ti::Image { .. } | Ti::Sampler { .. } => {
+ TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE)
+ }
+ Ti::AccelerationStructure => {
+ self.require_type_capability(Capabilities::RAY_QUERY)?;
+ TypeInfo::new(TypeFlags::empty(), Alignment::ONE)
+ }
+ Ti::RayQuery => {
+ self.require_type_capability(Capabilities::RAY_QUERY)?;
+ TypeInfo::new(TypeFlags::DATA | TypeFlags::SIZED, Alignment::ONE)
+ }
+ Ti::BindingArray { .. } => TypeInfo::new(TypeFlags::empty(), Alignment::ONE),
+ })
+ }
+}