summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/valid/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/valid/mod.rs')
-rw-r--r--third_party/rust/naga/src/valid/mod.rs228
1 files changed, 212 insertions, 16 deletions
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs
index 5459434f33..a0057f39ac 100644
--- a/third_party/rust/naga/src/valid/mod.rs
+++ b/third_party/rust/naga/src/valid/mod.rs
@@ -12,7 +12,7 @@ mod r#type;
use crate::{
arena::Handle,
- proc::{LayoutError, Layouter, TypeResolution},
+ proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
FastHashSet,
};
use bit_set::BitSet;
@@ -77,7 +77,7 @@ bitflags::bitflags! {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
- pub struct Capabilities: u16 {
+ pub struct Capabilities: u32 {
/// Support for [`AddressSpace:PushConstant`].
const PUSH_CONSTANT = 0x1;
/// Float values with width = 8.
@@ -110,6 +110,10 @@ bitflags::bitflags! {
const CUBE_ARRAY_TEXTURES = 0x4000;
/// Support for 64-bit signed and unsigned integers.
const SHADER_INT64 = 0x8000;
+ /// Support for subgroup operations.
+ const SUBGROUP = 0x10000;
+ /// Support for subgroup barriers.
+ const SUBGROUP_BARRIER = 0x20000;
}
}
@@ -120,6 +124,57 @@ impl Default for Capabilities {
}
bitflags::bitflags! {
+ /// Supported subgroup operations
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+ pub struct SubgroupOperationSet: u8 {
+ /// Elect, Barrier
+ const BASIC = 1 << 0;
+ /// Any, All
+ const VOTE = 1 << 1;
+ /// reductions, scans
+ const ARITHMETIC = 1 << 2;
+ /// ballot, broadcast
+ const BALLOT = 1 << 3;
+ /// shuffle, shuffle xor
+ const SHUFFLE = 1 << 4;
+ /// shuffle up, down
+ const SHUFFLE_RELATIVE = 1 << 5;
+ // We don't support these operations yet
+ // /// Clustered
+ // const CLUSTERED = 1 << 6;
+ // /// Quad supported
+ // const QUAD_FRAGMENT_COMPUTE = 1 << 7;
+ // /// Quad supported in all stages
+ // const QUAD_ALL_STAGES = 1 << 8;
+ }
+}
+
+impl super::SubgroupOperation {
+ const fn required_operations(&self) -> SubgroupOperationSet {
+ use SubgroupOperationSet as S;
+ match *self {
+ Self::All | Self::Any => S::VOTE,
+ Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
+ S::ARITHMETIC
+ }
+ }
+ }
+}
+
+impl super::GatherMode {
+ const fn required_operations(&self) -> SubgroupOperationSet {
+ use SubgroupOperationSet as S;
+ match *self {
+ Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
+ Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
+ Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
+ }
+ }
+}
+
+bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
@@ -131,7 +186,7 @@ bitflags::bitflags! {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
@@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
pub struct Validator {
flags: ValidationFlags,
capabilities: Capabilities,
+ subgroup_stages: ShaderStages,
+ subgroup_operations: SubgroupOperationSet,
types: Vec<r#type::TypeInfo>,
layouter: Layouter,
location_mask: BitSet,
@@ -174,10 +231,15 @@ pub struct Validator {
switch_values: FastHashSet<crate::SwitchValue>,
valid_expression_list: Vec<Handle<crate::Expression>>,
valid_expression_set: BitSet,
+ override_ids: FastHashSet<u16>,
+ allow_overrides: bool,
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantError {
+ #[error("Initializer must be a const-expression")]
+ InitializerExprType,
#[error("The type doesn't match the constant")]
InvalidType,
#[error("The type is not constructible")]
@@ -185,6 +247,26 @@ pub enum ConstantError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum OverrideError {
+ #[error("Override name and ID are missing")]
+ MissingNameAndID,
+ #[error("Override ID must be unique")]
+ DuplicateID,
+ #[error("Initializer must be a const-expression or override-expression")]
+ InitializerExprType,
+ #[error("The type doesn't match the override")]
+ InvalidType,
+ #[error("The type is not constructible")]
+ NonConstructibleType,
+ #[error("The type is not a scalar")]
+ TypeNotScalar,
+ #[error("Override declarations are not allowed")]
+ NotAllowed,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ValidationError {
#[error(transparent)]
InvalidHandle(#[from] InvalidHandleError),
@@ -207,6 +289,12 @@ pub enum ValidationError {
name: String,
source: ConstantError,
},
+ #[error("Override {handle:?} '{name}' is invalid")]
+ Override {
+ handle: Handle<crate::Override>,
+ name: String,
+ source: OverrideError,
+ },
#[error("Global variable {handle:?} '{name}' is invalid")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
@@ -286,6 +374,8 @@ impl Validator {
Validator {
flags,
capabilities,
+ subgroup_stages: ShaderStages::empty(),
+ subgroup_operations: SubgroupOperationSet::empty(),
types: Vec::new(),
layouter: Layouter::default(),
location_mask: BitSet::new(),
@@ -293,9 +383,21 @@ impl Validator {
switch_values: FastHashSet::default(),
valid_expression_list: Vec::new(),
valid_expression_set: BitSet::new(),
+ override_ids: FastHashSet::default(),
+ allow_overrides: true,
}
}
+ pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
+ self.subgroup_stages = stages;
+ self
+ }
+
+ pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
+ self.subgroup_operations = operations;
+ self
+ }
+
/// Reset the validator internals
pub fn reset(&mut self) {
self.types.clear();
@@ -305,6 +407,7 @@ impl Validator {
self.switch_values.clear();
self.valid_expression_list.clear();
self.valid_expression_set.clear();
+ self.override_ids.clear();
}
fn validate_constant(
@@ -312,6 +415,7 @@ impl Validator {
handle: Handle<crate::Constant>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
+ global_expr_kind: &ExpressionKindTracker,
) -> Result<(), ConstantError> {
let con = &gctx.constants[handle];
@@ -320,6 +424,10 @@ impl Validator {
return Err(ConstantError::NonConstructibleType);
}
+ if !global_expr_kind.is_const(con.init) {
+ return Err(ConstantError::InitializerExprType);
+ }
+
let decl_ty = &gctx.types[con.ty].inner;
let init_ty = mod_info[con.init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
@@ -329,11 +437,80 @@ impl Validator {
Ok(())
}
+ fn validate_override(
+ &mut self,
+ handle: Handle<crate::Override>,
+ gctx: crate::proc::GlobalCtx,
+ mod_info: &ModuleInfo,
+ ) -> Result<(), OverrideError> {
+ if !self.allow_overrides {
+ return Err(OverrideError::NotAllowed);
+ }
+
+ let o = &gctx.overrides[handle];
+
+ if o.name.is_none() && o.id.is_none() {
+ return Err(OverrideError::MissingNameAndID);
+ }
+
+ if let Some(id) = o.id {
+ if !self.override_ids.insert(id) {
+ return Err(OverrideError::DuplicateID);
+ }
+ }
+
+ let type_info = &self.types[o.ty.index()];
+ if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
+ return Err(OverrideError::NonConstructibleType);
+ }
+
+ let decl_ty = &gctx.types[o.ty].inner;
+ match decl_ty {
+ &crate::TypeInner::Scalar(scalar) => match scalar {
+ crate::Scalar::BOOL
+ | crate::Scalar::I32
+ | crate::Scalar::U32
+ | crate::Scalar::F32
+ | crate::Scalar::F64 => {}
+ _ => return Err(OverrideError::TypeNotScalar),
+ },
+ _ => return Err(OverrideError::TypeNotScalar),
+ }
+
+ if let Some(init) = o.init {
+ let init_ty = mod_info[init].inner_with(gctx.types);
+ if !decl_ty.equivalent(init_ty, gctx.types) {
+ return Err(OverrideError::InvalidType);
+ }
+ }
+
+ Ok(())
+ }
+
/// Check the given module to be valid.
pub fn validate(
&mut self,
module: &crate::Module,
) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.allow_overrides = true;
+ self.validate_impl(module)
+ }
+
+ /// Check the given module to be valid.
+ ///
+ /// With the additional restriction that overrides are not present.
+ pub fn validate_no_overrides(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.allow_overrides = false;
+ self.validate_impl(module)
+ }
+
+ fn validate_impl(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
self.reset();
self.reset_types(module.types.len());
@@ -354,7 +531,7 @@ impl Validator {
type_flags: Vec::with_capacity(module.types.len()),
functions: Vec::with_capacity(module.functions.len()),
entry_points: Vec::with_capacity(module.entry_points.len()),
- const_expression_types: vec![placeholder; module.const_expressions.len()]
+ const_expression_types: vec![placeholder; module.global_expressions.len()]
.into_boxed_slice(),
};
@@ -376,27 +553,34 @@ impl Validator {
{
let t = crate::Arena::new();
let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
- for (handle, _) in module.const_expressions.iter() {
+ for (handle, _) in module.global_expressions.iter() {
mod_info
.process_const_expression(handle, &resolve_context, module.to_ctx())
.map_err(|source| {
ValidationError::ConstExpression { handle, source }
- .with_span_handle(handle, &module.const_expressions)
+ .with_span_handle(handle, &module.global_expressions)
})?
}
}
+ let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
+
if self.flags.contains(ValidationFlags::CONSTANTS) {
- for (handle, _) in module.const_expressions.iter() {
- self.validate_const_expression(handle, module.to_ctx(), &mod_info)
- .map_err(|source| {
- ValidationError::ConstExpression { handle, source }
- .with_span_handle(handle, &module.const_expressions)
- })?
+ for (handle, _) in module.global_expressions.iter() {
+ self.validate_const_expression(
+ handle,
+ module.to_ctx(),
+ &mod_info,
+ &global_expr_kind,
+ )
+ .map_err(|source| {
+ ValidationError::ConstExpression { handle, source }
+ .with_span_handle(handle, &module.global_expressions)
+ })?
}
for (handle, constant) in module.constants.iter() {
- self.validate_constant(handle, module.to_ctx(), &mod_info)
+ self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::Constant {
handle,
@@ -406,10 +590,22 @@ impl Validator {
.with_span_handle(handle, &module.constants)
})?
}
+
+ for (handle, override_) in module.overrides.iter() {
+ self.validate_override(handle, module.to_ctx(), &mod_info)
+ .map_err(|source| {
+ ValidationError::Override {
+ handle,
+ name: override_.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.overrides)
+ })?
+ }
}
for (var_handle, var) in module.global_variables.iter() {
- self.validate_global_var(var, module.to_ctx(), &mod_info)
+ self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::GlobalVariable {
handle: var_handle,
@@ -421,7 +617,7 @@ impl Validator {
}
for (handle, fun) in module.functions.iter() {
- match self.validate_function(fun, module, &mod_info, false) {
+ match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
Ok(info) => mod_info.functions.push(info),
Err(error) => {
return Err(error.and_then(|source| {
@@ -447,7 +643,7 @@ impl Validator {
.with_span()); // TODO: keep some EP span information?
}
- match self.validate_entry_point(ep, module, &mod_info) {
+ match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
Ok(info) => mod_info.entry_points.push(info),
Err(error) => {
return Err(error.and_then(|source| {