diff options
Diffstat (limited to 'third_party/rust/naga/src/valid/analyzer.rs')
-rw-r--r-- | third_party/rust/naga/src/valid/analyzer.rs | 1209 |
1 files changed, 1209 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs new file mode 100644 index 0000000000..ec8e9d96e5 --- /dev/null +++ b/third_party/rust/naga/src/valid/analyzer.rs @@ -0,0 +1,1209 @@ +/*! Module analyzer. + +Figures out the following properties: + - control flow uniformity + - texture/sampler pairs + - expression reference counts +!*/ + +use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags}; +use crate::span::{AddSpan as _, WithSpan}; +use crate::{ + arena::{Arena, Handle}, + proc::{ResolveContext, ResolveError, TypeResolution}, +}; +use std::ops; + +pub type NonUniformResult = Option<Handle<crate::Expression>>; + +bitflags::bitflags! { + /// Kinds of expressions that require uniform control flow. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct UniformityRequirements: u8 { + const WORK_GROUP_BARRIER = 0x1; + const DERIVATIVE = 0x2; + const IMPLICIT_LEVEL = 0x4; + } +} + +/// Uniform control flow characteristics. +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(test, derive(PartialEq))] +pub struct Uniformity { + /// A child expression with non-uniform result. + /// + /// This means, when the relevant invocations are scheduled on a compute unit, + /// they have to use vector registers to store an individual value + /// per invocation. + /// + /// Whenever the control flow is conditioned on such value, + /// the hardware needs to keep track of the mask of invocations, + /// and process all branches of the control flow. + /// + /// Any operations that depend on non-uniform results also produce non-uniform. + pub non_uniform_result: NonUniformResult, + /// If this expression requires uniform control flow, store the reason here. + pub requirements: UniformityRequirements, +} + +impl Uniformity { + const fn new() -> Self { + Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::empty(), + } + } +} + +bitflags::bitflags! { + struct ExitFlags: u8 { + /// Control flow may return from the function, which makes all the + /// subsequent statements within the current function (only!) + /// to be executed in a non-uniform control flow. + const MAY_RETURN = 0x1; + /// Control flow may be killed. Anything after `Statement::Kill` is + /// considered inside non-uniform context. + const MAY_KILL = 0x2; + } +} + +/// Uniformity characteristics of a function. +#[cfg_attr(test, derive(Debug, PartialEq))] +struct FunctionUniformity { + result: Uniformity, + exit: ExitFlags, +} + +impl ops::BitOr for FunctionUniformity { + type Output = Self; + fn bitor(self, other: Self) -> Self { + FunctionUniformity { + result: Uniformity { + non_uniform_result: self + .result + .non_uniform_result + .or(other.result.non_uniform_result), + requirements: self.result.requirements | other.result.requirements, + }, + exit: self.exit | other.exit, + } + } +} + +impl FunctionUniformity { + const fn new() -> Self { + FunctionUniformity { + result: Uniformity::new(), + exit: ExitFlags::empty(), + } + } + + /// Returns a disruptor based on the stored exit flags, if any. + const fn exit_disruptor(&self) -> Option<UniformityDisruptor> { + if self.exit.contains(ExitFlags::MAY_RETURN) { + Some(UniformityDisruptor::Return) + } else if self.exit.contains(ExitFlags::MAY_KILL) { + Some(UniformityDisruptor::Discard) + } else { + None + } + } +} + +bitflags::bitflags! { + /// Indicates how a global variable is used. + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + pub struct GlobalUse: u8 { + /// Data will be read from the variable. + const READ = 0x1; + /// Data will be written to the variable. + const WRITE = 0x2; + /// The information about the data is queried. + const QUERY = 0x4; + } +} + +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct SamplingKey { + pub image: Handle<crate::GlobalVariable>, + pub sampler: Handle<crate::GlobalVariable>, +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct ExpressionInfo { + pub uniformity: Uniformity, + pub ref_count: usize, + assignable_global: Option<Handle<crate::GlobalVariable>>, + pub ty: TypeResolution, +} + +impl ExpressionInfo { + const fn new() -> Self { + ExpressionInfo { + uniformity: Uniformity::new(), + ref_count: 0, + assignable_global: None, + // this doesn't matter at this point, will be overwritten + ty: TypeResolution::Value(crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: 0, + }), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +enum GlobalOrArgument { + Global(Handle<crate::GlobalVariable>), + Argument(u32), +} + +impl GlobalOrArgument { + fn from_expression( + expression_arena: &Arena<crate::Expression>, + expression: Handle<crate::Expression>, + ) -> Result<GlobalOrArgument, ExpressionError> { + Ok(match expression_arena[expression] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + crate::Expression::FunctionArgument(i) => GlobalOrArgument::Argument(i), + crate::Expression::Access { base, .. } + | crate::Expression::AccessIndex { base, .. } => match expression_arena[base] { + crate::Expression::GlobalVariable(var) => GlobalOrArgument::Global(var), + _ => return Err(ExpressionError::ExpectedGlobalOrArgument), + }, + _ => return Err(ExpressionError::ExpectedGlobalOrArgument), + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +struct Sampling { + image: GlobalOrArgument, + sampler: GlobalOrArgument, +} + +#[derive(Debug)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct FunctionInfo { + /// Validation flags. + #[allow(dead_code)] + flags: ValidationFlags, + /// Set of shader stages where calling this function is valid. + pub available_stages: ShaderStages, + /// Uniformity characteristics. + pub uniformity: Uniformity, + /// Function may kill the invocation. + pub may_kill: bool, + + /// All pairs of (texture, sampler) globals that may be used together in + /// sampling operations by this function and its callees. This includes + /// pairings that arise when this function passes textures and samplers as + /// arguments to its callees. + /// + /// This table does not include uses of textures and samplers passed as + /// arguments to this function itself, since we do not know which globals + /// those will be. However, this table *is* exhaustive when computed for an + /// entry point function: entry points never receive textures or samplers as + /// arguments, so all an entry point's sampling can be reported in terms of + /// globals. + /// + /// The GLSL back end uses this table to construct reflection info that + /// clients need to construct texture-combined sampler values. + pub sampling_set: crate::FastHashSet<SamplingKey>, + + /// How this function and its callees use this module's globals. + /// + /// This is indexed by `Handle<GlobalVariable>` indices. However, + /// `FunctionInfo` implements `std::ops::Index<Handle<GlobalVariable>>`, + /// so you can simply index this struct with a global handle to retrieve + /// its usage information. + global_uses: Box<[GlobalUse]>, + + /// Information about each expression in this function's body. + /// + /// This is indexed by `Handle<Expression>` indices. However, `FunctionInfo` + /// implements `std::ops::Index<Handle<Expression>>`, so you can simply + /// index this struct with an expression handle to retrieve its + /// `ExpressionInfo`. + expressions: Box<[ExpressionInfo]>, + + /// All (texture, sampler) pairs that may be used together in sampling + /// operations by this function and its callees, whether they are accessed + /// as globals or passed as arguments. + /// + /// Participants are represented by [`GlobalVariable`] handles whenever + /// possible, and otherwise by indices of this function's arguments. + /// + /// When analyzing a function call, we combine this data about the callee + /// with the actual arguments being passed to produce the callers' own + /// `sampling_set` and `sampling` tables. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + sampling: crate::FastHashSet<Sampling>, +} + +impl FunctionInfo { + pub const fn global_variable_count(&self) -> usize { + self.global_uses.len() + } + pub const fn expression_count(&self) -> usize { + self.expressions.len() + } + pub fn dominates_global_use(&self, other: &Self) -> bool { + for (self_global_uses, other_global_uses) in + self.global_uses.iter().zip(other.global_uses.iter()) + { + if !self_global_uses.contains(*other_global_uses) { + return false; + } + } + true + } +} + +impl ops::Index<Handle<crate::GlobalVariable>> for FunctionInfo { + type Output = GlobalUse; + fn index(&self, handle: Handle<crate::GlobalVariable>) -> &GlobalUse { + &self.global_uses[handle.index()] + } +} + +impl ops::Index<Handle<crate::Expression>> for FunctionInfo { + type Output = ExpressionInfo; + fn index(&self, handle: Handle<crate::Expression>) -> &ExpressionInfo { + &self.expressions[handle.index()] + } +} + +/// Disruptor of the uniform control flow. +#[derive(Clone, Copy, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum UniformityDisruptor { + #[error("Expression {0:?} produced non-uniform result, and control flow depends on it")] + Expression(Handle<crate::Expression>), + #[error("There is a Return earlier in the control flow of the function")] + Return, + #[error("There is a Discard earlier in the entry point across all called functions")] + Discard, +} + +impl FunctionInfo { + /// Adds a value-type reference to an expression. + #[must_use] + fn add_ref_impl( + &mut self, + handle: Handle<crate::Expression>, + global_use: GlobalUse, + ) -> NonUniformResult { + //Note: if the expression doesn't exist, this function + // will return `None`, but the later validation of + // expressions should detect this and error properly. + let info = self.expressions.get_mut(handle.index())?; + info.ref_count += 1; + // mark the used global as read + if let Some(global) = info.assignable_global { + self.global_uses[global.index()] |= global_use; + } + info.uniformity.non_uniform_result + } + + /// Adds a value-type reference to an expression. + #[must_use] + fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult { + self.add_ref_impl(handle, GlobalUse::READ) + } + + /// Adds a potentially assignable reference to an expression. + /// These are destinations for `Store` and `ImageStore` statements, + /// which can transit through `Access` and `AccessIndex`. + #[must_use] + fn add_assignable_ref( + &mut self, + handle: Handle<crate::Expression>, + assignable_global: &mut Option<Handle<crate::GlobalVariable>>, + ) -> NonUniformResult { + //Note: similarly to `add_ref_impl`, this ignores invalid expressions. + let info = self.expressions.get_mut(handle.index())?; + info.ref_count += 1; + // propagate the assignable global up the chain, till it either hits + // a value-type expression, or the assignment statement. + if let Some(global) = info.assignable_global { + if let Some(_old) = assignable_global.replace(global) { + unreachable!() + } + } + info.uniformity.non_uniform_result + } + + /// Inherit information from a called function. + fn process_call( + &mut self, + callee: &Self, + arguments: &[Handle<crate::Expression>], + expression_arena: &Arena<crate::Expression>, + ) -> Result<FunctionUniformity, WithSpan<FunctionError>> { + self.sampling_set + .extend(callee.sampling_set.iter().cloned()); + for sampling in callee.sampling.iter() { + // If the callee was passed the texture or sampler as an argument, + // we may now be able to determine which globals those referred to. + let image_storage = match sampling.image { + GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var), + GlobalOrArgument::Argument(i) => { + let handle = arguments[i as usize]; + GlobalOrArgument::from_expression(expression_arena, handle).map_err( + |source| { + FunctionError::Expression { handle, source } + .with_span_handle(handle, expression_arena) + }, + )? + } + }; + + let sampler_storage = match sampling.sampler { + GlobalOrArgument::Global(var) => GlobalOrArgument::Global(var), + GlobalOrArgument::Argument(i) => { + let handle = arguments[i as usize]; + GlobalOrArgument::from_expression(expression_arena, handle).map_err( + |source| { + FunctionError::Expression { handle, source } + .with_span_handle(handle, expression_arena) + }, + )? + } + }; + + // If we've managed to pin both the image and sampler down to + // specific globals, record that in our `sampling_set`. Otherwise, + // record as much as we do know in our own `sampling` table, for our + // callers to sort out. + match (image_storage, sampler_storage) { + (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => { + self.sampling_set.insert(SamplingKey { image, sampler }); + } + (image, sampler) => { + self.sampling.insert(Sampling { image, sampler }); + } + } + } + + // Inherit global use from our callees. + for (mine, other) in self.global_uses.iter_mut().zip(callee.global_uses.iter()) { + *mine |= *other; + } + + Ok(FunctionUniformity { + result: callee.uniformity.clone(), + exit: if callee.may_kill { + ExitFlags::MAY_KILL + } else { + ExitFlags::empty() + }, + }) + } + + /// Computes the expression info and stores it in `self.expressions`. + /// Also, bumps the reference counts on dependent expressions. + #[allow(clippy::or_fun_call)] + fn process_expression( + &mut self, + handle: Handle<crate::Expression>, + expression: &crate::Expression, + expression_arena: &Arena<crate::Expression>, + other_functions: &[FunctionInfo], + resolve_context: &ResolveContext, + capabilities: super::Capabilities, + ) -> Result<(), ExpressionError> { + use crate::{Expression as E, SampleLevel as Sl}; + + let mut assignable_global = None; + let uniformity = match *expression { + E::Access { base, index } => { + let base_ty = self[base].ty.inner_with(resolve_context.types); + + // build up the caps needed if this is indexed non-uniformly + let mut needed_caps = super::Capabilities::empty(); + let is_binding_array = match *base_ty { + crate::TypeInner::BindingArray { + base: array_element_ty_handle, + .. + } => { + // these are nasty aliases, but these idents are too long and break rustfmt + let ub_st = super::Capabilities::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING; + let st_sb = super::Capabilities::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING; + let sampler = super::Capabilities::SAMPLER_NON_UNIFORM_INDEXING; + + // We're a binding array, so lets use the type of _what_ we are array of to determine if we can non-uniformly index it. + let array_element_ty = + &resolve_context.types[array_element_ty_handle].inner; + + needed_caps |= match *array_element_ty { + // If we're an image, use the appropriate limit. + crate::TypeInner::Image { class, .. } => match class { + crate::ImageClass::Storage { .. } => ub_st, + _ => st_sb, + }, + crate::TypeInner::Sampler { .. } => sampler, + // If we're anything but an image, assume we're a buffer and use the address space. + _ => { + if let E::GlobalVariable(global_handle) = expression_arena[base] { + let global = &resolve_context.global_vars[global_handle]; + match global.space { + crate::AddressSpace::Uniform => ub_st, + crate::AddressSpace::Storage { .. } => st_sb, + _ => unreachable!(), + } + } else { + unreachable!() + } + } + }; + + true + } + _ => false, + }; + + if self[index].uniformity.non_uniform_result.is_some() + && !capabilities.contains(needed_caps) + && is_binding_array + { + return Err(ExpressionError::MissingCapabilities(needed_caps)); + } + + Uniformity { + non_uniform_result: self + .add_assignable_ref(base, &mut assignable_global) + .or(self.add_ref(index)), + requirements: UniformityRequirements::empty(), + } + } + E::AccessIndex { base, .. } => Uniformity { + non_uniform_result: self.add_assignable_ref(base, &mut assignable_global), + requirements: UniformityRequirements::empty(), + }, + // always uniform + E::Constant(_) => Uniformity::new(), + E::Splat { size: _, value } => Uniformity { + non_uniform_result: self.add_ref(value), + requirements: UniformityRequirements::empty(), + }, + E::Swizzle { vector, .. } => Uniformity { + non_uniform_result: self.add_ref(vector), + requirements: UniformityRequirements::empty(), + }, + E::Compose { ref components, .. } => { + let non_uniform_result = components + .iter() + .fold(None, |nur, &comp| nur.or(self.add_ref(comp))); + Uniformity { + non_uniform_result, + requirements: UniformityRequirements::empty(), + } + } + // depends on the builtin or interpolation + E::FunctionArgument(index) => { + let arg = &resolve_context.arguments[index as usize]; + let uniform = match arg.binding { + Some(crate::Binding::BuiltIn(built_in)) => match built_in { + // per-polygon built-ins are uniform + crate::BuiltIn::FrontFacing + // per-work-group built-ins are uniform + | crate::BuiltIn::WorkGroupId + | crate::BuiltIn::WorkGroupSize + | crate::BuiltIn::NumWorkGroups => true, + _ => false, + }, + // only flat inputs are uniform + Some(crate::Binding::Location { + interpolation: Some(crate::Interpolation::Flat), + .. + }) => true, + _ => false, + }; + Uniformity { + non_uniform_result: if uniform { None } else { Some(handle) }, + requirements: UniformityRequirements::empty(), + } + } + // depends on the address space + E::GlobalVariable(gh) => { + use crate::AddressSpace as As; + assignable_global = Some(gh); + let var = &resolve_context.global_vars[gh]; + let uniform = match var.space { + // local data is non-uniform + As::Function | As::Private => false, + // workgroup memory is exclusively accessed by the group + As::WorkGroup => true, + // uniform data + As::Uniform | As::PushConstant => true, + // storage data is only uniform when read-only + As::Storage { access } => !access.contains(crate::StorageAccess::STORE), + As::Handle => false, + }; + Uniformity { + non_uniform_result: if uniform { None } else { Some(handle) }, + requirements: UniformityRequirements::empty(), + } + } + E::LocalVariable(_) => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::Load { pointer } => Uniformity { + non_uniform_result: self.add_ref(pointer), + requirements: UniformityRequirements::empty(), + }, + E::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset: _, + level, + depth_ref, + } => { + let image_storage = GlobalOrArgument::from_expression(expression_arena, image)?; + let sampler_storage = GlobalOrArgument::from_expression(expression_arena, sampler)?; + + match (image_storage, sampler_storage) { + (GlobalOrArgument::Global(image), GlobalOrArgument::Global(sampler)) => { + self.sampling_set.insert(SamplingKey { image, sampler }); + } + _ => { + self.sampling.insert(Sampling { + image: image_storage, + sampler: sampler_storage, + }); + } + } + + // "nur" == "Non-Uniform Result" + let array_nur = array_index.and_then(|h| self.add_ref(h)); + let level_nur = match level { + Sl::Auto | Sl::Zero => None, + Sl::Exact(h) | Sl::Bias(h) => self.add_ref(h), + Sl::Gradient { x, y } => self.add_ref(x).or(self.add_ref(y)), + }; + let dref_nur = depth_ref.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self + .add_ref(image) + .or(self.add_ref(sampler)) + .or(self.add_ref(coordinate)) + .or(array_nur) + .or(level_nur) + .or(dref_nur), + requirements: if level.implicit_derivatives() { + UniformityRequirements::IMPLICIT_LEVEL + } else { + UniformityRequirements::empty() + }, + } + } + E::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + let array_nur = array_index.and_then(|h| self.add_ref(h)); + let sample_nur = sample.and_then(|h| self.add_ref(h)); + let level_nur = level.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self + .add_ref(image) + .or(self.add_ref(coordinate)) + .or(array_nur) + .or(sample_nur) + .or(level_nur), + requirements: UniformityRequirements::empty(), + } + } + E::ImageQuery { image, query } => { + let query_nur = match query { + crate::ImageQuery::Size { level: Some(h) } => self.add_ref(h), + _ => None, + }; + Uniformity { + non_uniform_result: self.add_ref_impl(image, GlobalUse::QUERY).or(query_nur), + requirements: UniformityRequirements::empty(), + } + } + E::Unary { expr, .. } => Uniformity { + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::empty(), + }, + E::Binary { left, right, .. } => Uniformity { + non_uniform_result: self.add_ref(left).or(self.add_ref(right)), + requirements: UniformityRequirements::empty(), + }, + E::Select { + condition, + accept, + reject, + } => Uniformity { + non_uniform_result: self + .add_ref(condition) + .or(self.add_ref(accept)) + .or(self.add_ref(reject)), + requirements: UniformityRequirements::empty(), + }, + // explicit derivatives require uniform + E::Derivative { expr, .. } => Uniformity { + //Note: taking a derivative of a uniform doesn't make it non-uniform + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::DERIVATIVE, + }, + E::Relational { argument, .. } => Uniformity { + non_uniform_result: self.add_ref(argument), + requirements: UniformityRequirements::empty(), + }, + E::Math { + arg, arg1, arg2, .. + } => { + let arg1_nur = arg1.and_then(|h| self.add_ref(h)); + let arg2_nur = arg2.and_then(|h| self.add_ref(h)); + Uniformity { + non_uniform_result: self.add_ref(arg).or(arg1_nur).or(arg2_nur), + requirements: UniformityRequirements::empty(), + } + } + E::As { expr, .. } => Uniformity { + non_uniform_result: self.add_ref(expr), + requirements: UniformityRequirements::empty(), + }, + E::CallResult(function) => { + let info = other_functions + .get(function.index()) + .ok_or(ExpressionError::CallToUndeclaredFunction(function))?; + + info.uniformity.clone() + } + E::AtomicResult { .. } => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::ArrayLength(expr) => Uniformity { + non_uniform_result: self.add_ref_impl(expr, GlobalUse::QUERY), + requirements: UniformityRequirements::empty(), + }, + }; + + let ty = resolve_context.resolve(expression, |h| { + self.expressions + .get(h.index()) + .map(|ei| &ei.ty) + .ok_or(ResolveError::ExpressionForwardDependency(h)) + })?; + self.expressions[handle.index()] = ExpressionInfo { + uniformity, + ref_count: 0, + assignable_global, + ty, + }; + Ok(()) + } + + /// Analyzes the uniformity requirements of a block (as a sequence of statements). + /// Returns the uniformity characteristics at the *function* level, i.e. + /// whether or not the function requires to be called in uniform control flow, + /// and whether the produced result is not disrupting the control flow. + /// + /// The parent control flow is uniform if `disruptor.is_none()`. + /// + /// Returns a `NonUniformControlFlow` error if any of the expressions in the block + /// require uniformity, but the current flow is non-uniform. + #[allow(clippy::or_fun_call)] + fn process_block( + &mut self, + statements: &crate::Block, + other_functions: &[FunctionInfo], + mut disruptor: Option<UniformityDisruptor>, + expression_arena: &Arena<crate::Expression>, + ) -> Result<FunctionUniformity, WithSpan<FunctionError>> { + use crate::Statement as S; + + let mut combined_uniformity = FunctionUniformity::new(); + for (statement, &span) in statements.span_iter() { + let uniformity = match *statement { + S::Emit(ref range) => { + let mut requirements = UniformityRequirements::empty(); + for expr in range.clone() { + let req = match self.expressions.get(expr.index()) { + Some(expr) => expr.uniformity.requirements, + None => UniformityRequirements::empty(), + }; + #[cfg(feature = "validate")] + if self + .flags + .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY) + && !req.is_empty() + { + if let Some(cause) = disruptor { + return Err(FunctionError::NonUniformControlFlow(req, expr, cause) + .with_span_handle(expr, expression_arena)); + } + } + requirements |= req; + } + FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements, + }, + exit: ExitFlags::empty(), + } + } + S::Break | S::Continue => FunctionUniformity::new(), + S::Kill => FunctionUniformity { + result: Uniformity::new(), + exit: if disruptor.is_some() { + ExitFlags::MAY_KILL + } else { + ExitFlags::empty() + }, + }, + S::Barrier(_) => FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::WORK_GROUP_BARRIER, + }, + exit: ExitFlags::empty(), + }, + S::Block(ref b) => { + self.process_block(b, other_functions, disruptor, expression_arena)? + } + S::If { + condition, + ref accept, + ref reject, + } => { + let condition_nur = self.add_ref(condition); + let branch_disruptor = + disruptor.or(condition_nur.map(UniformityDisruptor::Expression)); + let accept_uniformity = self.process_block( + accept, + other_functions, + branch_disruptor, + expression_arena, + )?; + let reject_uniformity = self.process_block( + reject, + other_functions, + branch_disruptor, + expression_arena, + )?; + accept_uniformity | reject_uniformity + } + S::Switch { + selector, + ref cases, + } => { + let selector_nur = self.add_ref(selector); + let branch_disruptor = + disruptor.or(selector_nur.map(UniformityDisruptor::Expression)); + let mut uniformity = FunctionUniformity::new(); + let mut case_disruptor = branch_disruptor; + for case in cases.iter() { + let case_uniformity = self.process_block( + &case.body, + other_functions, + case_disruptor, + expression_arena, + )?; + case_disruptor = if case.fall_through { + case_disruptor.or(case_uniformity.exit_disruptor()) + } else { + branch_disruptor + }; + uniformity = uniformity | case_uniformity; + } + uniformity + } + S::Loop { + ref body, + ref continuing, + break_if: _, + } => { + let body_uniformity = + self.process_block(body, other_functions, disruptor, expression_arena)?; + let continuing_disruptor = disruptor.or(body_uniformity.exit_disruptor()); + let continuing_uniformity = self.process_block( + continuing, + other_functions, + continuing_disruptor, + expression_arena, + )?; + body_uniformity | continuing_uniformity + } + S::Return { value } => FunctionUniformity { + result: Uniformity { + non_uniform_result: value.and_then(|expr| self.add_ref(expr)), + requirements: UniformityRequirements::empty(), + }, + exit: if disruptor.is_some() { + ExitFlags::MAY_RETURN + } else { + ExitFlags::empty() + }, + }, + // Here and below, the used expressions are already emitted, + // and their results do not affect the function return value, + // so we can ignore their non-uniformity. + S::Store { pointer, value } => { + let _ = self.add_ref_impl(pointer, GlobalUse::WRITE); + let _ = self.add_ref(value); + FunctionUniformity::new() + } + S::ImageStore { + image, + coordinate, + array_index, + value, + } => { + let _ = self.add_ref_impl(image, GlobalUse::WRITE); + if let Some(expr) = array_index { + let _ = self.add_ref(expr); + } + let _ = self.add_ref(coordinate); + let _ = self.add_ref(value); + FunctionUniformity::new() + } + S::Call { + function, + ref arguments, + result: _, + } => { + for &argument in arguments { + let _ = self.add_ref(argument); + } + let info = other_functions.get(function.index()).ok_or( + FunctionError::InvalidCall { + function, + error: CallError::ForwardDeclaredFunction, + } + .with_span_static(span, "forward call"), + )?; + //Note: the result is validated by the Validator, not here + self.process_call(info, arguments, expression_arena)? + } + S::Atomic { + pointer, + ref fun, + value, + result: _, + } => { + let _ = self.add_ref_impl(pointer, GlobalUse::WRITE); + let _ = self.add_ref(value); + if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { + let _ = self.add_ref(cmp); + } + FunctionUniformity::new() + } + }; + + disruptor = disruptor.or(uniformity.exit_disruptor()); + combined_uniformity = combined_uniformity | uniformity; + } + Ok(combined_uniformity) + } +} + +impl ModuleInfo { + /// Builds the `FunctionInfo` based on the function, and validates the + /// uniform control flow if required by the expressions of this function. + pub(super) fn process_function( + &self, + fun: &crate::Function, + module: &crate::Module, + flags: ValidationFlags, + capabilities: super::Capabilities, + ) -> Result<FunctionInfo, WithSpan<FunctionError>> { + let mut info = FunctionInfo { + flags, + available_stages: ShaderStages::all(), + uniformity: Uniformity::new(), + may_kill: false, + sampling_set: crate::FastHashSet::default(), + global_uses: vec![GlobalUse::empty(); module.global_variables.len()].into_boxed_slice(), + expressions: vec![ExpressionInfo::new(); fun.expressions.len()].into_boxed_slice(), + sampling: crate::FastHashSet::default(), + }; + let resolve_context = ResolveContext { + constants: &module.constants, + types: &module.types, + global_vars: &module.global_variables, + local_vars: &fun.local_variables, + functions: &module.functions, + arguments: &fun.arguments, + }; + + for (handle, expr) in fun.expressions.iter() { + if let Err(source) = info.process_expression( + handle, + expr, + &fun.expressions, + &self.functions, + &resolve_context, + capabilities, + ) { + return Err(FunctionError::Expression { handle, source } + .with_span_handle(handle, &fun.expressions)); + } + } + + let uniformity = info.process_block(&fun.body, &self.functions, None, &fun.expressions)?; + info.uniformity = uniformity.result; + info.may_kill = uniformity.exit.contains(ExitFlags::MAY_KILL); + + Ok(info) + } + + pub fn get_entry_point(&self, index: usize) -> &FunctionInfo { + &self.entry_points[index] + } +} + +#[test] +#[cfg(feature = "validate")] +fn uniform_control_flow() { + use crate::{Expression as E, Statement as S}; + + let mut constant_arena = Arena::new(); + let constant = constant_arena.append( + crate::Constant { + name: None, + specialization: None, + inner: crate::ConstantInner::Scalar { + width: 4, + value: crate::ScalarValue::Uint(0), + }, + }, + Default::default(), + ); + let mut type_arena = crate::UniqueArena::new(); + let ty = type_arena.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Vector { + size: crate::VectorSize::Bi, + kind: crate::ScalarKind::Float, + width: 4, + }, + }, + Default::default(), + ); + let mut global_var_arena = Arena::new(); + let non_uniform_global = global_var_arena.append( + crate::GlobalVariable { + name: None, + init: None, + ty, + space: crate::AddressSpace::Handle, + binding: None, + }, + Default::default(), + ); + let uniform_global = global_var_arena.append( + crate::GlobalVariable { + name: None, + init: None, + ty, + binding: None, + space: crate::AddressSpace::Uniform, + }, + Default::default(), + ); + + let mut expressions = Arena::new(); + // checks the uniform control flow + let constant_expr = expressions.append(E::Constant(constant), Default::default()); + // checks the non-uniform control flow + let derivative_expr = expressions.append( + E::Derivative { + axis: crate::DerivativeAxis::X, + expr: constant_expr, + }, + Default::default(), + ); + let emit_range_constant_derivative = expressions.range_from(0); + let non_uniform_global_expr = + expressions.append(E::GlobalVariable(non_uniform_global), Default::default()); + let uniform_global_expr = + expressions.append(E::GlobalVariable(uniform_global), Default::default()); + let emit_range_globals = expressions.range_from(2); + + // checks the QUERY flag + let query_expr = expressions.append(E::ArrayLength(uniform_global_expr), Default::default()); + // checks the transitive WRITE flag + let access_expr = expressions.append( + E::AccessIndex { + base: non_uniform_global_expr, + index: 1, + }, + Default::default(), + ); + let emit_range_query_access_globals = expressions.range_from(2); + + let mut info = FunctionInfo { + flags: ValidationFlags::all(), + available_stages: ShaderStages::all(), + uniformity: Uniformity::new(), + may_kill: false, + sampling_set: crate::FastHashSet::default(), + global_uses: vec![GlobalUse::empty(); global_var_arena.len()].into_boxed_slice(), + expressions: vec![ExpressionInfo::new(); expressions.len()].into_boxed_slice(), + sampling: crate::FastHashSet::default(), + }; + let resolve_context = ResolveContext { + constants: &constant_arena, + types: &type_arena, + global_vars: &global_var_arena, + local_vars: &Arena::new(), + functions: &Arena::new(), + arguments: &[], + }; + for (handle, expression) in expressions.iter() { + info.process_expression( + handle, + expression, + &expressions, + &[], + &resolve_context, + super::Capabilities::empty(), + ) + .unwrap(); + } + assert_eq!(info[non_uniform_global_expr].ref_count, 1); + assert_eq!(info[uniform_global_expr].ref_count, 1); + assert_eq!(info[query_expr].ref_count, 0); + assert_eq!(info[access_expr].ref_count, 0); + assert_eq!(info[non_uniform_global], GlobalUse::empty()); + assert_eq!(info[uniform_global], GlobalUse::QUERY); + + let stmt_emit1 = S::Emit(emit_range_globals.clone()); + let stmt_if_uniform = S::If { + condition: uniform_global_expr, + accept: crate::Block::new(), + reject: vec![ + S::Emit(emit_range_constant_derivative.clone()), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ] + .into(), + }; + assert_eq!( + info.process_block( + &vec![stmt_emit1, stmt_if_uniform].into(), + &[], + None, + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::DERIVATIVE, + }, + exit: ExitFlags::empty(), + }), + ); + assert_eq!(info[constant_expr].ref_count, 2); + assert_eq!(info[uniform_global], GlobalUse::READ | GlobalUse::QUERY); + + let stmt_emit2 = S::Emit(emit_range_globals.clone()); + let stmt_if_non_uniform = S::If { + condition: non_uniform_global_expr, + accept: vec![ + S::Emit(emit_range_constant_derivative), + S::Store { + pointer: constant_expr, + value: derivative_expr, + }, + ] + .into(), + reject: crate::Block::new(), + }; + assert_eq!( + info.process_block( + &vec![stmt_emit2, stmt_if_non_uniform].into(), + &[], + None, + &expressions + ), + Err(FunctionError::NonUniformControlFlow( + UniformityRequirements::DERIVATIVE, + derivative_expr, + UniformityDisruptor::Expression(non_uniform_global_expr) + ) + .with_span()), + ); + assert_eq!(info[derivative_expr].ref_count, 1); + assert_eq!(info[non_uniform_global], GlobalUse::READ); + + let stmt_emit3 = S::Emit(emit_range_globals); + let stmt_return_non_uniform = S::Return { + value: Some(non_uniform_global_expr), + }; + assert_eq!( + info.process_block( + &vec![stmt_emit3, stmt_return_non_uniform].into(), + &[], + Some(UniformityDisruptor::Return), + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: Some(non_uniform_global_expr), + requirements: UniformityRequirements::empty(), + }, + exit: ExitFlags::MAY_RETURN, + }), + ); + assert_eq!(info[non_uniform_global_expr].ref_count, 3); + + // Check that uniformity requirements reach through a pointer + let stmt_emit4 = S::Emit(emit_range_query_access_globals); + let stmt_assign = S::Store { + pointer: access_expr, + value: query_expr, + }; + let stmt_return_pointer = S::Return { + value: Some(access_expr), + }; + let stmt_kill = S::Kill; + assert_eq!( + info.process_block( + &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(), + &[], + Some(UniformityDisruptor::Discard), + &expressions + ), + Ok(FunctionUniformity { + result: Uniformity { + non_uniform_result: Some(non_uniform_global_expr), + requirements: UniformityRequirements::empty(), + }, + exit: ExitFlags::all(), + }), + ); + assert_eq!(info[non_uniform_global], GlobalUse::READ | GlobalUse::WRITE); +} |