diff options
Diffstat (limited to 'third_party/rust/naga/src/valid')
-rw-r--r-- | third_party/rust/naga/src/valid/analyzer.rs | 53 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/expression.rs | 39 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/function.rs | 268 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/handles.rs | 91 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/interface.rs | 48 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/mod.rs | 228 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/type.rs | 3 |
7 files changed, 672 insertions, 58 deletions
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs index 03fbc4089b..6799e5db27 100644 --- a/third_party/rust/naga/src/valid/analyzer.rs +++ b/third_party/rust/naga/src/valid/analyzer.rs @@ -226,7 +226,7 @@ struct Sampling { sampler: GlobalOrArgument, } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct FunctionInfo { @@ -574,7 +574,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(vector), requirements: UniformityRequirements::empty(), }, - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(), E::Compose { ref components, .. } => { let non_uniform_result = components .iter() @@ -787,6 +787,14 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::SubgroupBallotResult => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::SubgroupOperationResult { .. } => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -827,7 +835,7 @@ impl FunctionInfo { let req = self.expressions[expr.index()].uniformity.requirements; if self .flags - .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY) + .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY) && !req.is_empty() { if let Some(cause) = disruptor { @@ -1029,6 +1037,42 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::SubgroupBallot { + result: _, + predicate, + } => { + if let Some(predicate) = predicate { + let _ = self.add_ref(predicate); + } + FunctionUniformity::new() + } + S::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } + S::SubgroupGather { + mode, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let _ = self.add_ref(index); + } + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); @@ -1047,7 +1091,7 @@ impl ModuleInfo { gctx: crate::proc::GlobalCtx, ) -> Result<(), super::ConstExpressionError> { self.const_expression_types[handle.index()] = - resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?; + resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?; Ok(()) } @@ -1186,6 +1230,7 @@ fn uniform_control_flow() { }; let resolve_context = ResolveContext { constants: &Arena::new(), + overrides: &Arena::new(), types: &type_arena, special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs index 838ecc4e27..525bd28c17 100644 --- a/third_party/rust/naga/src/valid/expression.rs +++ b/third_party/rust/naga/src/valid/expression.rs @@ -90,6 +90,8 @@ pub enum ExpressionError { sampler: bool, has_ref: bool, }, + #[error("Sample offset must be a const-expression")] + InvalidSampleOffsetExprType, #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>), #[error("Depth reference {0:?} is not a scalar float")] @@ -129,9 +131,12 @@ pub enum ExpressionError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstExpressionError { - #[error("The expression is not a constant expression")] - NonConst, + #[error("The expression is not a constant or override expression")] + NonConstOrOverride, + #[error("The expression is not a fully evaluated constant expression")] + NonFullyEvaluatedConst, #[error(transparent)] Compose(#[from] super::ComposeError), #[error("Splatting {0:?} can't be done")] @@ -184,10 +189,15 @@ impl super::Validator { handle: Handle<crate::Expression>, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), ConstExpressionError> { use crate::Expression as E; - match gctx.const_expressions[handle] { + if !global_expr_kind.is_const_or_override(handle) { + return Err(ConstExpressionError::NonConstOrOverride); + } + + match gctx.global_expressions[handle] { E::Literal(literal) => { self.validate_literal(literal)?; } @@ -201,14 +211,19 @@ impl super::Validator { } E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) { crate::TypeInner::Scalar { .. } => {} - _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), + _ => return Err(ConstExpressionError::InvalidSplatType(value)), }, - _ => return Err(super::ConstExpressionError::NonConst), + _ if global_expr_kind.is_const(handle) || !self.allow_overrides => { + return Err(ConstExpressionError::NonFullyEvaluatedConst) + } + // the constant evaluator will report errors about override-expressions + _ => {} } Ok(()) } + #[allow(clippy::too_many_arguments)] pub(super) fn validate_expression( &self, root: Handle<crate::Expression>, @@ -217,6 +232,7 @@ impl super::Validator { module: &crate::Module, info: &FunctionInfo, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<ShaderStages, ExpressionError> { use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; @@ -252,9 +268,7 @@ impl super::Validator { return Err(ExpressionError::InvalidIndexType(index)); } } - if dynamic_indexing_restricted - && function.expressions[index].is_dynamic_index(module) - { + if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() { return Err(ExpressionError::IndexMustBeConstant(base)); } @@ -347,7 +361,7 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, @@ -464,6 +478,10 @@ impl super::Validator { // check constant offset if let Some(const_expr) = offset { + if !global_expr_kind.is_const(const_expr) { + return Err(ExpressionError::InvalidSampleOffsetExprType); + } + match *mod_info[const_expr].inner_with(&module.types) { Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {} Ti::Vector { @@ -1623,6 +1641,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, + E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, }; Ok(stages) } @@ -1716,7 +1735,7 @@ fn validate_with_const_expression( use crate::span::Span; let mut module = crate::Module::default(); - module.const_expressions.append(expr, Span::default()); + module.global_expressions.append(expr, Span::default()); let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps); diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs index f0ca22cbda..71128fc86d 100644 --- a/third_party/rust/naga/src/valid/function.rs +++ b/third_party/rust/naga/src/valid/function.rs @@ -49,13 +49,26 @@ pub enum AtomicError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] +pub enum SubgroupError { + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle<crate::Expression>), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle<crate::Expression>), + #[error("Support for subgroup operation {0:?} is required")] + UnsupportedOperation(super::SubgroupOperationSet), + #[error("Unknown operation")] + UnknownOperation, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { #[error("Local variable has a type {0:?} that can't be stored in a local variable.")] InvalidType(Handle<crate::Type>), #[error("Initializer doesn't match the variable type")] InitializerType, - #[error("Initializer is not const")] - NonConstInitializer, + #[error("Initializer is not a const or override expression")] + NonConstOrOverrideInitializer, } #[derive(Clone, Debug, thiserror::Error)] @@ -135,6 +148,8 @@ pub enum FunctionError { InvalidRayDescriptor(Handle<crate::Expression>), #[error("Ray Query {0:?} does not have a matching type")] InvalidRayQueryType(Handle<crate::Type>), + #[error("Shader requires capability {0:?}")] + MissingCapability(super::Capabilities), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -155,6 +170,8 @@ pub enum FunctionError { WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>), #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>), + #[error("Subgroup operation is invalid")] + InvalidSubgroup(#[from] SubgroupError), } bitflags::bitflags! { @@ -399,6 +416,127 @@ impl super::Validator { } Ok(()) } + fn validate_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + let (is_scalar, scalar) = match *argument_inner { + crate::TypeInner::Scalar(scalar) => (true, scalar), + crate::TypeInner::Vector { scalar, .. } => (false, scalar), + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + match (scalar.kind, *op) { + (sk::Bool, sg::All | sg::Any) if is_scalar => {} + (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} + (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {} + + (_, _) => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::CollectiveOperation as co; + match (*collective_op, *op) { + ( + co::Reduce, + sg::All + | sg::Any + | sg::Add + | sg::Mul + | sg::Min + | sg::Max + | sg::And + | sg::Or + | sg::Xor, + ) => {} + (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {} + + (_, _) => { + return Err(SubgroupError::UnknownOperation.with_span().into_other()); + } + }; + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + fn validate_subgroup_gather( + &mut self, + mode: &crate::GatherMode, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar(crate::Scalar::U32) => {} + _ => { + log::error!( + "Subgroup gather index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } + } + } + } + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + if !matches!(*argument_inner, + crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } + if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup gather operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } fn validate_block_impl( &mut self, @@ -613,8 +751,30 @@ impl super::Validator { stages &= super::ShaderStages::FRAGMENT; finished = true; } - S::Barrier(_) => { + S::Barrier(barrier) => { stages &= super::ShaderStages::COMPUTE; + if barrier.contains(crate::Barrier::SUB_GROUP) { + if !self.capabilities.contains( + super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER, + ) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP + | super::Capabilities::SUBGROUP_BARRIER, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BASIC) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BASIC, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + } } S::Store { pointer, value } => { let mut current = pointer; @@ -904,6 +1064,86 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result, predicate } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BALLOT) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BALLOT, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + if let Some(predicate) = predicate { + let predicate_inner = + context.resolve_type(predicate, &self.valid_expression_set)?; + if !matches!( + *predicate_inner, + crate::TypeInner::Scalar(crate::Scalar::BOOL,) + ) { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); + } + } + self.emit_expression(result, context)?; + } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = op.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_operation(op, collective_op, argument, result, context)?; + } + S::SubgroupGather { + ref mode, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = mode.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_gather(mode, argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) @@ -927,7 +1167,7 @@ impl super::Validator { var: &crate::LocalVariable, gctx: crate::proc::GlobalCtx, fun_info: &FunctionInfo, - expression_constness: &crate::proc::ExpressionConstnessTracker, + local_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -945,8 +1185,8 @@ impl super::Validator { return Err(LocalVariableError::InitializerType); } - if !expression_constness.is_const(init) { - return Err(LocalVariableError::NonConstInitializer); + if !local_expr_kind.is_const_or_override(init) { + return Err(LocalVariableError::NonConstOrOverrideInitializer); } } @@ -959,14 +1199,14 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, entry_point: bool, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<FunctionInfo, WithSpan<FunctionError>> { let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; - let expression_constness = - crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions); for (var_handle, var) in fun.local_variables.iter() { - self.validate_local_var(var, module.to_ctx(), &info, &expression_constness) + self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind) .map_err(|source| { FunctionError::LocalVariable { handle: var_handle, @@ -1032,7 +1272,15 @@ impl super::Validator { self.valid_expression_set.insert(handle.index()); } if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { - match self.validate_expression(handle, expr, fun, module, &info, mod_info) { + match self.validate_expression( + handle, + expr, + fun, + module, + &info, + mod_info, + global_expr_kind, + ) { Ok(stages) => info.available_stages &= stages, Err(source) => { return Err(FunctionError::Expression { handle, source } diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs index e482f293bb..8f78204055 100644 --- a/third_party/rust/naga/src/valid/handles.rs +++ b/third_party/rust/naga/src/valid/handles.rs @@ -31,12 +31,13 @@ impl super::Validator { pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { let &crate::Module { ref constants, + ref overrides, ref entry_points, ref functions, ref global_variables, ref types, ref special_types, - ref const_expressions, + ref global_expressions, } = module; // NOTE: Types being first is important. All other forms of validation depend on this. @@ -67,23 +68,31 @@ impl super::Validator { } } - for handle_and_expr in const_expressions.iter() { - Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + for handle_and_expr in global_expressions.iter() { + Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); let validate_const_expr = - |handle| Self::validate_expression_handle(handle, const_expressions); + |handle| Self::validate_expression_handle(handle, global_expressions); for (_handle, constant) in constants.iter() { - let &crate::Constant { + let &crate::Constant { name: _, ty, init } = constant; + validate_type(ty)?; + validate_const_expr(init)?; + } + + for (_handle, override_) in overrides.iter() { + let &crate::Override { name: _, - r#override: _, + id: _, ty, init, - } = constant; + } = override_; validate_type(ty)?; - validate_const_expr(init)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } } for (_handle, global_variable) in global_variables.iter() { @@ -140,7 +149,8 @@ impl super::Validator { Self::validate_expression_handles( handle_and_expr, constants, - const_expressions, + overrides, + global_expressions, types, local_variables, global_variables, @@ -186,6 +196,13 @@ impl super::Validator { handle.check_valid_for(constants).map(|_| ()) } + fn validate_override_handle( + handle: Handle<crate::Override>, + overrides: &Arena<crate::Override>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(overrides).map(|_| ()) + } + fn validate_expression_handle( handle: Handle<crate::Expression>, expressions: &Arena<crate::Expression>, @@ -203,9 +220,11 @@ impl super::Validator { fn validate_const_expression_handles( (handle, expression): (Handle<crate::Expression>, &crate::Expression), constants: &Arena<crate::Constant>, + overrides: &Arena<crate::Override>, types: &UniqueArena<crate::Type>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -214,6 +233,12 @@ impl super::Validator { validate_constant(constant)?; handle.check_dep(constants[constant].init)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + if let Some(init) = overrides[override_].init { + handle.check_dep(init)?; + } + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -230,7 +255,8 @@ impl super::Validator { fn validate_expression_handles( (handle, expression): (Handle<crate::Expression>, &crate::Expression), constants: &Arena<crate::Constant>, - const_expressions: &Arena<crate::Expression>, + overrides: &Arena<crate::Override>, + global_expressions: &Arena<crate::Expression>, types: &UniqueArena<crate::Type>, local_variables: &Arena<crate::LocalVariable>, global_variables: &Arena<crate::GlobalVariable>, @@ -239,8 +265,9 @@ impl super::Validator { current_function: Option<Handle<crate::Function>>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_const_expr = - |handle| Self::validate_expression_handle(handle, const_expressions); + |handle| Self::validate_expression_handle(handle, global_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -260,6 +287,9 @@ impl super::Validator { crate::Expression::Constant(constant) => { validate_constant(constant)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -390,6 +420,8 @@ impl super::Validator { } crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -535,6 +567,38 @@ impl super::Validator { } Ok(()) } + crate::Statement::SubgroupBallot { result, predicate } => { + validate_expr_opt(predicate)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + validate_expr(argument)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, + } + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill @@ -562,6 +626,7 @@ impl From<BadRangeError> for ValidationError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum InvalidHandleError { #[error(transparent)] BadHandle(#[from] BadHandle), @@ -572,6 +637,7 @@ pub enum InvalidHandleError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] #[error( "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \ which has not been processed yet" @@ -664,6 +730,7 @@ fn constant_deps() { let mut const_exprs = Arena::new(); let mut fun_exprs = Arena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let i32_handle = types.insert( Type { @@ -679,7 +746,6 @@ fn constant_deps() { let self_referential_const = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_handle, init: fun_expr, }, @@ -692,6 +758,7 @@ fn constant_deps() { assert!(super::Validator::validate_const_expression_handles( handle_and_expr, &constants, + &overrides, &types, ) .is_err()); diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs index 84c8b09ddb..db890ddbac 100644 --- a/third_party/rust/naga/src/valid/interface.rs +++ b/third_party/rust/naga/src/valid/interface.rs @@ -10,6 +10,7 @@ use bit_set::BitSet; const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum GlobalVariableError { #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), @@ -30,6 +31,8 @@ pub enum GlobalVariableError { Handle<crate::Type>, #[source] Disalignment, ), + #[error("Initializer must be an override-expression")] + InitializerExprType, #[error("Initializer doesn't match the variable type")] InitializerType, #[error("Initializer can't be used with address space {0:?}")] @@ -39,6 +42,7 @@ pub enum GlobalVariableError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum VaryingError { #[error("The type {0:?} does not match the varying")] InvalidType(Handle<crate::Type>), @@ -73,9 +77,12 @@ pub enum VaryingError { location: u32, attribute: &'static str, }, + #[error("Workgroup size is multi dimensional, @builtin(subgroup_id) and @builtin(subgroup_invocation_id) are not supported.")] + InvalidMultiDimensionalSubgroupBuiltIn, } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum EntryPointError { #[error("Multiple conflicting entry points")] Conflict, @@ -135,6 +142,7 @@ struct VaryingContext<'a> { impl VaryingContext<'_> { fn validate_impl( &mut self, + ep: &crate::EntryPoint, ty: Handle<crate::Type>, binding: &crate::Binding, ) -> Result<(), VaryingError> { @@ -162,12 +170,24 @@ impl VaryingContext<'_> { Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX, Bi::ViewIndex => Capabilities::MULTIVIEW, Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING, + Bi::NumSubgroups + | Bi::SubgroupId + | Bi::SubgroupSize + | Bi::SubgroupInvocationId => Capabilities::SUBGROUP, _ => Capabilities::empty(), }; if !self.capabilities.contains(required) { return Err(VaryingError::UnsupportedCapability(required)); } + if matches!( + built_in, + crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId + ) && ep.workgroup_size[1..].iter().any(|&s| s > 1) + { + return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn); + } + let (visible, type_good) = match built_in { Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( self.stage == St::Vertex && !self.output, @@ -249,6 +269,17 @@ impl VaryingContext<'_> { scalar: crate::Scalar::U32, }, ), + Bi::NumSubgroups | Bi::SubgroupId => ( + self.stage == St::Compute && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::SubgroupSize | Bi::SubgroupInvocationId => ( + match self.stage { + St::Compute | St::Fragment => !self.output, + St::Vertex => false, + }, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), }; if !visible { @@ -349,13 +380,14 @@ impl VaryingContext<'_> { fn validate( &mut self, + ep: &crate::EntryPoint, ty: Handle<crate::Type>, binding: Option<&crate::Binding>, ) -> Result<(), WithSpan<VaryingError>> { let span_context = self.types.get_span_context(ty); match binding { Some(binding) => self - .validate_impl(ty, binding) + .validate_impl(ep, ty, binding) .map_err(|e| e.with_span_context(span_context)), None => { match self.types[ty].inner { @@ -372,7 +404,7 @@ impl VaryingContext<'_> { } } Some(ref binding) => self - .validate_impl(member.ty, binding) + .validate_impl(ep, member.ty, binding) .map_err(|e| e.with_span_context(span_context))?, } } @@ -395,6 +427,7 @@ impl super::Validator { var: &crate::GlobalVariable, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), GlobalVariableError> { use super::TypeFlags; @@ -523,6 +556,10 @@ impl super::Validator { } } + if !global_expr_kind.is_const_or_override(init) { + return Err(GlobalVariableError::InitializerExprType); + } + let decl_ty = &gctx.types[var.ty].inner; let init_ty = mod_info[init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -538,6 +575,7 @@ impl super::Validator { ep: &crate::EntryPoint, module: &crate::Module, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<FunctionInfo, WithSpan<EntryPointError>> { if ep.early_depth_test.is_some() { let required = Capabilities::EARLY_DEPTH_TEST; @@ -566,7 +604,7 @@ impl super::Validator { } let mut info = self - .validate_function(&ep.function, module, mod_info, true) + .validate_function(&ep.function, module, mod_info, true, global_expr_kind) .map_err(WithSpan::into_other)?; { @@ -598,7 +636,7 @@ impl super::Validator { capabilities: self.capabilities, flags: self.flags, }; - ctx.validate(fa.ty, fa.binding.as_ref()) + ctx.validate(ep, fa.ty, fa.binding.as_ref()) .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; } @@ -616,7 +654,7 @@ impl super::Validator { capabilities: self.capabilities, flags: self.flags, }; - ctx.validate(fr.ty, fr.binding.as_ref()) + ctx.validate(ep, fr.ty, fr.binding.as_ref()) .map_err_inner(|e| EntryPointError::Result(e).with_span())?; if ctx.second_blend_source { // Only the first location may be used when dual source blending diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs index 5459434f33..a0057f39ac 100644 --- a/third_party/rust/naga/src/valid/mod.rs +++ b/third_party/rust/naga/src/valid/mod.rs @@ -12,7 +12,7 @@ mod r#type; use crate::{ arena::Handle, - proc::{LayoutError, Layouter, TypeResolution}, + proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution}, FastHashSet, }; use bit_set::BitSet; @@ -77,7 +77,7 @@ bitflags::bitflags! { #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[derive(Clone, Copy, Debug, Eq, PartialEq)] - pub struct Capabilities: u16 { + pub struct Capabilities: u32 { /// Support for [`AddressSpace:PushConstant`]. const PUSH_CONSTANT = 0x1; /// Float values with width = 8. @@ -110,6 +110,10 @@ bitflags::bitflags! { const CUBE_ARRAY_TEXTURES = 0x4000; /// Support for 64-bit signed and unsigned integers. const SHADER_INT64 = 0x8000; + /// Support for subgroup operations. + const SUBGROUP = 0x10000; + /// Support for subgroup barriers. + const SUBGROUP_BARRIER = 0x20000; } } @@ -120,6 +124,57 @@ impl Default for Capabilities { } bitflags::bitflags! { + /// Supported subgroup operations + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct SubgroupOperationSet: u8 { + /// Elect, Barrier + const BASIC = 1 << 0; + /// Any, All + const VOTE = 1 << 1; + /// reductions, scans + const ARITHMETIC = 1 << 2; + /// ballot, broadcast + const BALLOT = 1 << 3; + /// shuffle, shuffle xor + const SHUFFLE = 1 << 4; + /// shuffle up, down + const SHUFFLE_RELATIVE = 1 << 5; + // We don't support these operations yet + // /// Clustered + // const CLUSTERED = 1 << 6; + // /// Quad supported + // const QUAD_FRAGMENT_COMPUTE = 1 << 7; + // /// Quad supported in all stages + // const QUAD_ALL_STAGES = 1 << 8; + } +} + +impl super::SubgroupOperation { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::All | Self::Any => S::VOTE, + Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => { + S::ARITHMETIC + } + } + } +} + +impl super::GatherMode { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, + Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, + Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + } + } +} + +bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] @@ -131,7 +186,7 @@ bitflags::bitflags! { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { @@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo { pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, + subgroup_stages: ShaderStages, + subgroup_operations: SubgroupOperationSet, types: Vec<r#type::TypeInfo>, layouter: Layouter, location_mask: BitSet, @@ -174,10 +231,15 @@ pub struct Validator { switch_values: FastHashSet<crate::SwitchValue>, valid_expression_list: Vec<Handle<crate::Expression>>, valid_expression_set: BitSet, + override_ids: FastHashSet<u16>, + allow_overrides: bool, } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantError { + #[error("Initializer must be a const-expression")] + InitializerExprType, #[error("The type doesn't match the constant")] InvalidType, #[error("The type is not constructible")] @@ -185,6 +247,26 @@ pub enum ConstantError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum OverrideError { + #[error("Override name and ID are missing")] + MissingNameAndID, + #[error("Override ID must be unique")] + DuplicateID, + #[error("Initializer must be a const-expression or override-expression")] + InitializerExprType, + #[error("The type doesn't match the override")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, + #[error("The type is not a scalar")] + TypeNotScalar, + #[error("Override declarations are not allowed")] + NotAllowed, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ValidationError { #[error(transparent)] InvalidHandle(#[from] InvalidHandleError), @@ -207,6 +289,12 @@ pub enum ValidationError { name: String, source: ConstantError, }, + #[error("Override {handle:?} '{name}' is invalid")] + Override { + handle: Handle<crate::Override>, + name: String, + source: OverrideError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle<crate::GlobalVariable>, @@ -286,6 +374,8 @@ impl Validator { Validator { flags, capabilities, + subgroup_stages: ShaderStages::empty(), + subgroup_operations: SubgroupOperationSet::empty(), types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), @@ -293,9 +383,21 @@ impl Validator { switch_values: FastHashSet::default(), valid_expression_list: Vec::new(), valid_expression_set: BitSet::new(), + override_ids: FastHashSet::default(), + allow_overrides: true, } } + pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self { + self.subgroup_stages = stages; + self + } + + pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self { + self.subgroup_operations = operations; + self + } + /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); @@ -305,6 +407,7 @@ impl Validator { self.switch_values.clear(); self.valid_expression_list.clear(); self.valid_expression_set.clear(); + self.override_ids.clear(); } fn validate_constant( @@ -312,6 +415,7 @@ impl Validator { handle: Handle<crate::Constant>, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &ExpressionKindTracker, ) -> Result<(), ConstantError> { let con = &gctx.constants[handle]; @@ -320,6 +424,10 @@ impl Validator { return Err(ConstantError::NonConstructibleType); } + if !global_expr_kind.is_const(con.init) { + return Err(ConstantError::InitializerExprType); + } + let decl_ty = &gctx.types[con.ty].inner; let init_ty = mod_info[con.init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -329,11 +437,80 @@ impl Validator { Ok(()) } + fn validate_override( + &mut self, + handle: Handle<crate::Override>, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), OverrideError> { + if !self.allow_overrides { + return Err(OverrideError::NotAllowed); + } + + let o = &gctx.overrides[handle]; + + if o.name.is_none() && o.id.is_none() { + return Err(OverrideError::MissingNameAndID); + } + + if let Some(id) = o.id { + if !self.override_ids.insert(id) { + return Err(OverrideError::DuplicateID); + } + } + + let type_info = &self.types[o.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(OverrideError::NonConstructibleType); + } + + let decl_ty = &gctx.types[o.ty].inner; + match decl_ty { + &crate::TypeInner::Scalar(scalar) => match scalar { + crate::Scalar::BOOL + | crate::Scalar::I32 + | crate::Scalar::U32 + | crate::Scalar::F32 + | crate::Scalar::F64 => {} + _ => return Err(OverrideError::TypeNotScalar), + }, + _ => return Err(OverrideError::TypeNotScalar), + } + + if let Some(init) = o.init { + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(OverrideError::InvalidType); + } + } + + Ok(()) + } + /// Check the given module to be valid. pub fn validate( &mut self, module: &crate::Module, ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.allow_overrides = true; + self.validate_impl(module) + } + + /// Check the given module to be valid. + /// + /// With the additional restriction that overrides are not present. + pub fn validate_no_overrides( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.allow_overrides = false; + self.validate_impl(module) + } + + fn validate_impl( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { self.reset(); self.reset_types(module.types.len()); @@ -354,7 +531,7 @@ impl Validator { type_flags: Vec::with_capacity(module.types.len()), functions: Vec::with_capacity(module.functions.len()), entry_points: Vec::with_capacity(module.entry_points.len()), - const_expression_types: vec![placeholder; module.const_expressions.len()] + const_expression_types: vec![placeholder; module.global_expressions.len()] .into_boxed_slice(), }; @@ -376,27 +553,34 @@ impl Validator { { let t = crate::Arena::new(); let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]); - for (handle, _) in module.const_expressions.iter() { + for (handle, _) in module.global_expressions.iter() { mod_info .process_const_expression(handle, &resolve_context, module.to_ctx()) .map_err(|source| { ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) + .with_span_handle(handle, &module.global_expressions) })? } } + let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions); + if self.flags.contains(ValidationFlags::CONSTANTS) { - for (handle, _) in module.const_expressions.iter() { - self.validate_const_expression(handle, module.to_ctx(), &mod_info) - .map_err(|source| { - ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) - })? + for (handle, _) in module.global_expressions.iter() { + self.validate_const_expression( + handle, + module.to_ctx(), + &mod_info, + &global_expr_kind, + ) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.global_expressions) + })? } for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, module.to_ctx(), &mod_info) + self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::Constant { handle, @@ -406,10 +590,22 @@ impl Validator { .with_span_handle(handle, &module.constants) })? } + + for (handle, override_) in module.overrides.iter() { + self.validate_override(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Override { + handle, + name: override_.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })? + } } for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, module.to_ctx(), &mod_info) + self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::GlobalVariable { handle: var_handle, @@ -421,7 +617,7 @@ impl Validator { } for (handle, fun) in module.functions.iter() { - match self.validate_function(fun, module, &mod_info, false) { + match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { @@ -447,7 +643,7 @@ impl Validator { .with_span()); // TODO: keep some EP span information? } - match self.validate_entry_point(ep, module, &mod_info) { + match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs index b8eb618ed4..f5b9856074 100644 --- a/third_party/rust/naga/src/valid/type.rs +++ b/third_party/rust/naga/src/valid/type.rs @@ -63,6 +63,7 @@ bitflags::bitflags! { } #[derive(Clone, Copy, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum Disalignment { #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")] ArrayStride { stride: u32, alignment: Alignment }, @@ -87,6 +88,7 @@ pub enum Disalignment { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum TypeError { #[error("Capability {0:?} is required")] MissingCapability(Capabilities), @@ -326,7 +328,6 @@ impl super::Validator { TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::COPY - | TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE | shareable, |