summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/valid
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:43:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-06-12 05:43:14 +0000
commit8dd16259287f58f9273002717ec4d27e97127719 (patch)
tree3863e62a53829a84037444beab3abd4ed9dfc7d0 /third_party/rust/naga/src/valid
parentReleasing progress-linux version 126.0.1-1~progress7.99u1. (diff)
downloadfirefox-8dd16259287f58f9273002717ec4d27e97127719.tar.xz
firefox-8dd16259287f58f9273002717ec4d27e97127719.zip
Merging upstream version 127.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/valid')
-rw-r--r--third_party/rust/naga/src/valid/analyzer.rs53
-rw-r--r--third_party/rust/naga/src/valid/expression.rs39
-rw-r--r--third_party/rust/naga/src/valid/function.rs268
-rw-r--r--third_party/rust/naga/src/valid/handles.rs91
-rw-r--r--third_party/rust/naga/src/valid/interface.rs48
-rw-r--r--third_party/rust/naga/src/valid/mod.rs228
-rw-r--r--third_party/rust/naga/src/valid/type.rs3
7 files changed, 672 insertions, 58 deletions
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs
index 03fbc4089b..6799e5db27 100644
--- a/third_party/rust/naga/src/valid/analyzer.rs
+++ b/third_party/rust/naga/src/valid/analyzer.rs
@@ -226,7 +226,7 @@ struct Sampling {
sampler: GlobalOrArgument,
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct FunctionInfo {
@@ -574,7 +574,7 @@ impl FunctionInfo {
non_uniform_result: self.add_ref(vector),
requirements: UniformityRequirements::empty(),
},
- E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(),
+ E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
E::Compose { ref components, .. } => {
let non_uniform_result = components
.iter()
@@ -787,6 +787,14 @@ impl FunctionInfo {
non_uniform_result: self.add_ref(query),
requirements: UniformityRequirements::empty(),
},
+ E::SubgroupBallotResult => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::SubgroupOperationResult { .. } => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
};
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
@@ -827,7 +835,7 @@ impl FunctionInfo {
let req = self.expressions[expr.index()].uniformity.requirements;
if self
.flags
- .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
+ .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
&& !req.is_empty()
{
if let Some(cause) = disruptor {
@@ -1029,6 +1037,42 @@ impl FunctionInfo {
}
FunctionUniformity::new()
}
+ S::SubgroupBallot {
+ result: _,
+ predicate,
+ } => {
+ if let Some(predicate) = predicate {
+ let _ = self.add_ref(predicate);
+ }
+ FunctionUniformity::new()
+ }
+ S::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: _,
+ argument,
+ result: _,
+ } => {
+ let _ = self.add_ref(argument);
+ FunctionUniformity::new()
+ }
+ S::SubgroupGather {
+ mode,
+ argument,
+ result: _,
+ } => {
+ let _ = self.add_ref(argument);
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let _ = self.add_ref(index);
+ }
+ }
+ FunctionUniformity::new()
+ }
};
disruptor = disruptor.or(uniformity.exit_disruptor());
@@ -1047,7 +1091,7 @@ impl ModuleInfo {
gctx: crate::proc::GlobalCtx,
) -> Result<(), super::ConstExpressionError> {
self.const_expression_types[handle.index()] =
- resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?;
+ resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
Ok(())
}
@@ -1186,6 +1230,7 @@ fn uniform_control_flow() {
};
let resolve_context = ResolveContext {
constants: &Arena::new(),
+ overrides: &Arena::new(),
types: &type_arena,
special_types: &crate::SpecialTypes::default(),
global_vars: &global_var_arena,
diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs
index 838ecc4e27..525bd28c17 100644
--- a/third_party/rust/naga/src/valid/expression.rs
+++ b/third_party/rust/naga/src/valid/expression.rs
@@ -90,6 +90,8 @@ pub enum ExpressionError {
sampler: bool,
has_ref: bool,
},
+ #[error("Sample offset must be a const-expression")]
+ InvalidSampleOffsetExprType,
#[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
#[error("Depth reference {0:?} is not a scalar float")]
@@ -129,9 +131,12 @@ pub enum ExpressionError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ConstExpressionError {
- #[error("The expression is not a constant expression")]
- NonConst,
+ #[error("The expression is not a constant or override expression")]
+ NonConstOrOverride,
+ #[error("The expression is not a fully evaluated constant expression")]
+ NonFullyEvaluatedConst,
#[error(transparent)]
Compose(#[from] super::ComposeError),
#[error("Splatting {0:?} can't be done")]
@@ -184,10 +189,15 @@ impl super::Validator {
handle: Handle<crate::Expression>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<(), ConstExpressionError> {
use crate::Expression as E;
- match gctx.const_expressions[handle] {
+ if !global_expr_kind.is_const_or_override(handle) {
+ return Err(ConstExpressionError::NonConstOrOverride);
+ }
+
+ match gctx.global_expressions[handle] {
E::Literal(literal) => {
self.validate_literal(literal)?;
}
@@ -201,14 +211,19 @@ impl super::Validator {
}
E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
crate::TypeInner::Scalar { .. } => {}
- _ => return Err(super::ConstExpressionError::InvalidSplatType(value)),
+ _ => return Err(ConstExpressionError::InvalidSplatType(value)),
},
- _ => return Err(super::ConstExpressionError::NonConst),
+ _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
+ return Err(ConstExpressionError::NonFullyEvaluatedConst)
+ }
+ // the constant evaluator will report errors about override-expressions
+ _ => {}
}
Ok(())
}
+ #[allow(clippy::too_many_arguments)]
pub(super) fn validate_expression(
&self,
root: Handle<crate::Expression>,
@@ -217,6 +232,7 @@ impl super::Validator {
module: &crate::Module,
info: &FunctionInfo,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<ShaderStages, ExpressionError> {
use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
@@ -252,9 +268,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidIndexType(index));
}
}
- if dynamic_indexing_restricted
- && function.expressions[index].is_dynamic_index(module)
- {
+ if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() {
return Err(ExpressionError::IndexMustBeConstant(base));
}
@@ -347,7 +361,7 @@ impl super::Validator {
self.validate_literal(literal)?;
ShaderStages::all()
}
- E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(),
+ E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
E::Compose { ref components, ty } => {
validate_compose(
ty,
@@ -464,6 +478,10 @@ impl super::Validator {
// check constant offset
if let Some(const_expr) = offset {
+ if !global_expr_kind.is_const(const_expr) {
+ return Err(ExpressionError::InvalidSampleOffsetExprType);
+ }
+
match *mod_info[const_expr].inner_with(&module.types) {
Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
Ti::Vector {
@@ -1623,6 +1641,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
+ E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
};
Ok(stages)
}
@@ -1716,7 +1735,7 @@ fn validate_with_const_expression(
use crate::span::Span;
let mut module = crate::Module::default();
- module.const_expressions.append(expr, Span::default());
+ module.global_expressions.append(expr, Span::default());
let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs
index f0ca22cbda..71128fc86d 100644
--- a/third_party/rust/naga/src/valid/function.rs
+++ b/third_party/rust/naga/src/valid/function.rs
@@ -49,13 +49,26 @@ pub enum AtomicError {
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
+pub enum SubgroupError {
+ #[error("Operand {0:?} has invalid type.")]
+ InvalidOperand(Handle<crate::Expression>),
+ #[error("Result type for {0:?} doesn't match the statement")]
+ ResultTypeMismatch(Handle<crate::Expression>),
+ #[error("Support for subgroup operation {0:?} is required")]
+ UnsupportedOperation(super::SubgroupOperationSet),
+ #[error("Unknown operation")]
+ UnknownOperation,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum LocalVariableError {
#[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
InvalidType(Handle<crate::Type>),
#[error("Initializer doesn't match the variable type")]
InitializerType,
- #[error("Initializer is not const")]
- NonConstInitializer,
+ #[error("Initializer is not a const or override expression")]
+ NonConstOrOverrideInitializer,
}
#[derive(Clone, Debug, thiserror::Error)]
@@ -135,6 +148,8 @@ pub enum FunctionError {
InvalidRayDescriptor(Handle<crate::Expression>),
#[error("Ray Query {0:?} does not have a matching type")]
InvalidRayQueryType(Handle<crate::Type>),
+ #[error("Shader requires capability {0:?}")]
+ MissingCapability(super::Capabilities),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
@@ -155,6 +170,8 @@ pub enum FunctionError {
WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
#[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
+ #[error("Subgroup operation is invalid")]
+ InvalidSubgroup(#[from] SubgroupError),
}
bitflags::bitflags! {
@@ -399,6 +416,127 @@ impl super::Validator {
}
Ok(())
}
+ fn validate_subgroup_operation(
+ &mut self,
+ op: &crate::SubgroupOperation,
+ collective_op: &crate::CollectiveOperation,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
+
+ let (is_scalar, scalar) = match *argument_inner {
+ crate::TypeInner::Scalar(scalar) => (true, scalar),
+ crate::TypeInner::Vector { scalar, .. } => (false, scalar),
+ _ => {
+ log::error!("Subgroup operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+ };
+
+ use crate::ScalarKind as sk;
+ use crate::SubgroupOperation as sg;
+ match (scalar.kind, *op) {
+ (sk::Bool, sg::All | sg::Any) if is_scalar => {}
+ (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
+ (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
+
+ (_, _) => {
+ log::error!("Subgroup operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+ };
+
+ use crate::CollectiveOperation as co;
+ match (*collective_op, *op) {
+ (
+ co::Reduce,
+ sg::All
+ | sg::Any
+ | sg::Add
+ | sg::Mul
+ | sg::Min
+ | sg::Max
+ | sg::And
+ | sg::Or
+ | sg::Xor,
+ ) => {}
+ (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
+
+ (_, _) => {
+ return Err(SubgroupError::UnknownOperation.with_span().into_other());
+ }
+ };
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::SubgroupOperationResult { ty }
+ if { &context.types[ty].inner == argument_inner } => {}
+ _ => {
+ return Err(SubgroupError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
+ fn validate_subgroup_gather(
+ &mut self,
+ mode: &crate::GatherMode,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
+ match *index_ty {
+ crate::TypeInner::Scalar(crate::Scalar::U32) => {}
+ _ => {
+ log::error!(
+ "Subgroup gather index type {:?}, expected unsigned int",
+ index_ty
+ );
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(index, context.expressions)
+ .into_other());
+ }
+ }
+ }
+ }
+ let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
+ if !matches!(*argument_inner,
+ crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
+ if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
+ ) {
+ log::error!("Subgroup gather operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::SubgroupOperationResult { ty }
+ if { &context.types[ty].inner == argument_inner } => {}
+ _ => {
+ return Err(SubgroupError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
fn validate_block_impl(
&mut self,
@@ -613,8 +751,30 @@ impl super::Validator {
stages &= super::ShaderStages::FRAGMENT;
finished = true;
}
- S::Barrier(_) => {
+ S::Barrier(barrier) => {
stages &= super::ShaderStages::COMPUTE;
+ if barrier.contains(crate::Barrier::SUB_GROUP) {
+ if !self.capabilities.contains(
+ super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
+ ) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP
+ | super::Capabilities::SUBGROUP_BARRIER,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ if !self
+ .subgroup_operations
+ .contains(super::SubgroupOperationSet::BASIC)
+ {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(
+ super::SubgroupOperationSet::BASIC,
+ ),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ }
}
S::Store { pointer, value } => {
let mut current = pointer;
@@ -904,6 +1064,86 @@ impl super::Validator {
crate::RayQueryFunction::Terminate => {}
}
}
+ S::SubgroupBallot { result, predicate } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ if !self
+ .subgroup_operations
+ .contains(super::SubgroupOperationSet::BALLOT)
+ {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(
+ super::SubgroupOperationSet::BALLOT,
+ ),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ if let Some(predicate) = predicate {
+ let predicate_inner =
+ context.resolve_type(predicate, &self.valid_expression_set)?;
+ if !matches!(
+ *predicate_inner,
+ crate::TypeInner::Scalar(crate::Scalar::BOOL,)
+ ) {
+ log::error!(
+ "Subgroup ballot predicate type {:?} expected bool",
+ predicate_inner
+ );
+ return Err(SubgroupError::InvalidOperand(predicate)
+ .with_span_handle(predicate, context.expressions)
+ .into_other());
+ }
+ }
+ self.emit_expression(result, context)?;
+ }
+ S::SubgroupCollectiveOperation {
+ ref op,
+ ref collective_op,
+ argument,
+ result,
+ } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ let operation = op.required_operations();
+ if !self.subgroup_operations.contains(operation) {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(operation),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
+ }
+ S::SubgroupGather {
+ ref mode,
+ argument,
+ result,
+ } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ let operation = mode.required_operations();
+ if !self.subgroup_operations.contains(operation) {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(operation),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ self.validate_subgroup_gather(mode, argument, result, context)?;
+ }
}
}
Ok(BlockInfo { stages, finished })
@@ -927,7 +1167,7 @@ impl super::Validator {
var: &crate::LocalVariable,
gctx: crate::proc::GlobalCtx,
fun_info: &FunctionInfo,
- expression_constness: &crate::proc::ExpressionConstnessTracker,
+ local_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<(), LocalVariableError> {
log::debug!("var {:?}", var);
let type_info = self
@@ -945,8 +1185,8 @@ impl super::Validator {
return Err(LocalVariableError::InitializerType);
}
- if !expression_constness.is_const(init) {
- return Err(LocalVariableError::NonConstInitializer);
+ if !local_expr_kind.is_const_or_override(init) {
+ return Err(LocalVariableError::NonConstOrOverrideInitializer);
}
}
@@ -959,14 +1199,14 @@ impl super::Validator {
module: &crate::Module,
mod_info: &ModuleInfo,
entry_point: bool,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<FunctionInfo, WithSpan<FunctionError>> {
let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
- let expression_constness =
- crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
+ let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
for (var_handle, var) in fun.local_variables.iter() {
- self.validate_local_var(var, module.to_ctx(), &info, &expression_constness)
+ self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
.map_err(|source| {
FunctionError::LocalVariable {
handle: var_handle,
@@ -1032,7 +1272,15 @@ impl super::Validator {
self.valid_expression_set.insert(handle.index());
}
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
- match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
+ match self.validate_expression(
+ handle,
+ expr,
+ fun,
+ module,
+ &info,
+ mod_info,
+ global_expr_kind,
+ ) {
Ok(stages) => info.available_stages &= stages,
Err(source) => {
return Err(FunctionError::Expression { handle, source }
diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs
index e482f293bb..8f78204055 100644
--- a/third_party/rust/naga/src/valid/handles.rs
+++ b/third_party/rust/naga/src/valid/handles.rs
@@ -31,12 +31,13 @@ impl super::Validator {
pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
let &crate::Module {
ref constants,
+ ref overrides,
ref entry_points,
ref functions,
ref global_variables,
ref types,
ref special_types,
- ref const_expressions,
+ ref global_expressions,
} = module;
// NOTE: Types being first is important. All other forms of validation depend on this.
@@ -67,23 +68,31 @@ impl super::Validator {
}
}
- for handle_and_expr in const_expressions.iter() {
- Self::validate_const_expression_handles(handle_and_expr, constants, types)?;
+ for handle_and_expr in global_expressions.iter() {
+ Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?;
}
let validate_type = |handle| Self::validate_type_handle(handle, types);
let validate_const_expr =
- |handle| Self::validate_expression_handle(handle, const_expressions);
+ |handle| Self::validate_expression_handle(handle, global_expressions);
for (_handle, constant) in constants.iter() {
- let &crate::Constant {
+ let &crate::Constant { name: _, ty, init } = constant;
+ validate_type(ty)?;
+ validate_const_expr(init)?;
+ }
+
+ for (_handle, override_) in overrides.iter() {
+ let &crate::Override {
name: _,
- r#override: _,
+ id: _,
ty,
init,
- } = constant;
+ } = override_;
validate_type(ty)?;
- validate_const_expr(init)?;
+ if let Some(init_expr) = init {
+ validate_const_expr(init_expr)?;
+ }
}
for (_handle, global_variable) in global_variables.iter() {
@@ -140,7 +149,8 @@ impl super::Validator {
Self::validate_expression_handles(
handle_and_expr,
constants,
- const_expressions,
+ overrides,
+ global_expressions,
types,
local_variables,
global_variables,
@@ -186,6 +196,13 @@ impl super::Validator {
handle.check_valid_for(constants).map(|_| ())
}
+ fn validate_override_handle(
+ handle: Handle<crate::Override>,
+ overrides: &Arena<crate::Override>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(overrides).map(|_| ())
+ }
+
fn validate_expression_handle(
handle: Handle<crate::Expression>,
expressions: &Arena<crate::Expression>,
@@ -203,9 +220,11 @@ impl super::Validator {
fn validate_const_expression_handles(
(handle, expression): (Handle<crate::Expression>, &crate::Expression),
constants: &Arena<crate::Constant>,
+ overrides: &Arena<crate::Override>,
types: &UniqueArena<crate::Type>,
) -> Result<(), InvalidHandleError> {
let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_override = |handle| Self::validate_override_handle(handle, overrides);
let validate_type = |handle| Self::validate_type_handle(handle, types);
match *expression {
@@ -214,6 +233,12 @@ impl super::Validator {
validate_constant(constant)?;
handle.check_dep(constants[constant].init)?;
}
+ crate::Expression::Override(override_) => {
+ validate_override(override_)?;
+ if let Some(init) = overrides[override_].init {
+ handle.check_dep(init)?;
+ }
+ }
crate::Expression::ZeroValue(ty) => {
validate_type(ty)?;
}
@@ -230,7 +255,8 @@ impl super::Validator {
fn validate_expression_handles(
(handle, expression): (Handle<crate::Expression>, &crate::Expression),
constants: &Arena<crate::Constant>,
- const_expressions: &Arena<crate::Expression>,
+ overrides: &Arena<crate::Override>,
+ global_expressions: &Arena<crate::Expression>,
types: &UniqueArena<crate::Type>,
local_variables: &Arena<crate::LocalVariable>,
global_variables: &Arena<crate::GlobalVariable>,
@@ -239,8 +265,9 @@ impl super::Validator {
current_function: Option<Handle<crate::Function>>,
) -> Result<(), InvalidHandleError> {
let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_override = |handle| Self::validate_override_handle(handle, overrides);
let validate_const_expr =
- |handle| Self::validate_expression_handle(handle, const_expressions);
+ |handle| Self::validate_expression_handle(handle, global_expressions);
let validate_type = |handle| Self::validate_type_handle(handle, types);
match *expression {
@@ -260,6 +287,9 @@ impl super::Validator {
crate::Expression::Constant(constant) => {
validate_constant(constant)?;
}
+ crate::Expression::Override(override_) => {
+ validate_override(override_)?;
+ }
crate::Expression::ZeroValue(ty) => {
validate_type(ty)?;
}
@@ -390,6 +420,8 @@ impl super::Validator {
}
crate::Expression::AtomicResult { .. }
| crate::Expression::RayQueryProceedResult
+ | crate::Expression::SubgroupBallotResult
+ | crate::Expression::SubgroupOperationResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. } => (),
crate::Expression::ArrayLength(array) => {
handle.check_dep(array)?;
@@ -535,6 +567,38 @@ impl super::Validator {
}
Ok(())
}
+ crate::Statement::SubgroupBallot { result, predicate } => {
+ validate_expr_opt(predicate)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: _,
+ argument,
+ result,
+ } => {
+ validate_expr(argument)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ validate_expr(argument)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?,
+ }
+ validate_expr(result)?;
+ Ok(())
+ }
crate::Statement::Break
| crate::Statement::Continue
| crate::Statement::Kill
@@ -562,6 +626,7 @@ impl From<BadRangeError> for ValidationError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum InvalidHandleError {
#[error(transparent)]
BadHandle(#[from] BadHandle),
@@ -572,6 +637,7 @@ pub enum InvalidHandleError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
#[error(
"{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
which has not been processed yet"
@@ -664,6 +730,7 @@ fn constant_deps() {
let mut const_exprs = Arena::new();
let mut fun_exprs = Arena::new();
let mut constants = Arena::new();
+ let overrides = Arena::new();
let i32_handle = types.insert(
Type {
@@ -679,7 +746,6 @@ fn constant_deps() {
let self_referential_const = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: i32_handle,
init: fun_expr,
},
@@ -692,6 +758,7 @@ fn constant_deps() {
assert!(super::Validator::validate_const_expression_handles(
handle_and_expr,
&constants,
+ &overrides,
&types,
)
.is_err());
diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs
index 84c8b09ddb..db890ddbac 100644
--- a/third_party/rust/naga/src/valid/interface.rs
+++ b/third_party/rust/naga/src/valid/interface.rs
@@ -10,6 +10,7 @@ use bit_set::BitSet;
const MAX_WORKGROUP_SIZE: u32 = 0x4000;
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum GlobalVariableError {
#[error("Usage isn't compatible with address space {0:?}")]
InvalidUsage(crate::AddressSpace),
@@ -30,6 +31,8 @@ pub enum GlobalVariableError {
Handle<crate::Type>,
#[source] Disalignment,
),
+ #[error("Initializer must be an override-expression")]
+ InitializerExprType,
#[error("Initializer doesn't match the variable type")]
InitializerType,
#[error("Initializer can't be used with address space {0:?}")]
@@ -39,6 +42,7 @@ pub enum GlobalVariableError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum VaryingError {
#[error("The type {0:?} does not match the varying")]
InvalidType(Handle<crate::Type>),
@@ -73,9 +77,12 @@ pub enum VaryingError {
location: u32,
attribute: &'static str,
},
+ #[error("Workgroup size is multi dimensional, @builtin(subgroup_id) and @builtin(subgroup_invocation_id) are not supported.")]
+ InvalidMultiDimensionalSubgroupBuiltIn,
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum EntryPointError {
#[error("Multiple conflicting entry points")]
Conflict,
@@ -135,6 +142,7 @@ struct VaryingContext<'a> {
impl VaryingContext<'_> {
fn validate_impl(
&mut self,
+ ep: &crate::EntryPoint,
ty: Handle<crate::Type>,
binding: &crate::Binding,
) -> Result<(), VaryingError> {
@@ -162,12 +170,24 @@ impl VaryingContext<'_> {
Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
Bi::ViewIndex => Capabilities::MULTIVIEW,
Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
+ Bi::NumSubgroups
+ | Bi::SubgroupId
+ | Bi::SubgroupSize
+ | Bi::SubgroupInvocationId => Capabilities::SUBGROUP,
_ => Capabilities::empty(),
};
if !self.capabilities.contains(required) {
return Err(VaryingError::UnsupportedCapability(required));
}
+ if matches!(
+ built_in,
+ crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId
+ ) && ep.workgroup_size[1..].iter().any(|&s| s > 1)
+ {
+ return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn);
+ }
+
let (visible, type_good) = match built_in {
Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
self.stage == St::Vertex && !self.output,
@@ -249,6 +269,17 @@ impl VaryingContext<'_> {
scalar: crate::Scalar::U32,
},
),
+ Bi::NumSubgroups | Bi::SubgroupId => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
+ Bi::SubgroupSize | Bi::SubgroupInvocationId => (
+ match self.stage {
+ St::Compute | St::Fragment => !self.output,
+ St::Vertex => false,
+ },
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
};
if !visible {
@@ -349,13 +380,14 @@ impl VaryingContext<'_> {
fn validate(
&mut self,
+ ep: &crate::EntryPoint,
ty: Handle<crate::Type>,
binding: Option<&crate::Binding>,
) -> Result<(), WithSpan<VaryingError>> {
let span_context = self.types.get_span_context(ty);
match binding {
Some(binding) => self
- .validate_impl(ty, binding)
+ .validate_impl(ep, ty, binding)
.map_err(|e| e.with_span_context(span_context)),
None => {
match self.types[ty].inner {
@@ -372,7 +404,7 @@ impl VaryingContext<'_> {
}
}
Some(ref binding) => self
- .validate_impl(member.ty, binding)
+ .validate_impl(ep, member.ty, binding)
.map_err(|e| e.with_span_context(span_context))?,
}
}
@@ -395,6 +427,7 @@ impl super::Validator {
var: &crate::GlobalVariable,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<(), GlobalVariableError> {
use super::TypeFlags;
@@ -523,6 +556,10 @@ impl super::Validator {
}
}
+ if !global_expr_kind.is_const_or_override(init) {
+ return Err(GlobalVariableError::InitializerExprType);
+ }
+
let decl_ty = &gctx.types[var.ty].inner;
let init_ty = mod_info[init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
@@ -538,6 +575,7 @@ impl super::Validator {
ep: &crate::EntryPoint,
module: &crate::Module,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
if ep.early_depth_test.is_some() {
let required = Capabilities::EARLY_DEPTH_TEST;
@@ -566,7 +604,7 @@ impl super::Validator {
}
let mut info = self
- .validate_function(&ep.function, module, mod_info, true)
+ .validate_function(&ep.function, module, mod_info, true, global_expr_kind)
.map_err(WithSpan::into_other)?;
{
@@ -598,7 +636,7 @@ impl super::Validator {
capabilities: self.capabilities,
flags: self.flags,
};
- ctx.validate(fa.ty, fa.binding.as_ref())
+ ctx.validate(ep, fa.ty, fa.binding.as_ref())
.map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
}
@@ -616,7 +654,7 @@ impl super::Validator {
capabilities: self.capabilities,
flags: self.flags,
};
- ctx.validate(fr.ty, fr.binding.as_ref())
+ ctx.validate(ep, fr.ty, fr.binding.as_ref())
.map_err_inner(|e| EntryPointError::Result(e).with_span())?;
if ctx.second_blend_source {
// Only the first location may be used when dual source blending
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs
index 5459434f33..a0057f39ac 100644
--- a/third_party/rust/naga/src/valid/mod.rs
+++ b/third_party/rust/naga/src/valid/mod.rs
@@ -12,7 +12,7 @@ mod r#type;
use crate::{
arena::Handle,
- proc::{LayoutError, Layouter, TypeResolution},
+ proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
FastHashSet,
};
use bit_set::BitSet;
@@ -77,7 +77,7 @@ bitflags::bitflags! {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
- pub struct Capabilities: u16 {
+ pub struct Capabilities: u32 {
/// Support for [`AddressSpace:PushConstant`].
const PUSH_CONSTANT = 0x1;
/// Float values with width = 8.
@@ -110,6 +110,10 @@ bitflags::bitflags! {
const CUBE_ARRAY_TEXTURES = 0x4000;
/// Support for 64-bit signed and unsigned integers.
const SHADER_INT64 = 0x8000;
+ /// Support for subgroup operations.
+ const SUBGROUP = 0x10000;
+ /// Support for subgroup barriers.
+ const SUBGROUP_BARRIER = 0x20000;
}
}
@@ -120,6 +124,57 @@ impl Default for Capabilities {
}
bitflags::bitflags! {
+ /// Supported subgroup operations
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+ pub struct SubgroupOperationSet: u8 {
+ /// Elect, Barrier
+ const BASIC = 1 << 0;
+ /// Any, All
+ const VOTE = 1 << 1;
+ /// reductions, scans
+ const ARITHMETIC = 1 << 2;
+ /// ballot, broadcast
+ const BALLOT = 1 << 3;
+ /// shuffle, shuffle xor
+ const SHUFFLE = 1 << 4;
+ /// shuffle up, down
+ const SHUFFLE_RELATIVE = 1 << 5;
+ // We don't support these operations yet
+ // /// Clustered
+ // const CLUSTERED = 1 << 6;
+ // /// Quad supported
+ // const QUAD_FRAGMENT_COMPUTE = 1 << 7;
+ // /// Quad supported in all stages
+ // const QUAD_ALL_STAGES = 1 << 8;
+ }
+}
+
+impl super::SubgroupOperation {
+ const fn required_operations(&self) -> SubgroupOperationSet {
+ use SubgroupOperationSet as S;
+ match *self {
+ Self::All | Self::Any => S::VOTE,
+ Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
+ S::ARITHMETIC
+ }
+ }
+ }
+}
+
+impl super::GatherMode {
+ const fn required_operations(&self) -> SubgroupOperationSet {
+ use SubgroupOperationSet as S;
+ match *self {
+ Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
+ Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
+ Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
+ }
+ }
+}
+
+bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
@@ -131,7 +186,7 @@ bitflags::bitflags! {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
@@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
pub struct Validator {
flags: ValidationFlags,
capabilities: Capabilities,
+ subgroup_stages: ShaderStages,
+ subgroup_operations: SubgroupOperationSet,
types: Vec<r#type::TypeInfo>,
layouter: Layouter,
location_mask: BitSet,
@@ -174,10 +231,15 @@ pub struct Validator {
switch_values: FastHashSet<crate::SwitchValue>,
valid_expression_list: Vec<Handle<crate::Expression>>,
valid_expression_set: BitSet,
+ override_ids: FastHashSet<u16>,
+ allow_overrides: bool,
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantError {
+ #[error("Initializer must be a const-expression")]
+ InitializerExprType,
#[error("The type doesn't match the constant")]
InvalidType,
#[error("The type is not constructible")]
@@ -185,6 +247,26 @@ pub enum ConstantError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum OverrideError {
+ #[error("Override name and ID are missing")]
+ MissingNameAndID,
+ #[error("Override ID must be unique")]
+ DuplicateID,
+ #[error("Initializer must be a const-expression or override-expression")]
+ InitializerExprType,
+ #[error("The type doesn't match the override")]
+ InvalidType,
+ #[error("The type is not constructible")]
+ NonConstructibleType,
+ #[error("The type is not a scalar")]
+ TypeNotScalar,
+ #[error("Override declarations are not allowed")]
+ NotAllowed,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ValidationError {
#[error(transparent)]
InvalidHandle(#[from] InvalidHandleError),
@@ -207,6 +289,12 @@ pub enum ValidationError {
name: String,
source: ConstantError,
},
+ #[error("Override {handle:?} '{name}' is invalid")]
+ Override {
+ handle: Handle<crate::Override>,
+ name: String,
+ source: OverrideError,
+ },
#[error("Global variable {handle:?} '{name}' is invalid")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
@@ -286,6 +374,8 @@ impl Validator {
Validator {
flags,
capabilities,
+ subgroup_stages: ShaderStages::empty(),
+ subgroup_operations: SubgroupOperationSet::empty(),
types: Vec::new(),
layouter: Layouter::default(),
location_mask: BitSet::new(),
@@ -293,9 +383,21 @@ impl Validator {
switch_values: FastHashSet::default(),
valid_expression_list: Vec::new(),
valid_expression_set: BitSet::new(),
+ override_ids: FastHashSet::default(),
+ allow_overrides: true,
}
}
+ pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
+ self.subgroup_stages = stages;
+ self
+ }
+
+ pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
+ self.subgroup_operations = operations;
+ self
+ }
+
/// Reset the validator internals
pub fn reset(&mut self) {
self.types.clear();
@@ -305,6 +407,7 @@ impl Validator {
self.switch_values.clear();
self.valid_expression_list.clear();
self.valid_expression_set.clear();
+ self.override_ids.clear();
}
fn validate_constant(
@@ -312,6 +415,7 @@ impl Validator {
handle: Handle<crate::Constant>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
+ global_expr_kind: &ExpressionKindTracker,
) -> Result<(), ConstantError> {
let con = &gctx.constants[handle];
@@ -320,6 +424,10 @@ impl Validator {
return Err(ConstantError::NonConstructibleType);
}
+ if !global_expr_kind.is_const(con.init) {
+ return Err(ConstantError::InitializerExprType);
+ }
+
let decl_ty = &gctx.types[con.ty].inner;
let init_ty = mod_info[con.init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
@@ -329,11 +437,80 @@ impl Validator {
Ok(())
}
+ fn validate_override(
+ &mut self,
+ handle: Handle<crate::Override>,
+ gctx: crate::proc::GlobalCtx,
+ mod_info: &ModuleInfo,
+ ) -> Result<(), OverrideError> {
+ if !self.allow_overrides {
+ return Err(OverrideError::NotAllowed);
+ }
+
+ let o = &gctx.overrides[handle];
+
+ if o.name.is_none() && o.id.is_none() {
+ return Err(OverrideError::MissingNameAndID);
+ }
+
+ if let Some(id) = o.id {
+ if !self.override_ids.insert(id) {
+ return Err(OverrideError::DuplicateID);
+ }
+ }
+
+ let type_info = &self.types[o.ty.index()];
+ if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
+ return Err(OverrideError::NonConstructibleType);
+ }
+
+ let decl_ty = &gctx.types[o.ty].inner;
+ match decl_ty {
+ &crate::TypeInner::Scalar(scalar) => match scalar {
+ crate::Scalar::BOOL
+ | crate::Scalar::I32
+ | crate::Scalar::U32
+ | crate::Scalar::F32
+ | crate::Scalar::F64 => {}
+ _ => return Err(OverrideError::TypeNotScalar),
+ },
+ _ => return Err(OverrideError::TypeNotScalar),
+ }
+
+ if let Some(init) = o.init {
+ let init_ty = mod_info[init].inner_with(gctx.types);
+ if !decl_ty.equivalent(init_ty, gctx.types) {
+ return Err(OverrideError::InvalidType);
+ }
+ }
+
+ Ok(())
+ }
+
/// Check the given module to be valid.
pub fn validate(
&mut self,
module: &crate::Module,
) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.allow_overrides = true;
+ self.validate_impl(module)
+ }
+
+ /// Check the given module to be valid.
+ ///
+ /// With the additional restriction that overrides are not present.
+ pub fn validate_no_overrides(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.allow_overrides = false;
+ self.validate_impl(module)
+ }
+
+ fn validate_impl(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
self.reset();
self.reset_types(module.types.len());
@@ -354,7 +531,7 @@ impl Validator {
type_flags: Vec::with_capacity(module.types.len()),
functions: Vec::with_capacity(module.functions.len()),
entry_points: Vec::with_capacity(module.entry_points.len()),
- const_expression_types: vec![placeholder; module.const_expressions.len()]
+ const_expression_types: vec![placeholder; module.global_expressions.len()]
.into_boxed_slice(),
};
@@ -376,27 +553,34 @@ impl Validator {
{
let t = crate::Arena::new();
let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
- for (handle, _) in module.const_expressions.iter() {
+ for (handle, _) in module.global_expressions.iter() {
mod_info
.process_const_expression(handle, &resolve_context, module.to_ctx())
.map_err(|source| {
ValidationError::ConstExpression { handle, source }
- .with_span_handle(handle, &module.const_expressions)
+ .with_span_handle(handle, &module.global_expressions)
})?
}
}
+ let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
+
if self.flags.contains(ValidationFlags::CONSTANTS) {
- for (handle, _) in module.const_expressions.iter() {
- self.validate_const_expression(handle, module.to_ctx(), &mod_info)
- .map_err(|source| {
- ValidationError::ConstExpression { handle, source }
- .with_span_handle(handle, &module.const_expressions)
- })?
+ for (handle, _) in module.global_expressions.iter() {
+ self.validate_const_expression(
+ handle,
+ module.to_ctx(),
+ &mod_info,
+ &global_expr_kind,
+ )
+ .map_err(|source| {
+ ValidationError::ConstExpression { handle, source }
+ .with_span_handle(handle, &module.global_expressions)
+ })?
}
for (handle, constant) in module.constants.iter() {
- self.validate_constant(handle, module.to_ctx(), &mod_info)
+ self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::Constant {
handle,
@@ -406,10 +590,22 @@ impl Validator {
.with_span_handle(handle, &module.constants)
})?
}
+
+ for (handle, override_) in module.overrides.iter() {
+ self.validate_override(handle, module.to_ctx(), &mod_info)
+ .map_err(|source| {
+ ValidationError::Override {
+ handle,
+ name: override_.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.overrides)
+ })?
+ }
}
for (var_handle, var) in module.global_variables.iter() {
- self.validate_global_var(var, module.to_ctx(), &mod_info)
+ self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::GlobalVariable {
handle: var_handle,
@@ -421,7 +617,7 @@ impl Validator {
}
for (handle, fun) in module.functions.iter() {
- match self.validate_function(fun, module, &mod_info, false) {
+ match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
Ok(info) => mod_info.functions.push(info),
Err(error) => {
return Err(error.and_then(|source| {
@@ -447,7 +643,7 @@ impl Validator {
.with_span()); // TODO: keep some EP span information?
}
- match self.validate_entry_point(ep, module, &mod_info) {
+ match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
Ok(info) => mod_info.entry_points.push(info),
Err(error) => {
return Err(error.and_then(|source| {
diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs
index b8eb618ed4..f5b9856074 100644
--- a/third_party/rust/naga/src/valid/type.rs
+++ b/third_party/rust/naga/src/valid/type.rs
@@ -63,6 +63,7 @@ bitflags::bitflags! {
}
#[derive(Clone, Copy, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum Disalignment {
#[error("The array stride {stride} is not a multiple of the required alignment {alignment}")]
ArrayStride { stride: u32, alignment: Alignment },
@@ -87,6 +88,7 @@ pub enum Disalignment {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum TypeError {
#[error("Capability {0:?} is required")]
MissingCapability(Capabilities),
@@ -326,7 +328,6 @@ impl super::Validator {
TypeFlags::DATA
| TypeFlags::SIZED
| TypeFlags::COPY
- | TypeFlags::HOST_SHAREABLE
| TypeFlags::ARGUMENT
| TypeFlags::CONSTRUCTIBLE
| shareable,