summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/valid/function.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/valid/function.rs')
-rw-r--r--third_party/rust/naga/src/valid/function.rs268
1 files changed, 258 insertions, 10 deletions
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs
index f0ca22cbda..71128fc86d 100644
--- a/third_party/rust/naga/src/valid/function.rs
+++ b/third_party/rust/naga/src/valid/function.rs
@@ -49,13 +49,26 @@ pub enum AtomicError {
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
+pub enum SubgroupError {
+ #[error("Operand {0:?} has invalid type.")]
+ InvalidOperand(Handle<crate::Expression>),
+ #[error("Result type for {0:?} doesn't match the statement")]
+ ResultTypeMismatch(Handle<crate::Expression>),
+ #[error("Support for subgroup operation {0:?} is required")]
+ UnsupportedOperation(super::SubgroupOperationSet),
+ #[error("Unknown operation")]
+ UnknownOperation,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum LocalVariableError {
#[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
InvalidType(Handle<crate::Type>),
#[error("Initializer doesn't match the variable type")]
InitializerType,
- #[error("Initializer is not const")]
- NonConstInitializer,
+ #[error("Initializer is not a const or override expression")]
+ NonConstOrOverrideInitializer,
}
#[derive(Clone, Debug, thiserror::Error)]
@@ -135,6 +148,8 @@ pub enum FunctionError {
InvalidRayDescriptor(Handle<crate::Expression>),
#[error("Ray Query {0:?} does not have a matching type")]
InvalidRayQueryType(Handle<crate::Type>),
+ #[error("Shader requires capability {0:?}")]
+ MissingCapability(super::Capabilities),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
@@ -155,6 +170,8 @@ pub enum FunctionError {
WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
#[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
+ #[error("Subgroup operation is invalid")]
+ InvalidSubgroup(#[from] SubgroupError),
}
bitflags::bitflags! {
@@ -399,6 +416,127 @@ impl super::Validator {
}
Ok(())
}
+ fn validate_subgroup_operation(
+ &mut self,
+ op: &crate::SubgroupOperation,
+ collective_op: &crate::CollectiveOperation,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
+
+ let (is_scalar, scalar) = match *argument_inner {
+ crate::TypeInner::Scalar(scalar) => (true, scalar),
+ crate::TypeInner::Vector { scalar, .. } => (false, scalar),
+ _ => {
+ log::error!("Subgroup operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+ };
+
+ use crate::ScalarKind as sk;
+ use crate::SubgroupOperation as sg;
+ match (scalar.kind, *op) {
+ (sk::Bool, sg::All | sg::Any) if is_scalar => {}
+ (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
+ (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
+
+ (_, _) => {
+ log::error!("Subgroup operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+ };
+
+ use crate::CollectiveOperation as co;
+ match (*collective_op, *op) {
+ (
+ co::Reduce,
+ sg::All
+ | sg::Any
+ | sg::Add
+ | sg::Mul
+ | sg::Min
+ | sg::Max
+ | sg::And
+ | sg::Or
+ | sg::Xor,
+ ) => {}
+ (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
+
+ (_, _) => {
+ return Err(SubgroupError::UnknownOperation.with_span().into_other());
+ }
+ };
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::SubgroupOperationResult { ty }
+ if { &context.types[ty].inner == argument_inner } => {}
+ _ => {
+ return Err(SubgroupError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
+ fn validate_subgroup_gather(
+ &mut self,
+ mode: &crate::GatherMode,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
+ match *index_ty {
+ crate::TypeInner::Scalar(crate::Scalar::U32) => {}
+ _ => {
+ log::error!(
+ "Subgroup gather index type {:?}, expected unsigned int",
+ index_ty
+ );
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(index, context.expressions)
+ .into_other());
+ }
+ }
+ }
+ }
+ let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
+ if !matches!(*argument_inner,
+ crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
+ if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
+ ) {
+ log::error!("Subgroup gather operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::SubgroupOperationResult { ty }
+ if { &context.types[ty].inner == argument_inner } => {}
+ _ => {
+ return Err(SubgroupError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
fn validate_block_impl(
&mut self,
@@ -613,8 +751,30 @@ impl super::Validator {
stages &= super::ShaderStages::FRAGMENT;
finished = true;
}
- S::Barrier(_) => {
+ S::Barrier(barrier) => {
stages &= super::ShaderStages::COMPUTE;
+ if barrier.contains(crate::Barrier::SUB_GROUP) {
+ if !self.capabilities.contains(
+ super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
+ ) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP
+ | super::Capabilities::SUBGROUP_BARRIER,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ if !self
+ .subgroup_operations
+ .contains(super::SubgroupOperationSet::BASIC)
+ {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(
+ super::SubgroupOperationSet::BASIC,
+ ),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ }
}
S::Store { pointer, value } => {
let mut current = pointer;
@@ -904,6 +1064,86 @@ impl super::Validator {
crate::RayQueryFunction::Terminate => {}
}
}
+ S::SubgroupBallot { result, predicate } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ if !self
+ .subgroup_operations
+ .contains(super::SubgroupOperationSet::BALLOT)
+ {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(
+ super::SubgroupOperationSet::BALLOT,
+ ),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ if let Some(predicate) = predicate {
+ let predicate_inner =
+ context.resolve_type(predicate, &self.valid_expression_set)?;
+ if !matches!(
+ *predicate_inner,
+ crate::TypeInner::Scalar(crate::Scalar::BOOL,)
+ ) {
+ log::error!(
+ "Subgroup ballot predicate type {:?} expected bool",
+ predicate_inner
+ );
+ return Err(SubgroupError::InvalidOperand(predicate)
+ .with_span_handle(predicate, context.expressions)
+ .into_other());
+ }
+ }
+ self.emit_expression(result, context)?;
+ }
+ S::SubgroupCollectiveOperation {
+ ref op,
+ ref collective_op,
+ argument,
+ result,
+ } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ let operation = op.required_operations();
+ if !self.subgroup_operations.contains(operation) {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(operation),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
+ }
+ S::SubgroupGather {
+ ref mode,
+ argument,
+ result,
+ } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ let operation = mode.required_operations();
+ if !self.subgroup_operations.contains(operation) {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(operation),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ self.validate_subgroup_gather(mode, argument, result, context)?;
+ }
}
}
Ok(BlockInfo { stages, finished })
@@ -927,7 +1167,7 @@ impl super::Validator {
var: &crate::LocalVariable,
gctx: crate::proc::GlobalCtx,
fun_info: &FunctionInfo,
- expression_constness: &crate::proc::ExpressionConstnessTracker,
+ local_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<(), LocalVariableError> {
log::debug!("var {:?}", var);
let type_info = self
@@ -945,8 +1185,8 @@ impl super::Validator {
return Err(LocalVariableError::InitializerType);
}
- if !expression_constness.is_const(init) {
- return Err(LocalVariableError::NonConstInitializer);
+ if !local_expr_kind.is_const_or_override(init) {
+ return Err(LocalVariableError::NonConstOrOverrideInitializer);
}
}
@@ -959,14 +1199,14 @@ impl super::Validator {
module: &crate::Module,
mod_info: &ModuleInfo,
entry_point: bool,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<FunctionInfo, WithSpan<FunctionError>> {
let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
- let expression_constness =
- crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
+ let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
for (var_handle, var) in fun.local_variables.iter() {
- self.validate_local_var(var, module.to_ctx(), &info, &expression_constness)
+ self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
.map_err(|source| {
FunctionError::LocalVariable {
handle: var_handle,
@@ -1032,7 +1272,15 @@ impl super::Validator {
self.valid_expression_set.insert(handle.index());
}
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
- match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
+ match self.validate_expression(
+ handle,
+ expr,
+ fun,
+ module,
+ &info,
+ mod_info,
+ global_expr_kind,
+ ) {
Ok(stages) => info.available_stages &= stages,
Err(source) => {
return Err(FunctionError::Expression { handle, source }