summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/valid/analyzer.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/valid/analyzer.rs')
-rw-r--r--third_party/rust/naga/src/valid/analyzer.rs1209
1 files changed, 1209 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs
new file mode 100644
index 0000000000..ec8e9d96e5
--- /dev/null
+++ b/third_party/rust/naga/src/valid/analyzer.rs
@@ -0,0 +1,1209 @@
+/*! Module analyzer.
+
+Figures out the following properties:
+ - control flow uniformity
+ - texture/sampler pairs
+ - expression reference counts
+!*/
+
+use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags};
+use crate::span::{AddSpan as _, WithSpan};
+use crate::{
+ arena::{Arena, Handle},
+ proc::{ResolveContext, ResolveError, TypeResolution},
+};
+use std::ops;
+
+pub type NonUniformResult = Option<Handle<crate::Expression>>;
+
+bitflags::bitflags! {
+ /// Kinds of expressions that require uniform control flow.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct UniformityRequirements: u8 {
+ const WORK_GROUP_BARRIER = 0x1;
+ const DERIVATIVE = 0x2;
+ const IMPLICIT_LEVEL = 0x4;
+ }
+}
+
+/// Uniform control flow characteristics.
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+#[cfg_attr(test, derive(PartialEq))]
+pub struct Uniformity {
+ /// A child expression with non-uniform result.
+ ///
+ /// This means, when the relevant invocations are scheduled on a compute unit,
+ /// they have to use vector registers to store an individual value
+ /// per invocation.
+ ///
+ /// Whenever the control flow is conditioned on such value,
+ /// the hardware needs to keep track of the mask of invocations,
+ /// and process all branches of the control flow.
+ ///
+ /// Any operations that depend on non-uniform results also produce non-uniform.
+ pub non_uniform_result: NonUniformResult,
+ /// If this expression requires uniform control flow, store the reason here.
+ pub requirements: UniformityRequirements,
+}
+
+impl Uniformity {
+ const fn new() -> Self {
+ Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+}
+
+bitflags::bitflags! {
+ struct ExitFlags: u8 {
+ /// Control flow may return from the function, which makes all the
+ /// subsequent statements within the current function (only!)
+ /// to be executed in a non-uniform control flow.
+ const MAY_RETURN = 0x1;
+ /// Control flow may be killed. Anything after `Statement::Kill` is
+ /// considered inside non-uniform context.
+ const MAY_KILL = 0x2;
+ }
+}
+
+/// Uniformity characteristics of a function.
+#[cfg_attr(test, derive(Debug, PartialEq))]
+struct FunctionUniformity {
+ result: Uniformity,
+ exit: ExitFlags,
+}
+
+impl ops::BitOr for FunctionUniformity {
+ type Output = Self;
+ fn bitor(self, other: Self) -> Self {
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: self
+ .result
+ .non_uniform_result
+ .or(other.result.non_uniform_result),
+ requirements: self.result.requirements | other.result.requirements,
+ },
+ exit: self.exit | other.exit,
+ }
+ }
+}
+
+impl FunctionUniformity {
+ const fn new() -> Self {
+ FunctionUniformity {
+ result: Uniformity::new(),
+ exit: ExitFlags::empty(),
+ }
+ }
+
+ /// Returns a disruptor based on the stored exit flags, if any.
+ const fn exit_disruptor(&self) -> Option<UniformityDisruptor> {
+ if self.exit.contains(ExitFlags::MAY_RETURN) {
+ Some(UniformityDisruptor::Return)
+ } else if self.exit.contains(ExitFlags::MAY_KILL) {
+ Some(UniformityDisruptor::Discard)
+ } else {
+ None
+ }
+ }
+}
+
+bitflags::bitflags! {
+ /// Indicates how a global variable is used.
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ pub struct GlobalUse: u8 {
+ /// Data will be read from the variable.
+ const READ = 0x1;
+ /// Data will be written to the variable.
+ const WRITE = 0x2;
+ /// The information about the data is queried.
+ const QUERY = 0x4;
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct SamplingKey {
+ pub image: Handle<crate::GlobalVariable>,
+ pub sampler: Handle<crate::GlobalVariable>,
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct ExpressionInfo {
+ pub uniformity: Uniformity,
+ pub ref_count: usize,
+ assignable_global: Option<Handle<crate::GlobalVariable>>,
+ pub ty: TypeResolution,
+}
+
+impl ExpressionInfo {
+ const fn new() -> Self {
+ ExpressionInfo {
+ uniformity: Uniformity::new(),
+ ref_count: 0,
+ assignable_global: None,
+ // this doesn't matter at this point, will be overwritten
+ ty: TypeResolution::Value(crate::TypeInner::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: 0,
+ }),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+enum GlobalOrArgument {
+ Global(Handle<crate::GlobalVariable>),
+ Argument(u32),
+}
+
+impl GlobalOrArgument {
+ fn from_expression(
+ expression_arena: &Arena<crate::Expression>,
+ expression: Handle<crate::Expression>,
+ ) -> Result<GlobalOrArgument, ExpressionError> {
+ Ok(match expression_arena[expression] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i),
+ crate::Expression::Access { base, .. }
+ | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] {
+ crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var),
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ },
+ _ => return Err(ExpressionError::ExpectedGlobalOrArgument),
+ })
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+struct Sampling {
+ image: GlobalOrArgument,
+ sampler: GlobalOrArgument,
+}
+
+#[derive(Debug)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct FunctionInfo {
+ /// Validation flags.
+ #[allow(dead_code)]
+ flags: ValidationFlags,
+ /// Set of shader stages where calling this function is valid.
+ pub available_stages: ShaderStages,
+ /// Uniformity characteristics.
+ pub uniformity: Uniformity,
+ /// Function may kill the invocation.
+ pub may_kill: bool,
+
+ /// All pairs of (texture, sampler) globals that may be used together in
+ /// sampling operations by this function and its callees. This includes
+ /// pairings that arise when this function passes textures and samplers as
+ /// arguments to its callees.
+ ///
+ /// This table does not include uses of textures and samplers passed as
+ /// arguments to this function itself, since we do not know which globals
+ /// those will be. However, this table *is* exhaustive when computed for an
+ /// entry point function: entry points never receive textures or samplers as
+ /// arguments, so all an entry point's sampling can be reported in terms of
+ /// globals.
+ ///
+ /// The GLSL back end uses this table to construct reflection info that
+ /// clients need to construct texture-combined sampler values.
+ pub sampling_set: crate::FastHashSet<SamplingKey>,
+
+ /// How this function and its callees use this module's globals.
+ ///
+ /// This is indexed by `Handle<GlobalVariable>` indices. However,
+ /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`,
+ /// so you can simply index this struct with a global handle to retrieve
+ /// its usage information.
+ global_uses: Box<[GlobalUse]>,
+
+ /// Information about each expression in this function's body.
+ ///
+ /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo`
+ /// implements `std::ops::Index<Handle<Expression>>`, so you can simply
+ /// index this struct with an expression handle to retrieve its
+ /// `ExpressionInfo`.
+ expressions: Box<[ExpressionInfo]>,
+
+ /// All (texture, sampler) pairs that may be used together in sampling
+ /// operations by this function and its callees, whether they are accessed
+ /// as globals or passed as arguments.
+ ///
+ /// Participants are represented by [`GlobalVariable`] handles whenever
+ /// possible, and otherwise by indices of this function's arguments.
+ ///
+ /// When analyzing a function call, we combine this data about the callee
+ /// with the actual arguments being passed to produce the callers' own
+ /// `sampling_set` and `sampling` tables.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ sampling: crate::FastHashSet<Sampling>,
+}
+
+impl FunctionInfo {
+ pub const fn global_variable_count(&self) -> usize {
+ self.global_uses.len()
+ }
+ pub const fn expression_count(&self) -> usize {
+ self.expressions.len()
+ }
+ pub fn dominates_global_use(&self, other: &Self) -> bool {
+ for (self_global_uses, other_global_uses) in
+ self.global_uses.iter().zip(other.global_uses.iter())
+ {
+ if !self_global_uses.contains(*other_global_uses) {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo {
+ type Output = GlobalUse;
+ fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse {
+ &self.global_uses[handle.index()]
+ }
+}
+
+impl ops::Index<Handle<crate::Expression>> for FunctionInfo {
+ type Output = ExpressionInfo;
+ fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo {
+ &self.expressions[handle.index()]
+ }
+}
+
+/// Disruptor of the uniform control flow.
+#[derive(Clone, Copy, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum UniformityDisruptor {
+ #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")]
+ Expression(Handle<crate::Expression>),
+ #[error("There is a Return earlier in the control flow of the function")]
+ Return,
+ #[error("There is a Discard earlier in the entry point across all called functions")]
+ Discard,
+}
+
+impl FunctionInfo {
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref_impl(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ global_use: GlobalUse,
+ ) -> NonUniformResult {
+ //Note: if the expression doesn't exist, this function
+ // will return `None`, but the later validation of
+ // expressions should detect this and error properly.
+ let info = self.expressions.get_mut(handle.index())?;
+ info.ref_count += 1;
+ // mark the used global as read
+ if let Some(global) = info.assignable_global {
+ self.global_uses[global.index()] |= global_use;
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Adds a value-type reference to an expression.
+ #[must_use]
+ fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult {
+ self.add_ref_impl(handle, GlobalUse::READ)
+ }
+
+ /// Adds a potentially assignable reference to an expression.
+ /// These are destinations for `Store` and `ImageStore` statements,
+ /// which can transit through `Access` and `AccessIndex`.
+ #[must_use]
+ fn add_assignable_ref(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ assignable_global: &mut Option<Handle<crate::GlobalVariable>>,
+ ) -> NonUniformResult {
+ //Note: similarly to `add_ref_impl`, this ignores invalid expressions.
+ let info = self.expressions.get_mut(handle.index())?;
+ info.ref_count += 1;
+ // propagate the assignable global up the chain, till it either hits
+ // a value-type expression, or the assignment statement.
+ if let Some(global) = info.assignable_global {
+ if let Some(_old) = assignable_global.replace(global) {
+ unreachable!()
+ }
+ }
+ info.uniformity.non_uniform_result
+ }
+
+ /// Inherit information from a called function.
+ fn process_call(
+ &mut self,
+ callee: &Self,
+ arguments: &[Handle<crate::Expression>],
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ self.sampling_set
+ .extend(callee.sampling_set.iter().cloned());
+ for sampling in callee.sampling.iter() {
+ // If the callee was passed the texture or sampler as an argument,
+ // we may now be able to determine which globals those referred to.
+ let image_storage = match sampling.image {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ let sampler_storage = match sampling.sampler {
+ GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var),
+ GlobalOrArgument::Argument(i) => {
+ let handle = arguments[i as usize];
+ GlobalOrArgument::from_expression(expression_arena, handle).map_err(
+ |source| {
+ FunctionError::Expression { handle, source }
+ .with_span_handle(handle, expression_arena)
+ },
+ )?
+ }
+ };
+
+ // If we've managed to pin both the image and sampler down to
+ // specific globals, record that in our `sampling_set`. Otherwise,
+ // record as much as we do know in our own `sampling` table, for our
+ // callers to sort out.
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ (image, sampler) => {
+ self.sampling.insert(Sampling { image, sampler });
+ }
+ }
+ }
+
+ // Inherit global use from our callees.
+ for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) {
+ *mine |= *other;
+ }
+
+ Ok(FunctionUniformity {
+ result: callee.uniformity.clone(),
+ exit: if callee.may_kill {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ })
+ }
+
+ /// Computes the expression info and stores it in `self.expressions`.
+ /// Also, bumps the reference counts on dependent expressions.
+ #[allow(clippy::or_fun_call)]
+ fn process_expression(
+ &mut self,
+ handle: Handle<crate::Expression>,
+ expression: &crate::Expression,
+ expression_arena: &Arena<crate::Expression>,
+ other_functions: &[FunctionInfo],
+ resolve_context: &ResolveContext,
+ capabilities: super::Capabilities,
+ ) -> Result<(), ExpressionError> {
+ use crate::{Expression as E, SampleLevel as Sl};
+
+ let mut assignable_global = None;
+ let uniformity = match *expression {
+ E::Access { base, index } => {
+ let base_ty = self[base].ty.inner_with(resolve_context.types);
+
+ // build up the caps needed if this is indexed non-uniformly
+ let mut needed_caps = super::Capabilities::empty();
+ let is_binding_array = match *base_ty {
+ crate::TypeInner::BindingArray {
+ base: array_element_ty_handle,
+ ..
+ } => {
+ // these are nasty aliases, but these idents are too long and break rustfmt
+ let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING;
+ let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING;
+ let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING;
+
+ // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it.
+ let array_element_ty =
+ &resolve_context.types[array_element_ty_handle].inner;
+
+ needed_caps |= match *array_element_ty {
+ // If we're an image, use the appropriate limit.
+ crate::TypeInner::Image { class, .. } => match class {
+ crate::ImageClass::Storage { .. } => ub_st,
+ _ => st_sb,
+ },
+ crate::TypeInner::Sampler { .. } => sampler,
+ // If we're anything but an image, assume we're a buffer and use the address space.
+ _ => {
+ if let E::GlobalVariable(global_handle) = expression_arena[base] {
+ let global = &resolve_context.global_vars[global_handle];
+ match global.space {
+ crate::AddressSpace::Uniform => ub_st,
+ crate::AddressSpace::Storage { .. } => st_sb,
+ _ => unreachable!(),
+ }
+ } else {
+ unreachable!()
+ }
+ }
+ };
+
+ true
+ }
+ _ => false,
+ };
+
+ if self[index].uniformity.non_uniform_result.is_some()
+ && !capabilities.contains(needed_caps)
+ && is_binding_array
+ {
+ return Err(ExpressionError::MissingCapabilities(needed_caps));
+ }
+
+ Uniformity {
+ non_uniform_result: self
+ .add_assignable_ref(base, &mut assignable_global)
+ .or(self.add_ref(index)),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::AccessIndex { base, .. } => Uniformity {
+ non_uniform_result: self.add_assignable_ref(base, &mut assignable_global),
+ requirements: UniformityRequirements::empty(),
+ },
+ // always uniform
+ E::Constant(_) => Uniformity::new(),
+ E::Splat { size: _, value } => Uniformity {
+ non_uniform_result: self.add_ref(value),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Swizzle { vector, .. } => Uniformity {
+ non_uniform_result: self.add_ref(vector),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Compose { ref components, .. } => {
+ let non_uniform_result = components
+ .iter()
+ .fold(None, |nur, &comp| nur.or(self.add_ref(comp)));
+ Uniformity {
+ non_uniform_result,
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the builtin or interpolation
+ E::FunctionArgument(index) => {
+ let arg = &resolve_context.arguments[index as usize];
+ let uniform = match arg.binding {
+ Some(crate::Binding::BuiltIn(built_in)) => match built_in {
+ // per-polygon built-ins are uniform
+ crate::BuiltIn::FrontFacing
+ // per-work-group built-ins are uniform
+ | crate::BuiltIn::WorkGroupId
+ | crate::BuiltIn::WorkGroupSize
+ | crate::BuiltIn::NumWorkGroups => true,
+ _ => false,
+ },
+ // only flat inputs are uniform
+ Some(crate::Binding::Location {
+ interpolation: Some(crate::Interpolation::Flat),
+ ..
+ }) => true,
+ _ => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ // depends on the address space
+ E::GlobalVariable(gh) => {
+ use crate::AddressSpace as As;
+ assignable_global = Some(gh);
+ let var = &resolve_context.global_vars[gh];
+ let uniform = match var.space {
+ // local data is non-uniform
+ As::Function | As::Private => false,
+ // workgroup memory is exclusively accessed by the group
+ As::WorkGroup => true,
+ // uniform data
+ As::Uniform | As::PushConstant => true,
+ // storage data is only uniform when read-only
+ As::Storage { access } => !access.contains(crate::StorageAccess::STORE),
+ As::Handle => false,
+ };
+ Uniformity {
+ non_uniform_result: if uniform { None } else { Some(handle) },
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::LocalVariable(_) => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Load { pointer } => Uniformity {
+ non_uniform_result: self.add_ref(pointer),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ImageSample {
+ image,
+ sampler,
+ gather: _,
+ coordinate,
+ array_index,
+ offset: _,
+ level,
+ depth_ref,
+ } => {
+ let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?;
+ let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?;
+
+ match (image_storage, sampler_storage) {
+ (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => {
+ self.sampling_set.insert(SamplingKey { image, sampler });
+ }
+ _ => {
+ self.sampling.insert(Sampling {
+ image: image_storage,
+ sampler: sampler_storage,
+ });
+ }
+ }
+
+ // "nur" == "Non-Uniform Result"
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let level_nur = match level {
+ Sl::Auto | Sl::Zero => None,
+ Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h),
+ Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)),
+ };
+ let dref_nur = depth_ref.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(sampler))
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(level_nur)
+ .or(dref_nur),
+ requirements: if level.implicit_derivatives() {
+ UniformityRequirements::IMPLICIT_LEVEL
+ } else {
+ UniformityRequirements::empty()
+ },
+ }
+ }
+ E::ImageLoad {
+ image,
+ coordinate,
+ array_index,
+ sample,
+ level,
+ } => {
+ let array_nur = array_index.and_then(|h| self.add_ref(h));
+ let sample_nur = sample.and_then(|h| self.add_ref(h));
+ let level_nur = level.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self
+ .add_ref(image)
+ .or(self.add_ref(coordinate))
+ .or(array_nur)
+ .or(sample_nur)
+ .or(level_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::ImageQuery { image, query } => {
+ let query_nur = match query {
+ crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h),
+ _ => None,
+ };
+ Uniformity {
+ non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::Unary { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Binary { left, right, .. } => Uniformity {
+ non_uniform_result: self.add_ref(left).or(self.add_ref(right)),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Select {
+ condition,
+ accept,
+ reject,
+ } => Uniformity {
+ non_uniform_result: self
+ .add_ref(condition)
+ .or(self.add_ref(accept))
+ .or(self.add_ref(reject)),
+ requirements: UniformityRequirements::empty(),
+ },
+ // explicit derivatives require uniform
+ E::Derivative { expr, .. } => Uniformity {
+ //Note: taking a derivative of a uniform doesn't make it non-uniform
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ E::Relational { argument, .. } => Uniformity {
+ non_uniform_result: self.add_ref(argument),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::Math {
+ arg, arg1, arg2, ..
+ } => {
+ let arg1_nur = arg1.and_then(|h| self.add_ref(h));
+ let arg2_nur = arg2.and_then(|h| self.add_ref(h));
+ Uniformity {
+ non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur),
+ requirements: UniformityRequirements::empty(),
+ }
+ }
+ E::As { expr, .. } => Uniformity {
+ non_uniform_result: self.add_ref(expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::CallResult(function) => {
+ let info = other_functions
+ .get(function.index())
+ .ok_or(ExpressionError::CallToUndeclaredFunction(function))?;
+
+ info.uniformity.clone()
+ }
+ E::AtomicResult { .. } => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::ArrayLength(expr) => Uniformity {
+ non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY),
+ requirements: UniformityRequirements::empty(),
+ },
+ };
+
+ let ty = resolve_context.resolve(expression, |h| {
+ self.expressions
+ .get(h.index())
+ .map(|ei| &ei.ty)
+ .ok_or(ResolveError::ExpressionForwardDependency(h))
+ })?;
+ self.expressions[handle.index()] = ExpressionInfo {
+ uniformity,
+ ref_count: 0,
+ assignable_global,
+ ty,
+ };
+ Ok(())
+ }
+
+ /// Analyzes the uniformity requirements of a block (as a sequence of statements).
+ /// Returns the uniformity characteristics at the *function* level, i.e.
+ /// whether or not the function requires to be called in uniform control flow,
+ /// and whether the produced result is not disrupting the control flow.
+ ///
+ /// The parent control flow is uniform if `disruptor.is_none()`.
+ ///
+ /// Returns a `NonUniformControlFlow` error if any of the expressions in the block
+ /// require uniformity, but the current flow is non-uniform.
+ #[allow(clippy::or_fun_call)]
+ fn process_block(
+ &mut self,
+ statements: &crate::Block,
+ other_functions: &[FunctionInfo],
+ mut disruptor: Option<UniformityDisruptor>,
+ expression_arena: &Arena<crate::Expression>,
+ ) -> Result<FunctionUniformity, WithSpan<FunctionError>> {
+ use crate::Statement as S;
+
+ let mut combined_uniformity = FunctionUniformity::new();
+ for (statement, &span) in statements.span_iter() {
+ let uniformity = match *statement {
+ S::Emit(ref range) => {
+ let mut requirements = UniformityRequirements::empty();
+ for expr in range.clone() {
+ let req = match self.expressions.get(expr.index()) {
+ Some(expr) => expr.uniformity.requirements,
+ None => UniformityRequirements::empty(),
+ };
+ #[cfg(feature = "validate")]
+ if self
+ .flags
+ .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
+ && !req.is_empty()
+ {
+ if let Some(cause) = disruptor {
+ return Err(FunctionError::NonUniformControlFlow(req, expr, cause)
+ .with_span_handle(expr, expression_arena));
+ }
+ }
+ requirements |= req;
+ }
+ FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements,
+ },
+ exit: ExitFlags::empty(),
+ }
+ }
+ S::Break | S::Continue => FunctionUniformity::new(),
+ S::Kill => FunctionUniformity {
+ result: Uniformity::new(),
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_KILL
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ S::Barrier(_) => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::WORK_GROUP_BARRIER,
+ },
+ exit: ExitFlags::empty(),
+ },
+ S::Block(ref b) => {
+ self.process_block(b, other_functions, disruptor, expression_arena)?
+ }
+ S::If {
+ condition,
+ ref accept,
+ ref reject,
+ } => {
+ let condition_nur = self.add_ref(condition);
+ let branch_disruptor =
+ disruptor.or(condition_nur.map(UniformityDisruptor::Expression));
+ let accept_uniformity = self.process_block(
+ accept,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ let reject_uniformity = self.process_block(
+ reject,
+ other_functions,
+ branch_disruptor,
+ expression_arena,
+ )?;
+ accept_uniformity | reject_uniformity
+ }
+ S::Switch {
+ selector,
+ ref cases,
+ } => {
+ let selector_nur = self.add_ref(selector);
+ let branch_disruptor =
+ disruptor.or(selector_nur.map(UniformityDisruptor::Expression));
+ let mut uniformity = FunctionUniformity::new();
+ let mut case_disruptor = branch_disruptor;
+ for case in cases.iter() {
+ let case_uniformity = self.process_block(
+ &case.body,
+ other_functions,
+ case_disruptor,
+ expression_arena,
+ )?;
+ case_disruptor = if case.fall_through {
+ case_disruptor.or(case_uniformity.exit_disruptor())
+ } else {
+ branch_disruptor
+ };
+ uniformity = uniformity | case_uniformity;
+ }
+ uniformity
+ }
+ S::Loop {
+ ref body,
+ ref continuing,
+ break_if: _,
+ } => {
+ let body_uniformity =
+ self.process_block(body, other_functions, disruptor, expression_arena)?;
+ let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor());
+ let continuing_uniformity = self.process_block(
+ continuing,
+ other_functions,
+ continuing_disruptor,
+ expression_arena,
+ )?;
+ body_uniformity | continuing_uniformity
+ }
+ S::Return { value } => FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: value.and_then(|expr| self.add_ref(expr)),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: if disruptor.is_some() {
+ ExitFlags::MAY_RETURN
+ } else {
+ ExitFlags::empty()
+ },
+ },
+ // Here and below, the used expressions are already emitted,
+ // and their results do not affect the function return value,
+ // so we can ignore their non-uniformity.
+ S::Store { pointer, value } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::ImageStore {
+ image,
+ coordinate,
+ array_index,
+ value,
+ } => {
+ let _ = self.add_ref_impl(image, GlobalUse::WRITE);
+ if let Some(expr) = array_index {
+ let _ = self.add_ref(expr);
+ }
+ let _ = self.add_ref(coordinate);
+ let _ = self.add_ref(value);
+ FunctionUniformity::new()
+ }
+ S::Call {
+ function,
+ ref arguments,
+ result: _,
+ } => {
+ for &argument in arguments {
+ let _ = self.add_ref(argument);
+ }
+ let info = other_functions.get(function.index()).ok_or(
+ FunctionError::InvalidCall {
+ function,
+ error: CallError::ForwardDeclaredFunction,
+ }
+ .with_span_static(span, "forward call"),
+ )?;
+ //Note: the result is validated by the Validator, not here
+ self.process_call(info, arguments, expression_arena)?
+ }
+ S::Atomic {
+ pointer,
+ ref fun,
+ value,
+ result: _,
+ } => {
+ let _ = self.add_ref_impl(pointer, GlobalUse::WRITE);
+ let _ = self.add_ref(value);
+ if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun {
+ let _ = self.add_ref(cmp);
+ }
+ FunctionUniformity::new()
+ }
+ };
+
+ disruptor = disruptor.or(uniformity.exit_disruptor());
+ combined_uniformity = combined_uniformity | uniformity;
+ }
+ Ok(combined_uniformity)
+ }
+}
+
+impl ModuleInfo {
+ /// Builds the `FunctionInfo` based on the function, and validates the
+ /// uniform control flow if required by the expressions of this function.
+ pub(super) fn process_function(
+ &self,
+ fun: &crate::Function,
+ module: &crate::Module,
+ flags: ValidationFlags,
+ capabilities: super::Capabilities,
+ ) -> Result<FunctionInfo, WithSpan<FunctionError>> {
+ let mut info = FunctionInfo {
+ flags,
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ };
+ let resolve_context = ResolveContext {
+ constants: &module.constants,
+ types: &module.types,
+ global_vars: &module.global_variables,
+ local_vars: &fun.local_variables,
+ functions: &module.functions,
+ arguments: &fun.arguments,
+ };
+
+ for (handle, expr) in fun.expressions.iter() {
+ if let Err(source) = info.process_expression(
+ handle,
+ expr,
+ &fun.expressions,
+ &self.functions,
+ &resolve_context,
+ capabilities,
+ ) {
+ return Err(FunctionError::Expression { handle, source }
+ .with_span_handle(handle, &fun.expressions));
+ }
+ }
+
+ let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?;
+ info.uniformity = uniformity.result;
+ info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL);
+
+ Ok(info)
+ }
+
+ pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
+ &self.entry_points[index]
+ }
+}
+
+#[test]
+#[cfg(feature = "validate")]
+fn uniform_control_flow() {
+ use crate::{Expression as E, Statement as S};
+
+ let mut constant_arena = Arena::new();
+ let constant = constant_arena.append(
+ crate::Constant {
+ name: None,
+ specialization: None,
+ inner: crate::ConstantInner::Scalar {
+ width: 4,
+ value: crate::ScalarValue::Uint(0),
+ },
+ },
+ Default::default(),
+ );
+ let mut type_arena = crate::UniqueArena::new();
+ let ty = type_arena.insert(
+ crate::Type {
+ name: None,
+ inner: crate::TypeInner::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ },
+ Default::default(),
+ );
+ let mut global_var_arena = Arena::new();
+ let non_uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ space: crate::AddressSpace::Handle,
+ binding: None,
+ },
+ Default::default(),
+ );
+ let uniform_global = global_var_arena.append(
+ crate::GlobalVariable {
+ name: None,
+ init: None,
+ ty,
+ binding: None,
+ space: crate::AddressSpace::Uniform,
+ },
+ Default::default(),
+ );
+
+ let mut expressions = Arena::new();
+ // checks the uniform control flow
+ let constant_expr = expressions.append(E::Constant(constant), Default::default());
+ // checks the non-uniform control flow
+ let derivative_expr = expressions.append(
+ E::Derivative {
+ axis: crate::DerivativeAxis::X,
+ expr: constant_expr,
+ },
+ Default::default(),
+ );
+ let emit_range_constant_derivative = expressions.range_from(0);
+ let non_uniform_global_expr =
+ expressions.append(E::GlobalVariable(non_uniform_global), Default::default());
+ let uniform_global_expr =
+ expressions.append(E::GlobalVariable(uniform_global), Default::default());
+ let emit_range_globals = expressions.range_from(2);
+
+ // checks the QUERY flag
+ let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default());
+ // checks the transitive WRITE flag
+ let access_expr = expressions.append(
+ E::AccessIndex {
+ base: non_uniform_global_expr,
+ index: 1,
+ },
+ Default::default(),
+ );
+ let emit_range_query_access_globals = expressions.range_from(2);
+
+ let mut info = FunctionInfo {
+ flags: ValidationFlags::all(),
+ available_stages: ShaderStages::all(),
+ uniformity: Uniformity::new(),
+ may_kill: false,
+ sampling_set: crate::FastHashSet::default(),
+ global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(),
+ expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(),
+ sampling: crate::FastHashSet::default(),
+ };
+ let resolve_context = ResolveContext {
+ constants: &constant_arena,
+ types: &type_arena,
+ global_vars: &global_var_arena,
+ local_vars: &Arena::new(),
+ functions: &Arena::new(),
+ arguments: &[],
+ };
+ for (handle, expression) in expressions.iter() {
+ info.process_expression(
+ handle,
+ expression,
+ &expressions,
+ &[],
+ &resolve_context,
+ super::Capabilities::empty(),
+ )
+ .unwrap();
+ }
+ assert_eq!(info[non_uniform_global_expr].ref_count, 1);
+ assert_eq!(info[uniform_global_expr].ref_count, 1);
+ assert_eq!(info[query_expr].ref_count, 0);
+ assert_eq!(info[access_expr].ref_count, 0);
+ assert_eq!(info[non_uniform_global], GlobalUse::empty());
+ assert_eq!(info[uniform_global], GlobalUse::QUERY);
+
+ let stmt_emit1 = S::Emit(emit_range_globals.clone());
+ let stmt_if_uniform = S::If {
+ condition: uniform_global_expr,
+ accept: crate::Block::new(),
+ reject: vec![
+ S::Emit(emit_range_constant_derivative.clone()),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit1, stmt_if_uniform].into(),
+ &[],
+ None,
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: None,
+ requirements: UniformityRequirements::DERIVATIVE,
+ },
+ exit: ExitFlags::empty(),
+ }),
+ );
+ assert_eq!(info[constant_expr].ref_count, 2);
+ assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY);
+
+ let stmt_emit2 = S::Emit(emit_range_globals.clone());
+ let stmt_if_non_uniform = S::If {
+ condition: non_uniform_global_expr,
+ accept: vec![
+ S::Emit(emit_range_constant_derivative),
+ S::Store {
+ pointer: constant_expr,
+ value: derivative_expr,
+ },
+ ]
+ .into(),
+ reject: crate::Block::new(),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit2, stmt_if_non_uniform].into(),
+ &[],
+ None,
+ &expressions
+ ),
+ Err(FunctionError::NonUniformControlFlow(
+ UniformityRequirements::DERIVATIVE,
+ derivative_expr,
+ UniformityDisruptor::Expression(non_uniform_global_expr)
+ )
+ .with_span()),
+ );
+ assert_eq!(info[derivative_expr].ref_count, 1);
+ assert_eq!(info[non_uniform_global], GlobalUse::READ);
+
+ let stmt_emit3 = S::Emit(emit_range_globals);
+ let stmt_return_non_uniform = S::Return {
+ value: Some(non_uniform_global_expr),
+ };
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit3, stmt_return_non_uniform].into(),
+ &[],
+ Some(UniformityDisruptor::Return),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::MAY_RETURN,
+ }),
+ );
+ assert_eq!(info[non_uniform_global_expr].ref_count, 3);
+
+ // Check that uniformity requirements reach through a pointer
+ let stmt_emit4 = S::Emit(emit_range_query_access_globals);
+ let stmt_assign = S::Store {
+ pointer: access_expr,
+ value: query_expr,
+ };
+ let stmt_return_pointer = S::Return {
+ value: Some(access_expr),
+ };
+ let stmt_kill = S::Kill;
+ assert_eq!(
+ info.process_block(
+ &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(),
+ &[],
+ Some(UniformityDisruptor::Discard),
+ &expressions
+ ),
+ Ok(FunctionUniformity {
+ result: Uniformity {
+ non_uniform_result: Some(non_uniform_global_expr),
+ requirements: UniformityRequirements::empty(),
+ },
+ exit: ExitFlags::all(),
+ }),
+ );
+ assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE);
+}