diff options
Diffstat (limited to 'third_party/rust/naga/src/valid/mod.rs')
-rw-r--r-- | third_party/rust/naga/src/valid/mod.rs | 228 |
1 files changed, 212 insertions, 16 deletions
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs index 5459434f33..a0057f39ac 100644 --- a/third_party/rust/naga/src/valid/mod.rs +++ b/third_party/rust/naga/src/valid/mod.rs @@ -12,7 +12,7 @@ mod r#type; use crate::{ arena::Handle, - proc::{LayoutError, Layouter, TypeResolution}, + proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution}, FastHashSet, }; use bit_set::BitSet; @@ -77,7 +77,7 @@ bitflags::bitflags! { #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[derive(Clone, Copy, Debug, Eq, PartialEq)] - pub struct Capabilities: u16 { + pub struct Capabilities: u32 { /// Support for [`AddressSpace:PushConstant`]. const PUSH_CONSTANT = 0x1; /// Float values with width = 8. @@ -110,6 +110,10 @@ bitflags::bitflags! { const CUBE_ARRAY_TEXTURES = 0x4000; /// Support for 64-bit signed and unsigned integers. const SHADER_INT64 = 0x8000; + /// Support for subgroup operations. + const SUBGROUP = 0x10000; + /// Support for subgroup barriers. + const SUBGROUP_BARRIER = 0x20000; } } @@ -120,6 +124,57 @@ impl Default for Capabilities { } bitflags::bitflags! { + /// Supported subgroup operations + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct SubgroupOperationSet: u8 { + /// Elect, Barrier + const BASIC = 1 << 0; + /// Any, All + const VOTE = 1 << 1; + /// reductions, scans + const ARITHMETIC = 1 << 2; + /// ballot, broadcast + const BALLOT = 1 << 3; + /// shuffle, shuffle xor + const SHUFFLE = 1 << 4; + /// shuffle up, down + const SHUFFLE_RELATIVE = 1 << 5; + // We don't support these operations yet + // /// Clustered + // const CLUSTERED = 1 << 6; + // /// Quad supported + // const QUAD_FRAGMENT_COMPUTE = 1 << 7; + // /// Quad supported in all stages + // const QUAD_ALL_STAGES = 1 << 8; + } +} + +impl super::SubgroupOperation { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::All | Self::Any => S::VOTE, + Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => { + S::ARITHMETIC + } + } + } +} + +impl super::GatherMode { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, + Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, + Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + } + } +} + +bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] @@ -131,7 +186,7 @@ bitflags::bitflags! { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { @@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo { pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, + subgroup_stages: ShaderStages, + subgroup_operations: SubgroupOperationSet, types: Vec<r#type::TypeInfo>, layouter: Layouter, location_mask: BitSet, @@ -174,10 +231,15 @@ pub struct Validator { switch_values: FastHashSet<crate::SwitchValue>, valid_expression_list: Vec<Handle<crate::Expression>>, valid_expression_set: BitSet, + override_ids: FastHashSet<u16>, + allow_overrides: bool, } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantError { + #[error("Initializer must be a const-expression")] + InitializerExprType, #[error("The type doesn't match the constant")] InvalidType, #[error("The type is not constructible")] @@ -185,6 +247,26 @@ pub enum ConstantError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum OverrideError { + #[error("Override name and ID are missing")] + MissingNameAndID, + #[error("Override ID must be unique")] + DuplicateID, + #[error("Initializer must be a const-expression or override-expression")] + InitializerExprType, + #[error("The type doesn't match the override")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, + #[error("The type is not a scalar")] + TypeNotScalar, + #[error("Override declarations are not allowed")] + NotAllowed, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ValidationError { #[error(transparent)] InvalidHandle(#[from] InvalidHandleError), @@ -207,6 +289,12 @@ pub enum ValidationError { name: String, source: ConstantError, }, + #[error("Override {handle:?} '{name}' is invalid")] + Override { + handle: Handle<crate::Override>, + name: String, + source: OverrideError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle<crate::GlobalVariable>, @@ -286,6 +374,8 @@ impl Validator { Validator { flags, capabilities, + subgroup_stages: ShaderStages::empty(), + subgroup_operations: SubgroupOperationSet::empty(), types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), @@ -293,9 +383,21 @@ impl Validator { switch_values: FastHashSet::default(), valid_expression_list: Vec::new(), valid_expression_set: BitSet::new(), + override_ids: FastHashSet::default(), + allow_overrides: true, } } + pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self { + self.subgroup_stages = stages; + self + } + + pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self { + self.subgroup_operations = operations; + self + } + /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); @@ -305,6 +407,7 @@ impl Validator { self.switch_values.clear(); self.valid_expression_list.clear(); self.valid_expression_set.clear(); + self.override_ids.clear(); } fn validate_constant( @@ -312,6 +415,7 @@ impl Validator { handle: Handle<crate::Constant>, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &ExpressionKindTracker, ) -> Result<(), ConstantError> { let con = &gctx.constants[handle]; @@ -320,6 +424,10 @@ impl Validator { return Err(ConstantError::NonConstructibleType); } + if !global_expr_kind.is_const(con.init) { + return Err(ConstantError::InitializerExprType); + } + let decl_ty = &gctx.types[con.ty].inner; let init_ty = mod_info[con.init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -329,11 +437,80 @@ impl Validator { Ok(()) } + fn validate_override( + &mut self, + handle: Handle<crate::Override>, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), OverrideError> { + if !self.allow_overrides { + return Err(OverrideError::NotAllowed); + } + + let o = &gctx.overrides[handle]; + + if o.name.is_none() && o.id.is_none() { + return Err(OverrideError::MissingNameAndID); + } + + if let Some(id) = o.id { + if !self.override_ids.insert(id) { + return Err(OverrideError::DuplicateID); + } + } + + let type_info = &self.types[o.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(OverrideError::NonConstructibleType); + } + + let decl_ty = &gctx.types[o.ty].inner; + match decl_ty { + &crate::TypeInner::Scalar(scalar) => match scalar { + crate::Scalar::BOOL + | crate::Scalar::I32 + | crate::Scalar::U32 + | crate::Scalar::F32 + | crate::Scalar::F64 => {} + _ => return Err(OverrideError::TypeNotScalar), + }, + _ => return Err(OverrideError::TypeNotScalar), + } + + if let Some(init) = o.init { + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(OverrideError::InvalidType); + } + } + + Ok(()) + } + /// Check the given module to be valid. pub fn validate( &mut self, module: &crate::Module, ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.allow_overrides = true; + self.validate_impl(module) + } + + /// Check the given module to be valid. + /// + /// With the additional restriction that overrides are not present. + pub fn validate_no_overrides( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.allow_overrides = false; + self.validate_impl(module) + } + + fn validate_impl( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { self.reset(); self.reset_types(module.types.len()); @@ -354,7 +531,7 @@ impl Validator { type_flags: Vec::with_capacity(module.types.len()), functions: Vec::with_capacity(module.functions.len()), entry_points: Vec::with_capacity(module.entry_points.len()), - const_expression_types: vec![placeholder; module.const_expressions.len()] + const_expression_types: vec![placeholder; module.global_expressions.len()] .into_boxed_slice(), }; @@ -376,27 +553,34 @@ impl Validator { { let t = crate::Arena::new(); let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]); - for (handle, _) in module.const_expressions.iter() { + for (handle, _) in module.global_expressions.iter() { mod_info .process_const_expression(handle, &resolve_context, module.to_ctx()) .map_err(|source| { ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) + .with_span_handle(handle, &module.global_expressions) })? } } + let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions); + if self.flags.contains(ValidationFlags::CONSTANTS) { - for (handle, _) in module.const_expressions.iter() { - self.validate_const_expression(handle, module.to_ctx(), &mod_info) - .map_err(|source| { - ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) - })? + for (handle, _) in module.global_expressions.iter() { + self.validate_const_expression( + handle, + module.to_ctx(), + &mod_info, + &global_expr_kind, + ) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.global_expressions) + })? } for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, module.to_ctx(), &mod_info) + self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::Constant { handle, @@ -406,10 +590,22 @@ impl Validator { .with_span_handle(handle, &module.constants) })? } + + for (handle, override_) in module.overrides.iter() { + self.validate_override(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Override { + handle, + name: override_.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })? + } } for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, module.to_ctx(), &mod_info) + self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::GlobalVariable { handle: var_handle, @@ -421,7 +617,7 @@ impl Validator { } for (handle, fun) in module.functions.iter() { - match self.validate_function(fun, module, &mod_info, false) { + match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { @@ -447,7 +643,7 @@ impl Validator { .with_span()); // TODO: keep some EP span information? } - match self.validate_entry_point(ep, module, &mod_info) { + match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { |