diff options
Diffstat (limited to 'third_party/rust/naga/src/valid/function.rs')
-rw-r--r-- | third_party/rust/naga/src/valid/function.rs | 268 |
1 files changed, 258 insertions, 10 deletions
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs index f0ca22cbda..71128fc86d 100644 --- a/third_party/rust/naga/src/valid/function.rs +++ b/third_party/rust/naga/src/valid/function.rs @@ -49,13 +49,26 @@ pub enum AtomicError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] +pub enum SubgroupError { + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle<crate::Expression>), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle<crate::Expression>), + #[error("Support for subgroup operation {0:?} is required")] + UnsupportedOperation(super::SubgroupOperationSet), + #[error("Unknown operation")] + UnknownOperation, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { #[error("Local variable has a type {0:?} that can't be stored in a local variable.")] InvalidType(Handle<crate::Type>), #[error("Initializer doesn't match the variable type")] InitializerType, - #[error("Initializer is not const")] - NonConstInitializer, + #[error("Initializer is not a const or override expression")] + NonConstOrOverrideInitializer, } #[derive(Clone, Debug, thiserror::Error)] @@ -135,6 +148,8 @@ pub enum FunctionError { InvalidRayDescriptor(Handle<crate::Expression>), #[error("Ray Query {0:?} does not have a matching type")] InvalidRayQueryType(Handle<crate::Type>), + #[error("Shader requires capability {0:?}")] + MissingCapability(super::Capabilities), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -155,6 +170,8 @@ pub enum FunctionError { WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>), #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>), + #[error("Subgroup operation is invalid")] + InvalidSubgroup(#[from] SubgroupError), } bitflags::bitflags! { @@ -399,6 +416,127 @@ impl super::Validator { } Ok(()) } + fn validate_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + let (is_scalar, scalar) = match *argument_inner { + crate::TypeInner::Scalar(scalar) => (true, scalar), + crate::TypeInner::Vector { scalar, .. } => (false, scalar), + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + match (scalar.kind, *op) { + (sk::Bool, sg::All | sg::Any) if is_scalar => {} + (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} + (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {} + + (_, _) => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::CollectiveOperation as co; + match (*collective_op, *op) { + ( + co::Reduce, + sg::All + | sg::Any + | sg::Add + | sg::Mul + | sg::Min + | sg::Max + | sg::And + | sg::Or + | sg::Xor, + ) => {} + (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {} + + (_, _) => { + return Err(SubgroupError::UnknownOperation.with_span().into_other()); + } + }; + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + fn validate_subgroup_gather( + &mut self, + mode: &crate::GatherMode, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar(crate::Scalar::U32) => {} + _ => { + log::error!( + "Subgroup gather index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } + } + } + } + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + if !matches!(*argument_inner, + crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } + if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup gather operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } fn validate_block_impl( &mut self, @@ -613,8 +751,30 @@ impl super::Validator { stages &= super::ShaderStages::FRAGMENT; finished = true; } - S::Barrier(_) => { + S::Barrier(barrier) => { stages &= super::ShaderStages::COMPUTE; + if barrier.contains(crate::Barrier::SUB_GROUP) { + if !self.capabilities.contains( + super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER, + ) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP + | super::Capabilities::SUBGROUP_BARRIER, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BASIC) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BASIC, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + } } S::Store { pointer, value } => { let mut current = pointer; @@ -904,6 +1064,86 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result, predicate } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BALLOT) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BALLOT, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + if let Some(predicate) = predicate { + let predicate_inner = + context.resolve_type(predicate, &self.valid_expression_set)?; + if !matches!( + *predicate_inner, + crate::TypeInner::Scalar(crate::Scalar::BOOL,) + ) { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); + } + } + self.emit_expression(result, context)?; + } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = op.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_operation(op, collective_op, argument, result, context)?; + } + S::SubgroupGather { + ref mode, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = mode.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_gather(mode, argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) @@ -927,7 +1167,7 @@ impl super::Validator { var: &crate::LocalVariable, gctx: crate::proc::GlobalCtx, fun_info: &FunctionInfo, - expression_constness: &crate::proc::ExpressionConstnessTracker, + local_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -945,8 +1185,8 @@ impl super::Validator { return Err(LocalVariableError::InitializerType); } - if !expression_constness.is_const(init) { - return Err(LocalVariableError::NonConstInitializer); + if !local_expr_kind.is_const_or_override(init) { + return Err(LocalVariableError::NonConstOrOverrideInitializer); } } @@ -959,14 +1199,14 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, entry_point: bool, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<FunctionInfo, WithSpan<FunctionError>> { let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; - let expression_constness = - crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions); for (var_handle, var) in fun.local_variables.iter() { - self.validate_local_var(var, module.to_ctx(), &info, &expression_constness) + self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind) .map_err(|source| { FunctionError::LocalVariable { handle: var_handle, @@ -1032,7 +1272,15 @@ impl super::Validator { self.valid_expression_set.insert(handle.index()); } if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { - match self.validate_expression(handle, expr, fun, module, &info, mod_info) { + match self.validate_expression( + handle, + expr, + fun, + module, + &info, + mod_info, + global_expr_kind, + ) { Ok(stages) => info.available_stages &= stages, Err(source) => { return Err(FunctionError::Expression { handle, source } |