diff options
Diffstat (limited to 'third_party/rust/naga/src')
63 files changed, 4815 insertions, 843 deletions
diff --git a/third_party/rust/naga/src/arena.rs b/third_party/rust/naga/src/arena.rs index c37538667f..740df85b86 100644 --- a/third_party/rust/naga/src/arena.rs +++ b/third_party/rust/naga/src/arena.rs @@ -122,6 +122,7 @@ impl<T> Handle<T> { serde(transparent) )] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct Range<T> { inner: ops::Range<u32>, #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] @@ -140,6 +141,7 @@ impl<T> Range<T> { // NOTE: Keep this diagnostic in sync with that of [`BadHandle`]. #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] #[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")] pub struct BadRangeError { // This error is used for many `Handle` types, but there's no point in making this generic, so @@ -239,7 +241,7 @@ impl<T> Range<T> { /// Adding new items to the arena produces a strongly-typed [`Handle`]. /// The arena can be indexed using the given handle to obtain /// a reference to the stored item. -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "serialize", serde(transparent))] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] @@ -297,6 +299,17 @@ impl<T> Arena<T> { .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) }) } + /// Drains the arena, returning an iterator over the items stored. + pub fn drain(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, T, Span)> { + let arena = std::mem::take(self); + arena + .data + .into_iter() + .zip(arena.span_info) + .enumerate() + .map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) }) + } + /// Returns a iterator over the items stored in this arena, /// returning both the item's handle and a mutable reference to it. pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> { @@ -531,7 +544,7 @@ mod tests { /// /// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like, /// `UniqueArena` is `HashSet`-like. -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Clone)] pub struct UniqueArena<T> { set: FastIndexSet<T>, diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs index 1556371df1..9a7702b3f6 100644 --- a/third_party/rust/naga/src/back/dot/mod.rs +++ b/third_party/rust/naga/src/back/dot/mod.rs @@ -279,6 +279,94 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.dependencies.push((id, predicate, "predicate")); + } + self.emits.push((id, result)); + "SubgroupBallot" + } + S::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupInclusiveMul", + _ => unimplemented!(), + } + } + S::SubgroupGather { + mode, + argument, + result, + } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -404,6 +492,7 @@ fn write_function_expressions( let (label, color_id) = match *expression { E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), + E::Override(_) => ("Override".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); @@ -586,6 +675,8 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{}Intersection", ty).into(), 4) } + E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), + E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4), }; // give uniform expressions an outline diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs index 99c128c6d9..e5a43f3e02 100644 --- a/third_party/rust/naga/src/back/glsl/features.rs +++ b/third_party/rust/naga/src/back/glsl/features.rs @@ -50,6 +50,8 @@ bitflags::bitflags! { const INSTANCE_INDEX = 1 << 22; /// Sample specific LODs of cube / array shadow textures const TEXTURE_SHADOW_LOD = 1 << 23; + /// Subgroup operations + const SUBGROUP_OPERATIONS = 1 << 24; } } @@ -117,6 +119,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + check_feature!(SUBGROUP_OPERATIONS, 430, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), @@ -259,6 +262,22 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?; } + if self.0.contains(Features::SUBGROUP_OPERATIONS) { + // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt + writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_arithmetic : require" + )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_shuffle_relative : require" + )?; + } + Ok(()) } } @@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::SubgroupBallotResult | + Expression::SubgroupOperationResult { .. } => { + features.request(Features::SUBGROUP_OPERATIONS) + } _ => {} } } diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs index 9bda594610..c8c7ea557d 100644 --- a/third_party/rust/naga/src/back/glsl/mod.rs +++ b/third_party/rust/naga/src/back/glsl/mod.rs @@ -282,7 +282,7 @@ impl Default for Options { } /// A subset of options meant to be changed per pipeline. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { @@ -497,6 +497,8 @@ pub enum Error { ImageMultipleSamplers, #[error("{0}")] Custom(String), + #[error("overrides should not be present at this stage")] + Override, } /// Binary operation with a different logic on the GLSL side. @@ -565,6 +567,10 @@ impl<'a, W: Write> Writer<'a, W> { pipeline_options: &'a PipelineOptions, policies: proc::BoundsCheckPolicies, ) -> Result<Self, Error> { + if !module.overrides.is_empty() { + return Err(Error::Override); + } + // Check if the requested version is supported if !options.version.is_supported() { log::error!("Version {}", options.version); @@ -2384,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + write!(self.out, "subgroupBallot(")?; + match predicate { + Some(predicate) => self.write_expr(predicate, ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -2402,7 +2527,7 @@ impl<'a, W: Write> Writer<'a, W> { fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult { self.write_possibly_const_expr( expr, - &self.module.const_expressions, + &self.module.global_expressions, |expr| &self.info[expr], |writer, expr| writer.write_const_expr(expr), ) @@ -2536,6 +2661,7 @@ impl<'a, W: Write> Writer<'a, W> { |writer, expr| writer.write_expr(expr, ctx), )?; } + Expression::Override(_) => return Err(Error::Override), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; @@ -3411,7 +3537,8 @@ impl<'a, W: Write> Writer<'a, W> { let scalar_bits = ctx .resolve_type(arg, &self.module.types) .scalar_width() - .unwrap(); + .unwrap() + * 8; write!(self.out, "bitfieldExtract(")?; self.write_expr(arg, ctx)?; @@ -3430,7 +3557,8 @@ impl<'a, W: Write> Writer<'a, W> { let scalar_bits = ctx .resolve_type(arg, &self.module.types) .scalar_width() - .unwrap(); + .unwrap() + * 8; write!(self.out, "bitfieldInsert(")?; self.write_expr(arg, ctx)?; @@ -3649,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult - | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupOperationResult { .. } + | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; @@ -4218,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> { if flags.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}memoryBarrierShared();")?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; + } writeln!(self.out, "{level}barrier();")?; Ok(()) } @@ -4487,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s Bi::WorkGroupId => "gl_WorkGroupID", Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", + // subgroup + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", + Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs index 2a6db35db8..7d15f43f6c 100644 --- a/third_party/rust/naga/src/back/hlsl/conv.rs +++ b/third_party/rust/naga/src/back/hlsl/conv.rs @@ -179,6 +179,11 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", + // These builtins map to functions + Self::SubgroupSize + | Self::SubgroupInvocationId + | Self::NumSubgroups + | Self::SubgroupId => unreachable!(), Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { return Err(Error::Unimplemented(format!("builtin {self:?}"))) } diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs index 4dd9ea5987..d3bb1ce7f5 100644 --- a/third_party/rust/naga/src/back/hlsl/help.rs +++ b/third_party/rust/naga/src/back/hlsl/help.rs @@ -70,6 +70,11 @@ pub(super) struct WrappedMath { pub(super) components: Option<u32>, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedZeroValue { + pub(super) ty: Handle<crate::Type>, +} + /// HLSL backend requires its own `ImageQuery` enum. /// /// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. @@ -359,7 +364,7 @@ impl<'a, W: Write> super::Writer<'a, W> { } /// Helper function that write wrapped function for `Expression::Compose` for structures. - pub(super) fn write_wrapped_constructor_function( + fn write_wrapped_constructor_function( &mut self, module: &crate::Module, constructor: WrappedConstructor, @@ -862,6 +867,25 @@ impl<'a, W: Write> super::Writer<'a, W> { Ok(()) } + // TODO: we could merge this with iteration in write_wrapped_compose_functions... + // + /// Helper function that writes zero value wrapped functions + pub(super) fn write_wrapped_zero_value_functions( + &mut self, + module: &crate::Module, + expressions: &crate::Arena<crate::Expression>, + ) -> BackendResult { + for (handle, _) in expressions.iter() { + if let crate::Expression::ZeroValue(ty) = expressions[handle] { + let zero_value = WrappedZeroValue { ty }; + if self.wrapped.zero_values.insert(zero_value) { + self.write_wrapped_zero_value_function(module, zero_value)?; + } + } + } + Ok(()) + } + pub(super) fn write_wrapped_math_functions( &mut self, module: &crate::Module, @@ -1006,6 +1030,7 @@ impl<'a, W: Write> super::Writer<'a, W> { ) -> BackendResult { self.write_wrapped_math_functions(module, func_ctx)?; self.write_wrapped_compose_functions(module, func_ctx.expressions)?; + self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?; for (handle, _) in func_ctx.expressions.iter() { match func_ctx.expressions[handle] { @@ -1283,4 +1308,71 @@ impl<'a, W: Write> super::Writer<'a, W> { Ok(()) } + + pub(super) fn write_wrapped_zero_value_function_name( + &mut self, + module: &crate::Module, + zero_value: WrappedZeroValue, + ) -> BackendResult { + let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?; + write!(self.out, "ZeroValue{name}")?; + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::ZeroValue` + /// + /// This is necessary since we might have a member access after the zero value expression, e.g. + /// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc). + /// + /// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly. + /// + /// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle + /// cases like: + /// + /// ```ignore + /// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet + /// t_1.am = (__mat4x2[2])((float4x2[2])0); + /// ^ + /// ``` + fn write_wrapped_zero_value_function( + &mut self, + module: &crate::Module, + zero_value: WrappedZeroValue, + ) -> BackendResult { + use crate::back::INDENT; + + const RETURN_VARIABLE_NAME: &str = "ret"; + + // Write function return type and name + if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner { + write!(self.out, "typedef ")?; + self.write_type(module, zero_value.ty)?; + write!(self.out, " ret_")?; + self.write_wrapped_zero_value_function_name(module, zero_value)?; + self.write_array_size(module, base, size)?; + writeln!(self.out, ";")?; + + write!(self.out, "ret_")?; + self.write_wrapped_zero_value_function_name(module, zero_value)?; + } else { + self.write_type(module, zero_value.ty)?; + } + write!(self.out, " ")?; + self.write_wrapped_zero_value_function_name(module, zero_value)?; + + // Write function parameters (none) and start function body + writeln!(self.out, "() {{")?; + + // Write `ZeroValue` function. + write!(self.out, "{INDENT}return ")?; + self.write_default_init(module, zero_value.ty)?; + writeln!(self.out, ";")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } } diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs index f37a223f47..28edbf70e1 100644 --- a/third_party/rust/naga/src/back/hlsl/mod.rs +++ b/third_party/rust/naga/src/back/hlsl/mod.rs @@ -131,6 +131,13 @@ pub enum ShaderModel { V5_0, V5_1, V6_0, + V6_1, + V6_2, + V6_3, + V6_4, + V6_5, + V6_6, + V6_7, } impl ShaderModel { @@ -139,6 +146,13 @@ impl ShaderModel { Self::V5_0 => "5_0", Self::V5_1 => "5_1", Self::V6_0 => "6_0", + Self::V6_1 => "6_1", + Self::V6_2 => "6_2", + Self::V6_3 => "6_3", + Self::V6_4 => "6_4", + Self::V6_5 => "6_5", + Self::V6_6 => "6_6", + Self::V6_7 => "6_7", } } } @@ -247,10 +261,13 @@ pub enum Error { Unimplemented(String), // TODO: Error used only during development #[error("{0}")] Custom(String), + #[error("overrides should not be present at this stage")] + Override, } #[derive(Default)] struct Wrapped { + zero_values: crate::FastHashSet<help::WrappedZeroValue>, array_lengths: crate::FastHashSet<help::WrappedArrayLength>, image_queries: crate::FastHashSet<help::WrappedImageQuery>, constructors: crate::FastHashSet<help::WrappedConstructor>, diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs index 4ba856946b..86d8f89035 100644 --- a/third_party/rust/naga/src/back/hlsl/writer.rs +++ b/third_party/rust/naga/src/back/hlsl/writer.rs @@ -1,5 +1,8 @@ use super::{ - help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess}, + help::{ + WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess, + WrappedZeroValue, + }, storage::StoreValue, BackendResult, Error, Options, }; @@ -77,6 +80,19 @@ enum Io { Output, } +const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool { + let &Some(crate::Binding::BuiltIn(builtin)) = binding else { + return false; + }; + matches!( + builtin, + crate::BuiltIn::SubgroupSize + | crate::BuiltIn::SubgroupInvocationId + | crate::BuiltIn::NumSubgroups + | crate::BuiltIn::SubgroupId + ) +} + impl<'a, W: fmt::Write> super::Writer<'a, W> { pub fn new(out: W, options: &'a Options) -> Self { Self { @@ -161,6 +177,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } } + for statement in func.body.iter() { + match *statement { + crate::Statement::SubgroupCollectiveOperation { + op: _, + collective_op: crate::CollectiveOperation::InclusiveScan, + argument, + result: _, + } => { + self.need_bake_expressions.insert(argument); + } + _ => {} + } + } } pub fn write( @@ -168,6 +197,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { module: &Module, module_info: &valid::ModuleInfo, ) -> Result<super::ReflectionInfo, Error> { + if !module.overrides.is_empty() { + return Err(Error::Override); + } + self.reset(module); // Write special constants, if needed @@ -233,7 +266,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_special_functions(module)?; - self.write_wrapped_compose_functions(module, &module.const_expressions)?; + self.write_wrapped_compose_functions(module, &module.global_expressions)?; + self.write_wrapped_zero_value_functions(module, &module.global_expressions)?; // Write all named constants let mut constants = module @@ -397,31 +431,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // if they are struct, so that the `stage` argument here could be omitted. fn write_semantic( &mut self, - binding: &crate::Binding, + binding: &Option<crate::Binding>, stage: Option<(ShaderStage, Io)>, ) -> BackendResult { match *binding { - crate::Binding::BuiltIn(builtin) => { + Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => { let builtin_str = builtin.to_hlsl_str()?; write!(self.out, " : {builtin_str}")?; } - crate::Binding::Location { + Some(crate::Binding::Location { second_blend_source: true, .. - } => { + }) => { write!(self.out, " : SV_Target1")?; } - crate::Binding::Location { + Some(crate::Binding::Location { location, second_blend_source: false, .. - } => { + }) => { if stage == Some((crate::ShaderStage::Fragment, Io::Output)) { write!(self.out, " : SV_Target{location}")?; } else { write!(self.out, " : {LOCATION_SEMANTIC}{location}")?; } } + _ => {} } Ok(()) @@ -442,17 +477,30 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "struct {struct_name}")?; writeln!(self.out, " {{")?; for m in members.iter() { + if is_subgroup_builtin_binding(&m.binding) { + continue; + } write!(self.out, "{}", back::INDENT)?; if let Some(ref binding) = m.binding { self.write_modifier(binding)?; } self.write_type(module, m.ty)?; write!(self.out, " {}", &m.name)?; - if let Some(ref binding) = m.binding { - self.write_semantic(binding, Some(shader_stage))?; - } + self.write_semantic(&m.binding, Some(shader_stage))?; writeln!(self.out, ";")?; } + if members.iter().any(|arg| { + matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) + ) + }) { + writeln!( + self.out, + "{}uint __local_invocation_index : SV_GroupIndex;", + back::INDENT + )?; + } writeln!(self.out, "}};")?; writeln!(self.out)?; @@ -553,8 +601,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } /// Writes special interface structures for an entry point. The special structures have - /// all the fields flattened into them and sorted by binding. They are only needed for - /// VS outputs and FS inputs, so that these interfaces match. + /// all the fields flattened into them and sorted by binding. They are needed to emulate + /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match. fn write_ep_interface( &mut self, module: &Module, @@ -563,7 +611,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { ep_name: &str, ) -> Result<EntryPointInterface, Error> { Ok(EntryPointInterface { - input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment { + input: if !func.arguments.is_empty() + && (stage == ShaderStage::Fragment + || func + .arguments + .iter() + .any(|arg| is_subgroup_builtin_binding(&arg.binding))) + { Some(self.write_ep_input_struct(module, func, stage, ep_name)?) } else { None @@ -577,6 +631,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }) } + fn write_ep_argument_initialization( + &mut self, + ep: &crate::EntryPoint, + ep_input: &EntryPointBinding, + fake_member: &EpStructMember, + ) -> BackendResult { + match fake_member.binding { + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => { + write!(self.out, "WaveGetLaneCount()")? + } + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => { + write!(self.out, "WaveGetLaneIndex()")? + } + Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!( + self.out, + "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()", + ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] + )?, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => { + write!( + self.out, + "{}.__local_invocation_index / WaveGetLaneCount()", + ep_input.arg_name + )?; + } + _ => { + write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?; + } + } + Ok(()) + } + /// Write an entry point preface that initializes the arguments as specified in IR. fn write_ep_arguments_initialization( &mut self, @@ -584,6 +670,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { func: &crate::Function, ep_index: u16, ) -> BackendResult { + let ep = &module.entry_points[ep_index as usize]; let ep_input = match self.entry_point_io[ep_index as usize].input.take() { Some(ep_input) => ep_input, None => return Ok(()), @@ -597,8 +684,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { match module.types[arg.ty].inner { TypeInner::Array { base, size, .. } => { self.write_array_size(module, base, size)?; - let fake_member = fake_iter.next().unwrap(); - writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + write!(self.out, " = ")?; + self.write_ep_argument_initialization( + ep, + &ep_input, + fake_iter.next().unwrap(), + )?; + writeln!(self.out, ";")?; } TypeInner::Struct { ref members, .. } => { write!(self.out, " = {{ ")?; @@ -606,14 +698,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if index != 0 { write!(self.out, ", ")?; } - let fake_member = fake_iter.next().unwrap(); - write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?; + self.write_ep_argument_initialization( + ep, + &ep_input, + fake_iter.next().unwrap(), + )?; } writeln!(self.out, " }};")?; } _ => { - let fake_member = fake_iter.next().unwrap(); - writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + write!(self.out, " = ")?; + self.write_ep_argument_initialization( + ep, + &ep_input, + fake_iter.next().unwrap(), + )?; + writeln!(self.out, ";")?; } } } @@ -928,9 +1028,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } - if let Some(ref binding) = member.binding { - self.write_semantic(binding, shader_stage)?; - }; + self.write_semantic(&member.binding, shader_stage)?; writeln!(self.out, ";")?; } @@ -1143,7 +1241,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } back::FunctionType::EntryPoint(ep_index) => { if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input { - write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; + write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?; } else { let stage = module.entry_points[ep_index as usize].stage; for (index, arg) in func.arguments.iter().enumerate() { @@ -1160,17 +1258,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_array_size(module, base, size)?; } - if let Some(ref binding) = arg.binding { - self.write_semantic(binding, Some((stage, Io::Input)))?; - } + self.write_semantic(&arg.binding, Some((stage, Io::Input)))?; } - - if need_workgroup_variables_initialization { - if !func.arguments.is_empty() { - write!(self.out, ", ")?; - } - write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; + } + if need_workgroup_variables_initialization { + if self.entry_point_io[ep_index as usize].input.is_some() + || !func.arguments.is_empty() + { + write!(self.out, ", ")?; } + write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; } } } @@ -1180,11 +1277,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // Write semantic if it present if let back::FunctionType::EntryPoint(index) = func_ctx.ty { let stage = module.entry_points[index as usize].stage; - if let Some(crate::FunctionResult { - binding: Some(ref binding), - .. - }) = func.result - { + if let Some(crate::FunctionResult { ref binding, .. }) = func.result { self.write_semantic(binding, Some((stage, Io::Output)))?; } } @@ -1984,6 +2077,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "const uint4 {name} = ")?; + self.named_expressions.insert(result, name); + + write!(self.out, "WaveActiveBallot(")?; + match predicate { + Some(predicate) => self.write_expr(module, predicate, func_ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -1997,7 +2213,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_possibly_const_expression( module, expr, - &module.const_expressions, + &module.global_expressions, |writer, expr| writer.write_const_expression(module, expr), ) } @@ -2039,7 +2255,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_const_expression(module, constant.init)?; } } - Expression::ZeroValue(ty) => self.write_default_init(module, ty)?, + Expression::ZeroValue(ty) => { + self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?; + write!(self.out, "()")?; + } Expression::Compose { ty, ref components } => { match module.types[ty].inner { TypeInner::Struct { .. } | TypeInner::Array { .. } => { @@ -2140,6 +2359,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => return Err(Error::Override), // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { @@ -2588,7 +2808,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { true } None => { - if inner.scalar_width() == Some(64) { + if inner.scalar_width() == Some(8) { false } else { write!(self.out, "{}(", kind.to_hlsl_cast(),)?; @@ -3129,7 +3349,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } - | Expression::RayQueryProceedResult => {} + | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} } if !closing_bracket.is_empty() { @@ -3179,7 +3401,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } /// Helper function that write default zero initialization - fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult { + pub(super) fn write_default_init( + &mut self, + module: &Module, + ty: Handle<crate::Type>, + ) -> BackendResult { write!(self.out, "(")?; self.write_type(module, ty)?; if let TypeInner::Array { base, size, .. } = module.types[ty].inner { @@ -3196,6 +3422,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } + if barrier.contains(crate::Barrier::SUB_GROUP) { + // Does not exist in DirectX + } Ok(()) } } diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs index c8f091decb..0c9c5e4761 100644 --- a/third_party/rust/naga/src/back/mod.rs +++ b/third_party/rust/naga/src/back/mod.rs @@ -16,6 +16,14 @@ pub mod spv; #[cfg(feature = "wgsl-out")] pub mod wgsl; +#[cfg(any( + feature = "hlsl-out", + feature = "msl-out", + feature = "spv-out", + feature = "glsl-out" +))] +pub mod pipeline_constants; + /// Names of vector components. pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; /// Indent for backends. @@ -26,6 +34,15 @@ pub const BAKE_PREFIX: &str = "_e"; /// Expressions that need baking. pub type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>; +/// Specifies the values of pipeline-overridable constants in the shader module. +/// +/// If an `@id` attribute was specified on the declaration, +/// the key must be the pipeline constant ID as a decimal ASCII number; if not, +/// the key must be the constant's identifier name. +/// +/// The value may represent any of WGSL's concrete scalar types. +pub type PipelineConstants = std::collections::HashMap<String, f64>; + /// Indentation level. #[derive(Clone, Copy)] pub struct Level(pub usize); diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs index 68e5b79906..8b03e20376 100644 --- a/third_party/rust/naga/src/back/msl/mod.rs +++ b/third_party/rust/naga/src/back/msl/mod.rs @@ -143,6 +143,8 @@ pub enum Error { UnsupportedArrayOfType(Handle<crate::Type>), #[error("ray tracing is not supported prior to MSL 2.3")] UnsupportedRayTracing, + #[error("overrides should not be present at this stage")] + Override, } #[derive(Clone, Debug, PartialEq, thiserror::Error)] @@ -221,7 +223,7 @@ impl Default for Options { } /// A subset of options that are meant to be changed per pipeline. -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { @@ -434,6 +436,11 @@ impl ResolvedBinding { Bi::WorkGroupId => "threadgroup_position_in_grid", Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", + // subgroup + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } @@ -536,3 +543,21 @@ fn test_error_size() { use std::mem::size_of; assert_eq!(size_of::<Error>(), 32); } + +impl crate::AtomicFunction { + fn to_msl(self) -> Result<&'static str, Error> { + Ok(match self { + Self::Add => "fetch_add", + Self::Subtract => "fetch_sub", + Self::And => "fetch_and", + Self::InclusiveOr => "fetch_or", + Self::ExclusiveOr => "fetch_xor", + Self::Min => "fetch_min", + Self::Max => "fetch_max", + Self::Exchange { compare: None } => "exchange", + Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented( + "atomic CompareExchange".to_string(), + ))?, + }) + } +} diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs index 5227d8e7db..e250d0b72c 100644 --- a/third_party/rust/naga/src/back/msl/writer.rs +++ b/third_party/rust/naga/src/back/msl/writer.rs @@ -1131,21 +1131,10 @@ impl<W: Write> Writer<W> { Ok(()) } - fn put_atomic_fetch( - &mut self, - pointer: Handle<crate::Expression>, - key: &str, - value: Handle<crate::Expression>, - context: &ExpressionContext, - ) -> BackendResult { - self.put_atomic_operation(pointer, "fetch_", key, value, context) - } - fn put_atomic_operation( &mut self, pointer: Handle<crate::Expression>, - key1: &str, - key2: &str, + key: &str, value: Handle<crate::Expression>, context: &ExpressionContext, ) -> BackendResult { @@ -1163,7 +1152,7 @@ impl<W: Write> Writer<W> { write!( self.out, - "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}" + "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}" )?; self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; @@ -1248,7 +1237,7 @@ impl<W: Write> Writer<W> { ) -> BackendResult { self.put_possibly_const_expression( expr_handle, - &module.const_expressions, + &module.global_expressions, module, mod_info, &(module, mod_info), @@ -1431,6 +1420,7 @@ impl<W: Write> Writer<W> { |writer, context, expr| writer.put_expression(expr, context, true), )?; } + crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. @@ -1944,7 +1934,7 @@ impl<W: Write> Writer<W> { // // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) - let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; write!(self.out, "{NAMESPACE}::extract_bits(")?; self.put_expression(arg, context, true)?; @@ -1960,7 +1950,7 @@ impl<W: Write> Writer<W> { // // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w)))) - let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; write!(self.out, "{NAMESPACE}::insert_bits(")?; self.put_expression(arg, context, true)?; @@ -2041,6 +2031,8 @@ impl<W: Write> Writer<W> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -2994,43 +2986,8 @@ impl<W: Write> Writer<W> { let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); self.start_baking_expression(result, &context.expression, &res_name)?; self.named_expressions.insert(result, res_name); - match *fun { - crate::AtomicFunction::Add => { - self.put_atomic_fetch(pointer, "add", value, &context.expression)?; - } - crate::AtomicFunction::Subtract => { - self.put_atomic_fetch(pointer, "sub", value, &context.expression)?; - } - crate::AtomicFunction::And => { - self.put_atomic_fetch(pointer, "and", value, &context.expression)?; - } - crate::AtomicFunction::InclusiveOr => { - self.put_atomic_fetch(pointer, "or", value, &context.expression)?; - } - crate::AtomicFunction::ExclusiveOr => { - self.put_atomic_fetch(pointer, "xor", value, &context.expression)?; - } - crate::AtomicFunction::Min => { - self.put_atomic_fetch(pointer, "min", value, &context.expression)?; - } - crate::AtomicFunction::Max => { - self.put_atomic_fetch(pointer, "max", value, &context.expression)?; - } - crate::AtomicFunction::Exchange { compare: None } => { - self.put_atomic_operation( - pointer, - "exchange", - "", - value, - &context.expression, - )?; - } - crate::AtomicFunction::Exchange { .. } => { - return Err(Error::FeatureNotImplemented( - "atomic CompareExchange".to_string(), - )); - } - } + let fun_str = fun.to_msl()?; + self.put_atomic_operation(pointer, fun_str, value, &context.expression)?; // done writeln!(self.out, ";")?; } @@ -3144,6 +3101,121 @@ impl<W: Write> Writer<W> { } } } + crate::Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; + } + writeln!(self.out, "), 0, 0, 0);")?; + } + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; + } } } @@ -3220,6 +3292,10 @@ impl<W: Write> Writer<W> { options: &Options, pipeline_options: &PipelineOptions, ) -> Result<TranslationInfo, Error> { + if !module.overrides.is_empty() { + return Err(Error::Override); + } + self.names.clear(); self.namer.reset( module, @@ -4487,6 +4563,12 @@ impl<W: Write> Writer<W> { "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; + } Ok(()) } } @@ -4757,8 +4839,8 @@ fn test_stack_size() { } let stack_size = addresses_end - addresses_start; // check the size (in debug only) - // last observed macOS value: 19152 (CI) - if !(9000..=20000).contains(&stack_size) { + // last observed macOS value: 22256 (CI) + if !(15000..=25000).contains(&stack_size) { panic!("`put_block` stack size {stack_size} has changed!"); } } diff --git a/third_party/rust/naga/src/back/pipeline_constants.rs b/third_party/rust/naga/src/back/pipeline_constants.rs new file mode 100644 index 0000000000..0dbe9cf4e8 --- /dev/null +++ b/third_party/rust/naga/src/back/pipeline_constants.rs @@ -0,0 +1,957 @@ +use super::PipelineConstants; +use crate::{ + proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, + valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, + Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, + Span, Statement, TypeInner, WithSpan, +}; +use std::{borrow::Cow, collections::HashSet, mem}; +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum PipelineConstantError { + #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")] + MissingValue(String), + #[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")] + SrcNeedsToBeFinite, + #[error("Source f64 value doesn't fit in destination")] + DstRangeTooSmall, + #[error(transparent)] + ConstantEvaluatorError(#[from] ConstantEvaluatorError), + #[error(transparent)] + ValidationError(#[from] WithSpan<ValidationError>), +} + +/// Replace all overrides in `module` with constants. +/// +/// If no changes are needed, this just returns `Cow::Borrowed` +/// references to `module` and `module_info`. Otherwise, it clones +/// `module`, edits its [`global_expressions`] arena to contain only +/// fully-evaluated expressions, and returns `Cow::Owned` values +/// holding the simplified module and its validation results. +/// +/// In either case, the module returned has an empty `overrides` +/// arena, and the `global_expressions` arena contains only +/// fully-evaluated expressions. +/// +/// [`global_expressions`]: Module::global_expressions +pub fn process_overrides<'a>( + module: &'a Module, + module_info: &'a ModuleInfo, + pipeline_constants: &PipelineConstants, +) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> { + if module.overrides.is_empty() { + return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info))); + } + + let mut module = module.clone(); + + // A map from override handles to the handles of the constants + // we've replaced them with. + let mut override_map = Vec::with_capacity(module.overrides.len()); + + // A map from `module`'s original global expression handles to + // handles in the new, simplified global expression arena. + let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len()); + + // The set of constants whose initializer handles we've already + // updated to refer to the newly built global expression arena. + // + // All constants in `module` must have their `init` handles + // updated to point into the new, simplified global expression + // arena. Some of these we can most easily handle as a side effect + // during the simplification process, but we must handle the rest + // in a final fixup pass, guided by `adjusted_global_expressions`. We + // add their handles to this set, so that the final fixup step can + // leave them alone. + let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); + + let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + + // An iterator through the original overrides table, consumed in + // approximate tandem with the global expressions. + let mut override_iter = module.overrides.drain(); + + // Do two things in tandem: + // + // - Rebuild the global expression arena from scratch, fully + // evaluating all expressions, and replacing each `Override` + // expression in `module.global_expressions` with a `Constant` + // expression. + // + // - Build a new `Constant` in `module.constants` to take the + // place of each `Override`. + // + // Build a map from old global expression handles to their + // fully-evaluated counterparts in `adjusted_global_expressions` as we + // go. + // + // Why in tandem? Overrides refer to expressions, and expressions + // refer to overrides, so we can't disentangle the two into + // separate phases. However, we can take advantage of the fact + // that the overrides and expressions must form a DAG, and work + // our way from the leaves to the roots, replacing and evaluating + // as we go. + // + // Although the two loops are nested, this is really two + // alternating phases: we adjust and evaluate constant expressions + // until we hit an `Override` expression, at which point we switch + // to building `Constant`s for `Overrides` until we've handled the + // one used by the expression. Then we switch back to processing + // expressions. Because we know they form a DAG, we know the + // `Override` expressions we encounter can only have initializers + // referring to global expressions we've already simplified. + for (old_h, expr, span) in module.global_expressions.drain() { + let mut expr = match expr { + Expression::Override(h) => { + let c_h = if let Some(new_h) = override_map.get(h.index()) { + *new_h + } else { + let mut new_h = None; + for entry in override_iter.by_ref() { + let stop = entry.0 == h; + new_h = Some(process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_global_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?); + if stop { + break; + } + } + new_h.unwrap() + }; + Expression::Constant(c_h) + } + Expression::Constant(c_h) => { + if adjusted_constant_initializers.insert(c_h) { + let init = &mut module.constants[c_h].init; + *init = adjusted_global_expressions[init.index()]; + } + expr + } + expr => expr, + }; + let mut evaluator = ConstantEvaluator::for_wgsl_module( + &mut module, + &mut global_expression_kind_tracker, + false, + ); + adjust_expr(&adjusted_global_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_global_expressions.len()); + adjusted_global_expressions.push(h); + } + + // Finish processing any overrides we didn't visit in the loop above. + for entry in override_iter { + process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_global_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?; + } + + // Update the initialization expression handles of all `Constant`s + // and `GlobalVariable`s. Skip `Constant`s we'd already updated en + // passant. + for (_, c) in module + .constants + .iter_mut() + .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h)) + { + c.init = adjusted_global_expressions[c.init.index()]; + } + + for (_, v) in module.global_variables.iter_mut() { + if let Some(ref mut init) = v.init { + *init = adjusted_global_expressions[init.index()]; + } + } + + let mut functions = mem::take(&mut module.functions); + for (_, function) in functions.iter_mut() { + process_function(&mut module, &override_map, function)?; + } + module.functions = functions; + + let mut entry_points = mem::take(&mut module.entry_points); + for ep in entry_points.iter_mut() { + process_function(&mut module, &override_map, &mut ep.function)?; + } + module.entry_points = entry_points; + + // Now that we've rewritten all the expressions, we need to + // recompute their types and other metadata. For the time being, + // do a full re-validation. + let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); + let module_info = validator.validate_no_overrides(&module)?; + + Ok((Cow::Owned(module), Cow::Owned(module_info))) +} + +/// Add a [`Constant`] to `module` for the override `old_h`. +/// +/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. +fn process_override( + (old_h, override_, span): (Handle<Override>, Override, Span), + pipeline_constants: &PipelineConstants, + module: &mut Module, + override_map: &mut Vec<Handle<Constant>>, + adjusted_global_expressions: &[Handle<Expression>], + adjusted_constant_initializers: &mut HashSet<Handle<Constant>>, + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, +) -> Result<Handle<Constant>, PipelineConstantError> { + // Determine which key to use for `override_` in `pipeline_constants`. + let key = if let Some(id) = override_.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = override_.name { + Cow::Borrowed(name) + } else { + unreachable!(); + }; + + // Generate a global expression for `override_`'s value, either + // from the provided `pipeline_constants` table or its initializer + // in the module. + let init = if let Some(value) = pipeline_constants.get::<str>(&key) { + let literal = match module.types[override_.ty].inner { + TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, + _ => unreachable!(), + }; + let expr = module + .global_expressions + .append(Expression::Literal(literal), Span::UNDEFINED); + global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); + expr + } else if let Some(init) = override_.init { + adjusted_global_expressions[init.index()] + } else { + return Err(PipelineConstantError::MissingValue(key.to_string())); + }; + + // Generate a new `Constant` to represent the override's value. + let constant = Constant { + name: override_.name, + ty: override_.ty, + init, + }; + let h = module.constants.append(constant, span); + debug_assert_eq!(old_h.index(), override_map.len()); + override_map.push(h); + adjusted_constant_initializers.insert(h); + Ok(h) +} + +/// Replace all override expressions in `function` with fully-evaluated constants. +/// +/// Replace all `Expression::Override`s in `function`'s expression arena with +/// the corresponding `Expression::Constant`s, as given in `override_map`. +/// Replace any expressions whose values are now known with their fully +/// evaluated form. +/// +/// If `h` is a `Handle<Override>`, then `override_map[h.index()]` is the +/// `Handle<Constant>` for the override's final value. +fn process_function( + module: &mut Module, + override_map: &[Handle<Constant>], + function: &mut Function, +) -> Result<(), ConstantEvaluatorError> { + // A map from original local expression handles to + // handles in the new, local expression arena. + let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len()); + + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + + let mut expressions = mem::take(&mut function.expressions); + + // Dummy `emitter` and `block` for the constant evaluator. + // We can ignore the concept of emitting expressions here since + // expressions have already been covered by a `Statement::Emit` + // in the frontend. + // The only thing we might have to do is remove some expressions + // that have been covered by a `Statement::Emit`. See the docs of + // `filter_emits_in_block` for the reasoning. + let mut emitter = Emitter::default(); + let mut block = Block::new(); + + let mut evaluator = ConstantEvaluator::for_wgsl_function( + module, + &mut function.expressions, + &mut local_expression_kind_tracker, + &mut emitter, + &mut block, + ); + + for (old_h, mut expr, span) in expressions.drain() { + if let Expression::Override(h) = expr { + expr = Expression::Constant(override_map[h.index()]); + } + adjust_expr(&adjusted_local_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_local_expressions.len()); + adjusted_local_expressions.push(h); + } + + adjust_block(&adjusted_local_expressions, &mut function.body); + + filter_emits_in_block(&mut function.body, &function.expressions); + + // Update local expression initializers. + for (_, local) in function.local_variables.iter_mut() { + if let &mut Some(ref mut init) = &mut local.init { + *init = adjusted_local_expressions[init.index()]; + } + } + + // We've changed the keys of `function.named_expression`, so we have to + // rebuild it from scratch. + let named_expressions = mem::take(&mut function.named_expressions); + for (expr_h, name) in named_expressions { + function + .named_expressions + .insert(adjusted_local_expressions[expr_h.index()], name); + } + + Ok(()) +} + +/// Replace every expression handle in `expr` with its counterpart +/// given by `new_pos`. +fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) { + let adjust = |expr: &mut Handle<Expression>| { + *expr = new_pos[expr.index()]; + }; + match *expr { + Expression::Compose { + ref mut components, + ty: _, + } => { + for c in components.iter_mut() { + adjust(c); + } + } + Expression::Access { + ref mut base, + ref mut index, + } => { + adjust(base); + adjust(index); + } + Expression::AccessIndex { + ref mut base, + index: _, + } => { + adjust(base); + } + Expression::Splat { + ref mut value, + size: _, + } => { + adjust(value); + } + Expression::Swizzle { + ref mut vector, + size: _, + pattern: _, + } => { + adjust(vector); + } + Expression::Load { ref mut pointer } => { + adjust(pointer); + } + Expression::ImageSample { + ref mut image, + ref mut sampler, + ref mut coordinate, + ref mut array_index, + ref mut offset, + ref mut level, + ref mut depth_ref, + gather: _, + } => { + adjust(image); + adjust(sampler); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = offset.as_mut() { + adjust(e); + } + match *level { + crate::SampleLevel::Exact(ref mut expr) + | crate::SampleLevel::Bias(ref mut expr) => { + adjust(expr); + } + crate::SampleLevel::Gradient { + ref mut x, + ref mut y, + } => { + adjust(x); + adjust(y); + } + _ => {} + } + if let Some(e) = depth_ref.as_mut() { + adjust(e); + } + } + Expression::ImageLoad { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut sample, + ref mut level, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = sample.as_mut() { + adjust(e); + } + if let Some(e) = level.as_mut() { + adjust(e); + } + } + Expression::ImageQuery { + ref mut image, + ref mut query, + } => { + adjust(image); + match *query { + crate::ImageQuery::Size { ref mut level } => { + if let Some(e) = level.as_mut() { + adjust(e); + } + } + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => {} + } + } + Expression::Unary { + ref mut expr, + op: _, + } => { + adjust(expr); + } + Expression::Binary { + ref mut left, + ref mut right, + op: _, + } => { + adjust(left); + adjust(right); + } + Expression::Select { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust(accept); + adjust(reject); + } + Expression::Derivative { + ref mut expr, + axis: _, + ctrl: _, + } => { + adjust(expr); + } + Expression::Relational { + ref mut argument, + fun: _, + } => { + adjust(argument); + } + Expression::Math { + ref mut arg, + ref mut arg1, + ref mut arg2, + ref mut arg3, + fun: _, + } => { + adjust(arg); + if let Some(e) = arg1.as_mut() { + adjust(e); + } + if let Some(e) = arg2.as_mut() { + adjust(e); + } + if let Some(e) = arg3.as_mut() { + adjust(e); + } + } + Expression::As { + ref mut expr, + kind: _, + convert: _, + } => { + adjust(expr); + } + Expression::ArrayLength(ref mut expr) => { + adjust(expr); + } + Expression::RayQueryGetIntersection { + ref mut query, + committed: _, + } => { + adjust(query); + } + Expression::Literal(_) + | Expression::FunctionArgument(_) + | Expression::GlobalVariable(_) + | Expression::LocalVariable(_) + | Expression::CallResult(_) + | Expression::RayQueryProceedResult + | Expression::Constant(_) + | Expression::Override(_) + | Expression::ZeroValue(_) + | Expression::AtomicResult { + ty: _, + comparison: _, + } + | Expression::WorkGroupUniformLoadResult { ty: _ } + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} + } +} + +/// Replace every expression handle in `block` with its counterpart +/// given by `new_pos`. +fn adjust_block(new_pos: &[Handle<Expression>], block: &mut Block) { + for stmt in block.iter_mut() { + adjust_stmt(new_pos, stmt); + } +} + +/// Replace every expression handle in `stmt` with its counterpart +/// given by `new_pos`. +fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) { + let adjust = |expr: &mut Handle<Expression>| { + *expr = new_pos[expr.index()]; + }; + match *stmt { + Statement::Emit(ref mut range) => { + if let Some((mut first, mut last)) = range.first_and_last() { + adjust(&mut first); + adjust(&mut last); + *range = Range::new_from_bounds(first, last); + } + } + Statement::Block(ref mut block) => { + adjust_block(new_pos, block); + } + Statement::If { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust_block(new_pos, accept); + adjust_block(new_pos, reject); + } + Statement::Switch { + ref mut selector, + ref mut cases, + } => { + adjust(selector); + for case in cases.iter_mut() { + adjust_block(new_pos, &mut case.body); + } + } + Statement::Loop { + ref mut body, + ref mut continuing, + ref mut break_if, + } => { + adjust_block(new_pos, body); + adjust_block(new_pos, continuing); + if let Some(e) = break_if.as_mut() { + adjust(e); + } + } + Statement::Return { ref mut value } => { + if let Some(e) = value.as_mut() { + adjust(e); + } + } + Statement::Store { + ref mut pointer, + ref mut value, + } => { + adjust(pointer); + adjust(value); + } + Statement::ImageStore { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut value, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + adjust(value); + } + crate::Statement::Atomic { + ref mut pointer, + ref mut value, + ref mut result, + ref mut fun, + } => { + adjust(pointer); + adjust(value); + adjust(result); + match *fun { + crate::AtomicFunction::Exchange { + compare: Some(ref mut compare), + } => { + adjust(compare); + } + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max + | crate::AtomicFunction::Exchange { compare: None } => {} + } + } + Statement::WorkGroupUniformLoad { + ref mut pointer, + ref mut result, + } => { + adjust(pointer); + adjust(result); + } + Statement::SubgroupBallot { + ref mut result, + ref mut predicate, + } => { + if let Some(ref mut predicate) = *predicate { + adjust(predicate); + } + adjust(result); + } + Statement::SubgroupCollectiveOperation { + ref mut argument, + ref mut result, + .. + } => { + adjust(argument); + adjust(result); + } + Statement::SubgroupGather { + ref mut mode, + ref mut argument, + ref mut result, + } => { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) + | crate::GatherMode::Shuffle(ref mut index) + | crate::GatherMode::ShuffleDown(ref mut index) + | crate::GatherMode::ShuffleUp(ref mut index) + | crate::GatherMode::ShuffleXor(ref mut index) => { + adjust(index); + } + } + adjust(argument); + adjust(result) + } + Statement::Call { + ref mut arguments, + ref mut result, + function: _, + } => { + for argument in arguments.iter_mut() { + adjust(argument); + } + if let Some(e) = result.as_mut() { + adjust(e); + } + } + Statement::RayQuery { + ref mut query, + ref mut fun, + } => { + adjust(query); + match *fun { + crate::RayQueryFunction::Initialize { + ref mut acceleration_structure, + ref mut descriptor, + } => { + adjust(acceleration_structure); + adjust(descriptor); + } + crate::RayQueryFunction::Proceed { ref mut result } => { + adjust(result); + } + crate::RayQueryFunction::Terminate => {} + } + } + Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} + } +} + +/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced. +/// +/// According to validation, [`Emit`] statements must not cover any expressions +/// for which [`Expression::needs_pre_emit`] returns true. All expressions built +/// by successful constant evaluation fall into that category, meaning that +/// `process_function` will usually rewrite [`Override`] expressions and those +/// that use their values into pre-emitted expressions, leaving any [`Emit`] +/// statements that cover them invalid. +/// +/// This function rewrites all [`Emit`] statements into zero or more new +/// [`Emit`] statements covering only those expressions in the original range +/// that are not pre-emitted. +/// +/// [`Emit`]: Statement::Emit +/// [`needs_pre_emit`]: Expression::needs_pre_emit +/// [`Override`]: Expression::Override +fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) { + let original = std::mem::replace(block, Block::with_capacity(block.len())); + for (stmt, span) in original.span_into_iter() { + match stmt { + Statement::Emit(range) => { + let mut current = None; + for expr_h in range { + if expressions[expr_h].needs_pre_emit() { + if let Some((first, last)) = current { + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); + } + + current = None; + } else if let Some((_, ref mut last)) = current { + *last = expr_h; + } else { + current = Some((expr_h, expr_h)); + } + } + if let Some((first, last)) = current { + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); + } + } + Statement::Block(mut child) => { + filter_emits_in_block(&mut child, expressions); + block.push(Statement::Block(child), span); + } + Statement::If { + condition, + mut accept, + mut reject, + } => { + filter_emits_in_block(&mut accept, expressions); + filter_emits_in_block(&mut reject, expressions); + block.push( + Statement::If { + condition, + accept, + reject, + }, + span, + ); + } + Statement::Switch { + selector, + mut cases, + } => { + for case in &mut cases { + filter_emits_in_block(&mut case.body, expressions); + } + block.push(Statement::Switch { selector, cases }, span); + } + Statement::Loop { + mut body, + mut continuing, + break_if, + } => { + filter_emits_in_block(&mut body, expressions); + filter_emits_in_block(&mut continuing, expressions); + block.push( + Statement::Loop { + body, + continuing, + break_if, + }, + span, + ); + } + stmt => block.push(stmt.clone(), span), + } + } +} + +fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> { + // note that in rust 0.0 == -0.0 + match scalar { + Scalar::BOOL => { + // https://webidl.spec.whatwg.org/#js-boolean + let value = value != 0.0 && !value.is_nan(); + Ok(Literal::Bool(value)) + } + Scalar::I32 => { + // https://webidl.spec.whatwg.org/#js-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as i32; + Ok(Literal::I32(value)) + } + Scalar::U32 => { + // https://webidl.spec.whatwg.org/#js-unsigned-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as u32; + Ok(Literal::U32(value)) + } + Scalar::F32 => { + // https://webidl.spec.whatwg.org/#js-float + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value as f32; + if !value.is_finite() { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + Ok(Literal::F32(value)) + } + Scalar::F64 => { + // https://webidl.spec.whatwg.org/#js-double + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + Ok(Literal::F64(value)) + } + _ => unreachable!(), + } +} + +#[test] +fn test_map_value_to_literal() { + let bool_test_cases = [ + (0.0, false), + (-0.0, false), + (f64::NAN, false), + (1.0, true), + (f64::INFINITY, true), + (f64::NEG_INFINITY, true), + ]; + for (value, out) in bool_test_cases { + let res = Ok(Literal::Bool(out)); + assert_eq!(map_value_to_literal(value, Scalar::BOOL), res); + } + + for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] { + for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + let res = Err(PipelineConstantError::SrcNeedsToBeFinite); + assert_eq!(map_value_to_literal(value, scalar), res); + } + } + + // i32 + assert_eq!( + map_value_to_literal(f64::from(i32::MIN), Scalar::I32), + Ok(Literal::I32(i32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX), Scalar::I32), + Ok(Literal::I32(i32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // u32 + assert_eq!( + map_value_to_literal(f64::from(u32::MIN), Scalar::U32), + Ok(Literal::U32(u32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX), Scalar::U32), + Ok(Literal::U32(u32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f32 + assert_eq!( + map_value_to_literal(f64::from(f32::MIN), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(f32::MAX), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f64 + assert_eq!( + map_value_to_literal(f64::MIN, Scalar::F64), + Ok(Literal::F64(f64::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::MAX, Scalar::F64), + Ok(Literal::F64(f64::MAX)) + ); +} diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs index 81f2fc10e0..120d60fc40 100644 --- a/third_party/rust/naga/src/back/spv/block.rs +++ b/third_party/rust/naga/src/back/spv/block.rs @@ -239,6 +239,7 @@ impl<'w> BlockContext<'w> { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init.index()] } + crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); @@ -1072,7 +1073,7 @@ impl<'w> BlockContext<'w> { // // bitfieldExtract(x, o, c) - let bit_width = arg_ty.scalar_width().unwrap(); + let bit_width = arg_ty.scalar_width().unwrap() * 8; let width_constant = self .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); @@ -1128,7 +1129,7 @@ impl<'w> BlockContext<'w> { Mf::InsertBits => { // The behavior of InsertBits has the same undefined behavior as ExtractBits. - let bit_width = arg_ty.scalar_width().unwrap(); + let bit_width = arg_ty.scalar_width().unwrap() * 8; let width_constant = self .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); @@ -1184,7 +1185,7 @@ impl<'w> BlockContext<'w> { } Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb), Mf::FindMsb => { - if arg_ty.scalar_width() == Some(32) { + if arg_ty.scalar_width() == Some(4) { let thing = match arg_scalar_kind { Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, @@ -1278,7 +1279,9 @@ impl<'w> BlockContext<'w> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } - | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2489,6 +2492,27 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; + } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; + } + crate::Statement::SubgroupGather { + ref mode, + argument, + result, + } => { + self.write_subgroup_gather(mode, argument, result, &mut block)?; + } } } diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs index 5b6226db85..1fb447e384 100644 --- a/third_party/rust/naga/src/back/spv/helpers.rs +++ b/third_party/rust/naga/src/back/spv/helpers.rs @@ -10,8 +10,12 @@ pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> { pub(super) fn string_to_words(input: &str) -> Vec<Word> { let bytes = input.as_bytes(); - let mut words = bytes_to_words(bytes); + str_bytes_to_words(bytes) +} + +pub(super) fn str_bytes_to_words(bytes: &[u8]) -> Vec<Word> { + let mut words = bytes_to_words(bytes); if bytes.len() % 4 == 0 { // nul-termination words.push(0x0u32); @@ -20,6 +24,21 @@ pub(super) fn string_to_words(input: &str) -> Vec<Word> { words } +/// split a string into chunks and keep utf8 valid +#[allow(unstable_name_collisions)] +pub(super) fn string_to_byte_chunks(input: &str, limit: usize) -> Vec<&[u8]> { + let mut offset: usize = 0; + let mut start: usize = 0; + let mut words = vec![]; + while offset < input.len() { + offset = input.floor_char_boundary(offset + limit); + words.push(input[start..offset].as_bytes()); + start = offset; + } + + words +} + pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass { match space { crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant, @@ -107,3 +126,35 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab _ => true, } } + +///HACK: this is taken from std unstable, remove it when std's floor_char_boundary is stable +trait U8Internal { + fn is_utf8_char_boundary(&self) -> bool; +} + +impl U8Internal for u8 { + fn is_utf8_char_boundary(&self) -> bool { + // This is bit magic equivalent to: b < 128 || b >= 192 + (*self as i8) >= -0x40 + } +} + +trait StrUnstable { + fn floor_char_boundary(&self, index: usize) -> usize; +} + +impl StrUnstable for str { + fn floor_char_boundary(&self, index: usize) -> usize { + if index >= self.len() { + self.len() + } else { + let lower_bound = index.saturating_sub(3); + let new_index = self.as_bytes()[lower_bound..=index] + .iter() + .rposition(|b| b.is_utf8_char_boundary()); + + // SAFETY: we know that the character boundary will be within four bytes + unsafe { lower_bound + new_index.unwrap_unchecked() } + } + } +} diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs index b963793ad3..df2774ab9c 100644 --- a/third_party/rust/naga/src/back/spv/instructions.rs +++ b/third_party/rust/naga/src/back/spv/instructions.rs @@ -43,6 +43,42 @@ impl super::Instruction { instruction } + pub(super) fn source_continued(source: &[u8]) -> Self { + let mut instruction = Self::new(Op::SourceContinued); + instruction.add_operands(helpers::str_bytes_to_words(source)); + instruction + } + + pub(super) fn source_auto_continued( + source_language: spirv::SourceLanguage, + version: u32, + source: &Option<DebugInfoInner>, + ) -> Vec<Self> { + let mut instructions = vec![]; + + let with_continue = source.as_ref().and_then(|debug_info| { + (debug_info.source_code.len() > u16::MAX as usize).then_some(debug_info) + }); + if let Some(debug_info) = with_continue { + let mut instruction = Self::new(Op::Source); + instruction.add_operand(source_language as u32); + instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes())); + + let words = helpers::string_to_byte_chunks(debug_info.source_code, u16::MAX as usize); + instruction.add_operand(debug_info.source_file_id); + instruction.add_operands(helpers::str_bytes_to_words(words[0])); + instructions.push(instruction); + for word_bytes in words[1..].iter() { + let instruction_continue = Self::source_continued(word_bytes); + instructions.push(instruction_continue); + } + } else { + let instruction = Self::source(source_language, version, source); + instructions.push(instruction); + } + instructions + } + pub(super) fn name(target_id: Word, name: &str) -> Self { let mut instruction = Self::new(Op::Name); instruction.add_operand(target_id); @@ -1037,6 +1073,73 @@ impl super::Instruction { instruction.add_operand(semantics_id); instruction } + + // Group Instructions + + pub(super) fn group_non_uniform_ballot( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + predicate: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBallot); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(predicate); + + instruction + } + pub(super) fn group_non_uniform_broadcast_first( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + + instruction + } + pub(super) fn group_non_uniform_gather( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_arithmetic( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + group_op: Option<spirv::GroupOperation>, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + if let Some(group_op) = group_op { + instruction.add_operand(group_op as u32); + } + instruction.add_operand(value); + + instruction + } } impl From<crate::StorageFormat> for spirv::ImageFormat { diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs index eb29e3cd8b..38a87049e6 100644 --- a/third_party/rust/naga/src/back/spv/mod.rs +++ b/third_party/rust/naga/src/back/spv/mod.rs @@ -13,6 +13,7 @@ mod layout; mod ray; mod recyclable; mod selection; +mod subgroup; mod writer; pub use spirv::Capability; @@ -70,6 +71,8 @@ pub enum Error { FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), + #[error("overrides should not be present at this stage")] + Override, } #[derive(Default)] @@ -245,7 +248,7 @@ impl LocalImageType { /// this, by converting everything possible to a `LocalType` before inspecting /// it. /// -/// ## `Localtype` equality and SPIR-V `OpType` uniqueness +/// ## `LocalType` equality and SPIR-V `OpType` uniqueness /// /// The definition of `Eq` on `LocalType` is carefully chosen to help us follow /// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...` @@ -454,7 +457,7 @@ impl recyclable::Recyclable for CachedExpressions { #[derive(Eq, Hash, PartialEq)] enum CachedConstant { - Literal(crate::Literal), + Literal(crate::proc::HashableLiteral), Composite { ty: LookupType, constituent_ids: Vec<Word>, @@ -527,6 +530,42 @@ struct FunctionArgument { handle_id: Word, } +/// Tracks the expressions for which the backend emits the following instructions: +/// - OpConstantTrue +/// - OpConstantFalse +/// - OpConstant +/// - OpConstantComposite +/// - OpConstantNull +struct ExpressionConstnessTracker { + inner: bit_set::BitSet, +} + +impl ExpressionConstnessTracker { + fn from_arena(arena: &crate::Arena<crate::Expression>) -> Self { + let mut inner = bit_set::BitSet::new(); + for (handle, expr) in arena.iter() { + let insert = match *expr { + crate::Expression::Literal(_) + | crate::Expression::ZeroValue(_) + | crate::Expression::Constant(_) => true, + crate::Expression::Compose { ref components, .. } => { + components.iter().all(|h| inner.contains(h.index())) + } + crate::Expression::Splat { value, .. } => inner.contains(value.index()), + _ => false, + }; + if insert { + inner.insert(handle.index()); + } + } + Self { inner } + } + + fn is_const(&self, value: Handle<crate::Expression>) -> bool { + self.inner.contains(value.index()) + } +} + /// General information needed to emit SPIR-V for Naga statements. struct BlockContext<'w> { /// The writer handling the module to which this code belongs. @@ -552,7 +591,7 @@ struct BlockContext<'w> { temp_list: Vec<Word>, /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions` - expression_constness: crate::proc::ExpressionConstnessTracker, + expression_constness: ExpressionConstnessTracker, } impl BlockContext<'_> { @@ -725,7 +764,7 @@ impl<'a> Default for Options<'a> { } // A subset of options meant to be changed per pipeline. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { diff --git a/third_party/rust/naga/src/back/spv/subgroup.rs b/third_party/rust/naga/src/back/spv/subgroup.rs new file mode 100644 index 0000000000..c952cb11a7 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/subgroup.rs @@ -0,0 +1,207 @@ +use super::{Block, BlockContext, Error, Instruction}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; + +impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option<Handle<crate::Expression>>, + result: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + use crate::SubgroupOperation as sg; + match *op { + sg::All | sg::Any => { + self.writer.require_any( + "GroupNonUniformVote", + &[spirv::Capability::GroupNonUniformVote], + )?; + } + _ => { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[spirv::Capability::GroupNonUniformArithmetic], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + let result_ty_inner = result_ty.inner_with(&self.ir_module.types); + + let (is_scalar, scalar) = match *result_ty_inner { + TypeInner::Scalar(kind) => (true, kind), + TypeInner::Vector { scalar: kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + let spirv_op = match (scalar.kind, *op) { + (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, + (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, + (_, sg::All | sg::Any) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd, + (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd, + (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul, + (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul, + (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax, + (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax, + (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax, + (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin, + (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin, + (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin, + (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd, + (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr, + (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor, + (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd, + (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr, + (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor, + (_, sg::And | sg::Or | sg::Xor) => unimplemented!(), + }; + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + use crate::CollectiveOperation as c; + let group_op = match *op { + sg::All | sg::Any => None, + _ => Some(match *collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }), + }; + + let arg_id = self.cached[argument]; + block.body.push(Instruction::group_non_uniform_arithmetic( + spirv_op, + result_type_id, + id, + exec_scope_id, + group_op, + arg_id, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_gather( + &mut self, + mode: &crate::GatherMode, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + match *mode { + crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + } + crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => { + self.writer.require_any( + "GroupNonUniformShuffle", + &[spirv::Capability::GroupNonUniformShuffle], + )?; + } + crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { + self.writer.require_any( + "GroupNonUniformShuffleRelative", + &[spirv::Capability::GroupNonUniformShuffleRelative], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + match *mode { + crate::GatherMode::BroadcastFirst => { + block + .body + .push(Instruction::group_non_uniform_broadcast_first( + result_type_id, + id, + exec_scope_id, + arg_id, + )); + } + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + // Use shuffle to emit broadcast to allow the index to + // be dynamically uniform on Vulkan 1.1. The argument to + // OpGroupNonUniformBroadcast must be a constant pre SPIR-V + // 1.5 (vulkan 1.2) + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } + } + self.cached[result] = id; + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs index a5065e0623..73a16c273e 100644 --- a/third_party/rust/naga/src/back/spv/writer.rs +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -615,7 +615,7 @@ impl Writer { // Steal the Writer's temp list for a bit. temp_list: std::mem::take(&mut self.temp_list), writer: self, - expression_constness: crate::proc::ExpressionConstnessTracker::from_arena( + expression_constness: super::ExpressionConstnessTracker::from_arena( &ir_function.expressions, ), }; @@ -970,6 +970,11 @@ impl Writer { handle: Handle<crate::Type>, ) -> Result<Word, Error> { let ty = &arena[handle]; + // If it's a type that needs SPIR-V capabilities, request them now. + // This needs to happen regardless of the LocalType lookup succeeding, + // because some types which map to the same LocalType have different + // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569 + self.request_type_capabilities(&ty.inner)?; let id = if let Some(local) = make_local(&ty.inner) { // This type can be represented as a `LocalType`, so check if we've // already written an instruction for it. If not, do so now, with @@ -985,10 +990,6 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's a type that needs SPIR-V capabilities, request them now, - // so write_type_declaration_local can stay infallible. - self.request_type_capabilities(&ty.inner)?; - id } } @@ -1150,7 +1151,7 @@ impl Writer { } pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word { - let scalar = CachedConstant::Literal(value); + let scalar = CachedConstant::Literal(value.into()); if let Some(&id) = self.cached_constants.get(&scalar) { return id; } @@ -1258,7 +1259,7 @@ impl Writer { ir_module: &crate::Module, mod_info: &ModuleInfo, ) -> Result<Word, Error> { - let id = match ir_module.const_expressions[handle] { + let id = match ir_module.global_expressions[handle] { crate::Expression::Literal(literal) => self.get_constant_scalar(literal), crate::Expression::Constant(constant) => { let constant = &ir_module.constants[constant]; @@ -1272,7 +1273,7 @@ impl Writer { let component_ids: Vec<_> = crate::proc::flatten_compose( ty, components, - &ir_module.const_expressions, + &ir_module.global_expressions, &ir_module.types, ) .map(|component| self.constant_ids[component.index()]) @@ -1310,7 +1311,11 @@ impl Writer { spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); - let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { + self.get_index_constant(spirv::Scope::Subgroup as u32) + } else { + self.get_index_constant(spirv::Scope::Workgroup as u32) + }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); block.body.push(Instruction::control_barrier( @@ -1585,6 +1590,41 @@ impl Writer { Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + // Subgroup + Bi::NumSubgroups => { + self.require_any( + "`num_subgroups` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId + } + Bi::SubgroupSize => { + self.require_any( + "`subgroup_size` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupSize + } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); @@ -1899,7 +1939,7 @@ impl Writer { source_code: debug_info.source_code, source_file_id, }); - self.debugs.push(Instruction::source( + self.debugs.append(&mut Instruction::source_auto_continued( spirv::SourceLanguage::Unknown, 0, &debug_info_inner, @@ -1914,8 +1954,8 @@ impl Writer { // write all const-expressions as constants self.constant_ids - .resize(ir_module.const_expressions.len(), 0); - for (handle, _) in ir_module.const_expressions.iter() { + .resize(ir_module.global_expressions.len(), 0); + for (handle, _) in ir_module.global_expressions.iter() { self.write_constant_expr(handle, ir_module, mod_info)?; } debug_assert!(self.constant_ids.iter().all(|&id| id != 0)); @@ -2029,6 +2069,10 @@ impl Writer { debug_info: &Option<DebugInfo>, words: &mut Vec<Word>, ) -> Result<(), Error> { + if !ir_module.overrides.is_empty() { + return Err(Error::Override); + } + self.reset(); // Try to find the entry point and corresponding index diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs index 3039cbbbe4..789f6f62bf 100644 --- a/third_party/rust/naga/src/back/wgsl/writer.rs +++ b/third_party/rust/naga/src/back/wgsl/writer.rs @@ -106,6 +106,12 @@ impl<W: Write> Writer<W> { } pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { + if !module.overrides.is_empty() { + return Err(Error::Unimplemented( + "Pipeline constants are not yet supported for this back-end".to_string(), + )); + } + self.reset(module); // Save all ep result types @@ -918,8 +924,124 @@ impl<W: Write> Writer<W> { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + write!(self.out, "subgroupBallot(")?; + if let Some(predicate) = predicate { + self.write_expr(module, predicate, func_ctx)?; + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -1070,7 +1192,7 @@ impl<W: Write> Writer<W> { self.write_possibly_const_expression( module, expr, - &module.const_expressions, + &module.global_expressions, |writer, expr| writer.write_const_expression(module, expr), ) } @@ -1199,6 +1321,7 @@ impl<W: Write> Writer<W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => unreachable!(), Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; @@ -1691,6 +1814,8 @@ impl<W: Write> Writer<W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} } @@ -1792,6 +1917,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", + Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/third_party/rust/naga/src/block.rs b/third_party/rust/naga/src/block.rs index 0abda9da7c..2e86a928f1 100644 --- a/third_party/rust/naga/src/block.rs +++ b/third_party/rust/naga/src/block.rs @@ -65,6 +65,12 @@ impl Block { self.span_info.splice(range.clone(), other.span_info); self.body.splice(range, other.body); } + + pub fn span_into_iter(self) -> impl Iterator<Item = (Statement, Span)> { + let Block { body, span_info } = self; + body.into_iter().zip(span_info) + } + pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> { let span_iter = self.span_info.iter(); self.body.iter().zip(span_iter) diff --git a/third_party/rust/naga/src/compact/expressions.rs b/third_party/rust/naga/src/compact/expressions.rs index 301bbe3240..a418bde301 100644 --- a/third_party/rust/naga/src/compact/expressions.rs +++ b/third_party/rust/naga/src/compact/expressions.rs @@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle}; pub struct ExpressionTracer<'tracer> { pub constants: &'tracer Arena<crate::Constant>, + pub overrides: &'tracer Arena<crate::Override>, /// The arena in which we are currently tracing expressions. pub expressions: &'tracer Arena<crate::Expression>, @@ -20,11 +21,11 @@ pub struct ExpressionTracer<'tracer> { /// the module's constant expression arena. pub expressions_used: &'tracer mut HandleSet<crate::Expression>, - /// The used set for the module's `const_expressions` arena. + /// The used set for the module's `global_expressions` arena. /// /// If `None`, we are already tracing the constant expressions, /// and `expressions_used` already refers to their handle set. - pub const_expressions_used: Option<&'tracer mut HandleSet<crate::Expression>>, + pub global_expressions_used: Option<&'tracer mut HandleSet<crate::Expression>>, } impl<'tracer> ExpressionTracer<'tracer> { @@ -39,11 +40,11 @@ impl<'tracer> ExpressionTracer<'tracer> { /// marked. /// /// [fe]: crate::Function::expressions - /// [ce]: crate::Module::const_expressions + /// [ce]: crate::Module::global_expressions pub fn trace_expressions(&mut self) { log::trace!( "entering trace_expression of {}", - if self.const_expressions_used.is_some() { + if self.global_expressions_used.is_some() { "function expressions" } else { "const expressions" @@ -71,6 +72,7 @@ impl<'tracer> ExpressionTracer<'tracer> { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} Ex::Constant(handle) => { @@ -83,11 +85,16 @@ impl<'tracer> ExpressionTracer<'tracer> { // and the constant refers to the initializer, it must // precede `expr` in the arena. let init = self.constants[handle].init; - match self.const_expressions_used { + match self.global_expressions_used { Some(ref mut used) => used.insert(init), None => self.expressions_used.insert(init), } } + Ex::Override(_) => { + // All overrides are considered used by definition. We mark + // their types and initialization expressions as used in + // `compact::compact`, so we have no more work to do here. + } Ex::ZeroValue(ty) => self.types_used.insert(ty), Ex::Compose { ty, ref components } => { self.types_used.insert(ty); @@ -116,7 +123,7 @@ impl<'tracer> ExpressionTracer<'tracer> { self.expressions_used .insert_iter([image, sampler, coordinate]); self.expressions_used.insert_iter(array_index); - match self.const_expressions_used { + match self.global_expressions_used { Some(ref mut used) => used.insert_iter(offset), None => self.expressions_used.insert_iter(offset), } @@ -186,6 +193,7 @@ impl<'tracer> ExpressionTracer<'tracer> { Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty), Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty), Ex::ArrayLength(expr) => self.expressions_used.insert(expr), + Ex::SubgroupOperationResult { ty } => self.types_used.insert(ty), Ex::RayQueryGetIntersection { query, committed: _, @@ -217,8 +225,12 @@ impl ModuleMap { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} + // All overrides are retained, so their handles never change. + Ex::Override(_) => {} + // Expressions that contain handles that need to be adjusted. Ex::Constant(ref mut constant) => self.constants.adjust(constant), Ex::ZeroValue(ref mut ty) => self.types.adjust(ty), @@ -267,7 +279,7 @@ impl ModuleMap { adjust(coordinate); operand_map.adjust_option(array_index); if let Some(ref mut offset) = *offset { - self.const_expressions.adjust(offset); + self.global_expressions.adjust(offset); } self.adjust_sample_level(level, operand_map); operand_map.adjust_option(depth_ref); @@ -344,6 +356,7 @@ impl ModuleMap { comparison: _, } => self.types.adjust(ty), Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty), + Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty), Ex::ArrayLength(ref mut expr) => adjust(expr), Ex::RayQueryGetIntersection { ref mut query, diff --git a/third_party/rust/naga/src/compact/functions.rs b/third_party/rust/naga/src/compact/functions.rs index b0d08c7e96..4ac2223eb7 100644 --- a/third_party/rust/naga/src/compact/functions.rs +++ b/third_party/rust/naga/src/compact/functions.rs @@ -4,10 +4,11 @@ use super::{FunctionMap, ModuleMap}; pub struct FunctionTracer<'a> { pub function: &'a crate::Function, pub constants: &'a crate::Arena<crate::Constant>, + pub overrides: &'a crate::Arena<crate::Override>, pub types_used: &'a mut HandleSet<crate::Type>, pub constants_used: &'a mut HandleSet<crate::Constant>, - pub const_expressions_used: &'a mut HandleSet<crate::Expression>, + pub global_expressions_used: &'a mut HandleSet<crate::Expression>, /// Function-local expressions used. pub expressions_used: HandleSet<crate::Expression>, @@ -47,12 +48,13 @@ impl<'a> FunctionTracer<'a> { fn as_expression(&mut self) -> super::expressions::ExpressionTracer { super::expressions::ExpressionTracer { constants: self.constants, + overrides: self.overrides, expressions: &self.function.expressions, types_used: self.types_used, constants_used: self.constants_used, expressions_used: &mut self.expressions_used, - const_expressions_used: Some(&mut self.const_expressions_used), + global_expressions_used: Some(&mut self.global_expressions_used), } } } diff --git a/third_party/rust/naga/src/compact/mod.rs b/third_party/rust/naga/src/compact/mod.rs index b4e57ed5c9..0d7a37b579 100644 --- a/third_party/rust/naga/src/compact/mod.rs +++ b/third_party/rust/naga/src/compact/mod.rs @@ -38,7 +38,7 @@ pub fn compact(module: &mut crate::Module) { log::trace!("tracing global {:?}", global.name); module_tracer.types_used.insert(global.ty); if let Some(init) = global.init { - module_tracer.const_expressions_used.insert(init); + module_tracer.global_expressions_used.insert(init); } } } @@ -50,7 +50,15 @@ pub fn compact(module: &mut crate::Module) { for (handle, constant) in module.constants.iter() { if constant.name.is_some() { module_tracer.constants_used.insert(handle); - module_tracer.const_expressions_used.insert(constant.init); + module_tracer.global_expressions_used.insert(constant.init); + } + } + + // We treat all overrides as used by definition. + for (_, override_) in module.overrides.iter() { + module_tracer.types_used.insert(override_.ty); + if let Some(init) = override_.init { + module_tracer.global_expressions_used.insert(init); } } @@ -137,9 +145,9 @@ pub fn compact(module: &mut crate::Module) { // Drop unused constant expressions, reusing existing storage. log::trace!("adjusting constant expressions"); - module.const_expressions.retain_mut(|handle, expr| { - if module_map.const_expressions.used(handle) { - module_map.adjust_expression(expr, &module_map.const_expressions); + module.global_expressions.retain_mut(|handle, expr| { + if module_map.global_expressions.used(handle) { + module_map.adjust_expression(expr, &module_map.global_expressions); true } else { false @@ -151,20 +159,29 @@ pub fn compact(module: &mut crate::Module) { module.constants.retain_mut(|handle, constant| { if module_map.constants.used(handle) { module_map.types.adjust(&mut constant.ty); - module_map.const_expressions.adjust(&mut constant.init); + module_map.global_expressions.adjust(&mut constant.init); true } else { false } }); + // Adjust override types and initializers. + log::trace!("adjusting overrides"); + for (_, override_) in module.overrides.iter_mut() { + module_map.types.adjust(&mut override_.ty); + if let Some(init) = override_.init.as_mut() { + module_map.global_expressions.adjust(init); + } + } + // Adjust global variables' types and initializers. log::trace!("adjusting global variables"); for (_, global) in module.global_variables.iter_mut() { log::trace!("adjusting global {:?}", global.name); module_map.types.adjust(&mut global.ty); if let Some(ref mut init) = global.init { - module_map.const_expressions.adjust(init); + module_map.global_expressions.adjust(init); } } @@ -193,7 +210,7 @@ struct ModuleTracer<'module> { module: &'module crate::Module, types_used: HandleSet<crate::Type>, constants_used: HandleSet<crate::Constant>, - const_expressions_used: HandleSet<crate::Expression>, + global_expressions_used: HandleSet<crate::Expression>, } impl<'module> ModuleTracer<'module> { @@ -202,7 +219,7 @@ impl<'module> ModuleTracer<'module> { module, types_used: HandleSet::for_arena(&module.types), constants_used: HandleSet::for_arena(&module.constants), - const_expressions_used: HandleSet::for_arena(&module.const_expressions), + global_expressions_used: HandleSet::for_arena(&module.global_expressions), } } @@ -233,12 +250,13 @@ impl<'module> ModuleTracer<'module> { fn as_const_expression(&mut self) -> expressions::ExpressionTracer { expressions::ExpressionTracer { - expressions: &self.module.const_expressions, + expressions: &self.module.global_expressions, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, - expressions_used: &mut self.const_expressions_used, - const_expressions_used: None, + expressions_used: &mut self.global_expressions_used, + global_expressions_used: None, } } @@ -249,9 +267,10 @@ impl<'module> ModuleTracer<'module> { FunctionTracer { function, constants: &self.module.constants, + overrides: &self.module.overrides, types_used: &mut self.types_used, constants_used: &mut self.constants_used, - const_expressions_used: &mut self.const_expressions_used, + global_expressions_used: &mut self.global_expressions_used, expressions_used: HandleSet::for_arena(&function.expressions), } } @@ -260,7 +279,7 @@ impl<'module> ModuleTracer<'module> { struct ModuleMap { types: HandleMap<crate::Type>, constants: HandleMap<crate::Constant>, - const_expressions: HandleMap<crate::Expression>, + global_expressions: HandleMap<crate::Expression>, } impl From<ModuleTracer<'_>> for ModuleMap { @@ -268,7 +287,7 @@ impl From<ModuleTracer<'_>> for ModuleMap { ModuleMap { types: HandleMap::from_set(used.types_used), constants: HandleMap::from_set(used.constants_used), - const_expressions: HandleMap::from_set(used.const_expressions_used), + global_expressions: HandleMap::from_set(used.global_expressions_used), } } } diff --git a/third_party/rust/naga/src/compact/statements.rs b/third_party/rust/naga/src/compact/statements.rs index 0698b57258..a124281bc1 100644 --- a/third_party/rust/naga/src/compact/statements.rs +++ b/third_party/rust/naga/src/compact/statements.rs @@ -97,6 +97,39 @@ impl FunctionTracer<'_> { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } + St::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.expressions_used.insert(predicate) + } + self.expressions_used.insert(result) + } + St::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result, + } => { + self.expressions_used.insert(argument); + self.expressions_used.insert(result) + } + St::SubgroupGather { + mode, + argument, + result, + } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.expressions_used.insert(index) + } + } + self.expressions_used.insert(argument); + self.expressions_used.insert(result) + } // Trivial statements. St::Break @@ -250,6 +283,40 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::SubgroupBallot { + ref mut result, + ref mut predicate, + } => { + if let Some(ref mut predicate) = *predicate { + adjust(predicate); + } + adjust(result); + } + St::SubgroupCollectiveOperation { + op: _, + collective_op: _, + ref mut argument, + ref mut result, + } => { + adjust(argument); + adjust(result); + } + St::SubgroupGather { + ref mut mode, + ref mut argument, + ref mut result, + } => { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) + | crate::GatherMode::Shuffle(ref mut index) + | crate::GatherMode::ShuffleDown(ref mut index) + | crate::GatherMode::ShuffleUp(ref mut index) + | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), + } + adjust(argument); + adjust(result); + } // Trivial statements. St::Break diff --git a/third_party/rust/naga/src/error.rs b/third_party/rust/naga/src/error.rs new file mode 100644 index 0000000000..5f2e28360b --- /dev/null +++ b/third_party/rust/naga/src/error.rs @@ -0,0 +1,74 @@ +use std::{error::Error, fmt}; + +#[derive(Clone, Debug)] +pub struct ShaderError<E> { + /// The source code of the shader. + pub source: String, + pub label: Option<String>, + pub inner: Box<E>, +} + +#[cfg(feature = "wgsl-in")] +impl fmt::Display for ShaderError<crate::front::wgsl::ParseError> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let label = self.label.as_deref().unwrap_or_default(); + let string = self.inner.emit_to_string(&self.source); + write!(f, "\nShader '{label}' parsing {string}") + } +} +#[cfg(feature = "glsl-in")] +impl fmt::Display for ShaderError<crate::front::glsl::ParseErrors> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let label = self.label.as_deref().unwrap_or_default(); + let string = self.inner.emit_to_string(&self.source); + write!(f, "\nShader '{label}' parsing {string}") + } +} +#[cfg(feature = "spv-in")] +impl fmt::Display for ShaderError<crate::front::spv::Error> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let label = self.label.as_deref().unwrap_or_default(); + let string = self.inner.emit_to_string(&self.source); + write!(f, "\nShader '{label}' parsing {string}") + } +} +impl fmt::Display for ShaderError<crate::WithSpan<crate::valid::ValidationError>> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use codespan_reporting::{ + diagnostic::{Diagnostic, Label}, + files::SimpleFile, + term, + }; + + let label = self.label.as_deref().unwrap_or_default(); + let files = SimpleFile::new(label, &self.source); + let config = term::Config::default(); + let mut writer = term::termcolor::NoColor::new(Vec::new()); + + let diagnostic = Diagnostic::error().with_labels( + self.inner + .spans() + .map(|&(span, ref desc)| { + Label::primary((), span.to_range().unwrap()).with_message(desc.to_owned()) + }) + .collect(), + ); + + term::emit(&mut writer, &config, &files, &diagnostic).expect("cannot write error"); + + write!( + f, + "\nShader validation {}", + String::from_utf8_lossy(&writer.into_inner()) + ) + } +} +impl<E> Error for ShaderError<E> +where + ShaderError<E>: fmt::Display, + E: Error + 'static, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.inner) + } +} diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs index f26c57965d..6ba7df593a 100644 --- a/third_party/rust/naga/src/front/glsl/context.rs +++ b/third_party/rust/naga/src/front/glsl/context.rs @@ -77,12 +77,19 @@ pub struct Context<'a> { pub body: Block, pub module: &'a mut crate::Module, pub is_const: bool, - /// Tracks the constness of `Expression`s residing in `self.expressions` - pub expression_constness: crate::proc::ExpressionConstnessTracker, + /// Tracks the expression kind of `Expression`s residing in `self.expressions` + pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker, + /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions` + pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, } impl<'a> Context<'a> { - pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> { + pub fn new( + frontend: &Frontend, + module: &'a mut crate::Module, + is_const: bool, + global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker, + ) -> Result<Self> { let mut this = Context { expressions: Arena::new(), locals: Arena::new(), @@ -101,7 +108,8 @@ impl<'a> Context<'a> { body: Block::new(), module, is_const: false, - expression_constness: crate::proc::ExpressionConstnessTracker::new(), + local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(), + global_expression_kind_tracker, }; this.emit_start(); @@ -249,40 +257,24 @@ impl<'a> Context<'a> { pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> { let mut eval = if self.is_const { - crate::proc::ConstantEvaluator::for_glsl_module(self.module) + crate::proc::ConstantEvaluator::for_glsl_module( + self.module, + self.global_expression_kind_tracker, + ) } else { crate::proc::ConstantEvaluator::for_glsl_function( self.module, &mut self.expressions, - &mut self.expression_constness, + &mut self.local_expression_kind_tracker, &mut self.emitter, &mut self.body, ) }; - let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error { + eval.try_eval_and_append(expr, meta).map_err(|e| Error { kind: e.into(), meta, - }); - - match res { - Ok(expr) => Ok(expr), - Err(e) => { - if self.is_const { - Err(e) - } else { - let needs_pre_emit = expr.needs_pre_emit(); - if needs_pre_emit { - self.body.extend(self.emitter.finish(&self.expressions)); - } - let h = self.expressions.append(expr, meta); - if needs_pre_emit { - self.emitter.start(&self.expressions); - } - Ok(h) - } - } - } + }) } /// Add variable to current scope @@ -1479,7 +1471,7 @@ impl Index<Handle<Expression>> for Context<'_> { fn index(&self, index: Handle<Expression>) -> &Self::Output { if self.is_const { - &self.module.const_expressions[index] + &self.module.global_expressions[index] } else { &self.expressions[index] } diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs index bd16ee30bc..e0771437e6 100644 --- a/third_party/rust/naga/src/front/glsl/error.rs +++ b/third_party/rust/naga/src/front/glsl/error.rs @@ -1,4 +1,5 @@ use super::token::TokenValue; +use crate::SourceLocation; use crate::{proc::ConstantEvaluatorError, Span}; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFile; @@ -137,14 +138,21 @@ pub struct Error { pub meta: Span, } +impl Error { + /// Returns a [`SourceLocation`] for the error message. + pub fn location(&self, source: &str) -> Option<SourceLocation> { + Some(self.meta.location(source)) + } +} + /// A collection of errors returned during shader parsing. #[derive(Clone, Debug)] #[cfg_attr(test, derive(PartialEq))] -pub struct ParseError { +pub struct ParseErrors { pub errors: Vec<Error>, } -impl ParseError { +impl ParseErrors { pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) { self.emit_to_writer_with_path(writer, source, "glsl"); } @@ -172,19 +180,19 @@ impl ParseError { } } -impl std::fmt::Display for ParseError { +impl std::fmt::Display for ParseErrors { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.errors.iter().try_for_each(|e| write!(f, "{e:?}")) } } -impl std::error::Error for ParseError { +impl std::error::Error for ParseErrors { fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { None } } -impl From<Vec<Error>> for ParseError { +impl From<Vec<Error>> for ParseErrors { fn from(errors: Vec<Error>) -> Self { Self { errors } } diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs index 01846eb814..fa1bbef56b 100644 --- a/third_party/rust/naga/src/front/glsl/functions.rs +++ b/third_party/rust/naga/src/front/glsl/functions.rs @@ -1236,6 +1236,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1256,6 +1258,8 @@ impl Frontend { let value = ctx .expressions .append(Expression::FunctionArgument(idx), Default::default()); + ctx.local_expression_kind_tracker + .insert(value, crate::proc::ExpressionKind::Runtime); ctx.body .push(Statement::Store { pointer, value }, Default::default()); }, @@ -1285,6 +1289,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1307,6 +1313,8 @@ impl Frontend { let load = ctx .expressions .append(Expression::Load { pointer }, Default::default()); + ctx.local_expression_kind_tracker + .insert(load, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), @@ -1329,6 +1337,8 @@ impl Frontend { let res = ctx .expressions .append(Expression::Compose { ty, components }, Default::default()); + ctx.local_expression_kind_tracker + .insert(res, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs index 75f3929db4..ea202b2445 100644 --- a/third_party/rust/naga/src/front/glsl/mod.rs +++ b/third_party/rust/naga/src/front/glsl/mod.rs @@ -13,7 +13,7 @@ To begin, take a look at the documentation for the [`Frontend`]. */ pub use ast::{Precision, Profile}; -pub use error::{Error, ErrorKind, ExpectedToken, ParseError}; +pub use error::{Error, ErrorKind, ExpectedToken, ParseErrors}; pub use token::TokenValue; use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type}; @@ -196,7 +196,7 @@ impl Frontend { &mut self, options: &Options, source: &str, - ) -> std::result::Result<Module, ParseError> { + ) -> std::result::Result<Module, ParseErrors> { self.reset(options.stage); let lexer = lex::Lexer::new(source, &options.defines); diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs index 851d2e1d79..28e0808063 100644 --- a/third_party/rust/naga/src/front/glsl/parser.rs +++ b/third_party/rust/naga/src/front/glsl/parser.rs @@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> { pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> { let mut module = Module::default(); + let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); // Body and expression arena for global initialization - let mut ctx = Context::new(frontend, &mut module, false)?; + let mut ctx = Context::new( + frontend, + &mut module, + false, + &mut global_expression_kind_tracker, + )?; while self.peek(frontend).is_some() { self.parse_external_declaration(frontend, &mut ctx)?; @@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, ) -> Result<(u32, Span)> { - let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); @@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, module: &mut Module, + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, ) -> Result<(Handle<Expression>, Span)> { - let mut ctx = Context::new(frontend, module, true)?; + let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?; let mut stmt_ctx = ctx.stmt_ctx(); let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs index f5e38fb016..2d253a378d 100644 --- a/third_party/rust/naga/src/front/glsl/parser/declarations.rs +++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs @@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> { init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok()); late_initializer = None; } else if let Some(init) = init { - if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) { + if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) { decl_initializer = None; late_initializer = Some(init); } else { @@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> { let result = ty.map(|ty| FunctionResult { ty, binding: None }); - let mut context = Context::new(frontend, ctx.module, false)?; + let mut context = Context::new( + frontend, + ctx.module, + false, + ctx.global_expression_kind_tracker, + )?; self.parse_function_args(frontend, &mut context)?; diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs index d428d74761..d0c889e4d3 100644 --- a/third_party/rust/naga/src/front/glsl/parser/functions.rs +++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs @@ -192,10 +192,13 @@ impl<'source> ParsingContext<'source> { TokenValue::Case => { self.bump(frontend)?; - let (const_expr, meta) = - self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; - match ctx.module.const_expressions[const_expr] { + match ctx.module.global_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { // This unchecked cast isn't good, but since // we only reach this code when the selector diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs index 259052cd27..135765ca58 100644 --- a/third_party/rust/naga/src/front/glsl/parser_tests.rs +++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs @@ -1,7 +1,7 @@ use super::{ ast::Profile, error::ExpectedToken, - error::{Error, ErrorKind, ParseError}, + error::{Error, ErrorKind, ParseErrors}, token::TokenValue, Frontend, Options, Span, }; @@ -21,7 +21,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidVersion(99000), meta: Span::new(9, 14) @@ -37,7 +37,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidVersion(449), meta: Span::new(9, 12) @@ -53,7 +53,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::InvalidProfile("smart".into()), meta: Span::new(13, 18), @@ -69,7 +69,7 @@ fn version() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![ Error { kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,), @@ -455,7 +455,7 @@ fn functions() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Function already defined".into()), meta: Span::new(134, 152), @@ -539,7 +539,7 @@ fn constants() { let mut types = module.types.iter(); let mut constants = module.constants.iter(); - let mut const_expressions = module.const_expressions.iter(); + let mut global_expressions = module.global_expressions.iter(); let (ty_handle, ty) = types.next().unwrap(); assert_eq!( @@ -550,14 +550,13 @@ fn constants() { } ); - let (init_handle, init) = const_expressions.next().unwrap(); + let (init_handle, init) = global_expressions.next().unwrap(); assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0))); assert_eq!( constants.next().unwrap().1, &Constant { name: Some("a".to_owned()), - r#override: crate::Override::None, ty: ty_handle, init: init_handle } @@ -567,7 +566,6 @@ fn constants() { constants.next().unwrap().1, &Constant { name: Some("b".to_owned()), - r#override: crate::Override::None, ty: ty_handle, init: init_handle } @@ -636,7 +634,7 @@ fn implicit_conversions() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Unknown function \'test\'".into()), meta: Span::new(156, 165), @@ -660,7 +658,7 @@ fn implicit_conversions() { ) .err() .unwrap(), - ParseError { + ParseErrors { errors: vec![Error { kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()), meta: Span::new(158, 165), diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs index e87d76fffc..f6836169c0 100644 --- a/third_party/rust/naga/src/front/glsl/types.rs +++ b/third_party/rust/naga/src/front/glsl/types.rs @@ -233,7 +233,7 @@ impl Context<'_> { }; let expressions = if self.is_const { - &self.module.const_expressions + &self.module.global_expressions } else { &self.expressions }; @@ -330,23 +330,25 @@ impl Context<'_> { expr: Handle<Expression>, ) -> Result<Handle<Expression>> { let meta = self.expressions.get_span(expr); - Ok(match self.expressions[expr] { + let h = match self.expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) - | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta), + | Expression::ZeroValue(_)) => { + self.module.global_expressions.append(expr.clone(), meta) + } Expression::Compose { ty, ref components } => { let mut components = components.clone(); for component in &mut components { *component = self.lift_up_const_expression(*component)?; } self.module - .const_expressions + .global_expressions .append(Expression::Compose { ty, components }, meta) } Expression::Splat { size, value } => { let value = self.lift_up_const_expression(value)?; self.module - .const_expressions + .global_expressions .append(Expression::Splat { size, value }, meta) } _ => { @@ -355,6 +357,9 @@ impl Context<'_> { meta, }) } - }) + }; + self.global_expression_kind_tracker + .insert(h, crate::proc::ExpressionKind::Const); + Ok(h) } } diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs index 9d2e7a0e7b..0725fbd94f 100644 --- a/third_party/rust/naga/src/front/glsl/variables.rs +++ b/third_party/rust/naga/src/front/glsl/variables.rs @@ -472,7 +472,6 @@ impl Frontend { let constant = Constant { name: name.clone(), - r#override: crate::Override::None, ty, init, }; diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs index f0a714fbeb..a6bf0e0451 100644 --- a/third_party/rust/naga/src/front/spv/convert.rs +++ b/third_party/rust/naga/src/front/spv/convert.rs @@ -153,6 +153,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::B Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId, Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize, Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups, + // subgroup + Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups, + Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId, + Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize, + Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnsupportedBuiltIn(word)), }) } diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs index af025636c0..44beadce98 100644 --- a/third_party/rust/naga/src/front/spv/error.rs +++ b/third_party/rust/naga/src/front/spv/error.rs @@ -5,7 +5,7 @@ use codespan_reporting::files::SimpleFile; use codespan_reporting::term; use termcolor::{NoColor, WriteColor}; -#[derive(Debug, thiserror::Error)] +#[derive(Clone, Debug, thiserror::Error)] pub enum Error { #[error("invalid header")] InvalidHeader, @@ -58,6 +58,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group operation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -118,8 +120,8 @@ pub enum Error { ControlFlowGraphCycle(crate::front::spv::BlockId), #[error("recursive function call %{0}")] FunctionCallCycle(spirv::Word), - #[error("invalid array size {0:?}")] - InvalidArraySize(Handle<crate::Constant>), + #[error("invalid array size %{0}")] + InvalidArraySize(spirv::Word), #[error("invalid barrier scope %{0}")] InvalidBarrierScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] @@ -130,6 +132,8 @@ pub enum Error { come from a binding)" )] NonBindingArrayOfImageOrSamplers, + #[error("naga only supports specialization constant IDs up to 65535 but was given {0}")] + SpecIdTooHigh(u32), } impl Error { diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs index e81ecf5c9b..113ca56313 100644 --- a/third_party/rust/naga/src/front/spv/function.rs +++ b/third_party/rust/naga/src/front/spv/function.rs @@ -59,8 +59,11 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> { }) }, local_variables: Arena::new(), - expressions: self - .make_expression_storage(&module.global_variables, &module.constants), + expressions: self.make_expression_storage( + &module.global_variables, + &module.constants, + &module.overrides, + ), named_expressions: crate::NamedExpressions::default(), body: crate::Block::new(), } @@ -128,7 +131,8 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> { expressions: &mut fun.expressions, local_arena: &mut fun.local_variables, const_arena: &mut module.constants, - const_expressions: &mut module.const_expressions, + overrides: &mut module.overrides, + global_expressions: &mut module.global_expressions, type_arena: &module.types, global_arena: &module.global_variables, arguments: &fun.arguments, @@ -581,7 +585,8 @@ impl<'function> BlockContext<'function> { crate::proc::GlobalCtx { types: self.type_arena, constants: self.const_arena, - const_expressions: self.const_expressions, + overrides: self.overrides, + global_expressions: self.global_expressions, } } diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs index 0f25dd626b..284c4cf7fd 100644 --- a/third_party/rust/naga/src/front/spv/image.rs +++ b/third_party/rust/naga/src/front/spv/image.rs @@ -507,11 +507,14 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> { } spirv::ImageOperands::CONST_OFFSET => { let offset_constant = self.next()?; - let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle; - let offset_handle = ctx.const_expressions.append( - crate::Expression::Constant(offset_handle), - Default::default(), - ); + let offset_expr = self + .lookup_constant + .lookup(offset_constant)? + .inner + .to_expr(); + let offset_handle = ctx + .global_expressions + .append(offset_expr, Default::default()); offset = Some(offset_handle); words_left -= 1; } diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs index b793448597..7ac5a18cd6 100644 --- a/third_party/rust/naga/src/front/spv/mod.rs +++ b/third_party/rust/naga/src/front/spv/mod.rs @@ -196,7 +196,7 @@ struct Decoration { location: Option<spirv::Word>, desc_set: Option<spirv::Word>, desc_index: Option<spirv::Word>, - specialization: Option<spirv::Word>, + specialization_constant_id: Option<spirv::Word>, storage_buffer: bool, offset: Option<spirv::Word>, array_stride: Option<NonZeroU32>, @@ -216,11 +216,6 @@ impl Decoration { } } - fn specialization(&self) -> crate::Override { - self.specialization - .map_or(crate::Override::None, crate::Override::ByNameOrId) - } - const fn resource_binding(&self) -> Option<crate::ResourceBinding> { match *self { Decoration { @@ -284,8 +279,23 @@ struct LookupType { } #[derive(Debug)] +enum Constant { + Constant(Handle<crate::Constant>), + Override(Handle<crate::Override>), +} + +impl Constant { + const fn to_expr(&self) -> crate::Expression { + match *self { + Self::Constant(c) => crate::Expression::Constant(c), + Self::Override(o) => crate::Expression::Override(o), + } + } +} + +#[derive(Debug)] struct LookupConstant { - handle: Handle<crate::Constant>, + inner: Constant, type_id: spirv::Word, } @@ -537,7 +547,8 @@ struct BlockContext<'function> { local_arena: &'function mut Arena<crate::LocalVariable>, /// Constants arena of the module being processed const_arena: &'function mut Arena<crate::Constant>, - const_expressions: &'function mut Arena<crate::Expression>, + overrides: &'function mut Arena<crate::Override>, + global_expressions: &'function mut Arena<crate::Expression>, /// Type arena of the module being processed type_arena: &'function UniqueArena<crate::Type>, /// Global arena of the module being processed @@ -757,7 +768,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { dec.matrix_major = Some(Majority::Row); } spirv::Decoration::SpecId => { - dec.specialization = Some(self.next()?); + dec.specialization_constant_id = Some(self.next()?); } other => { log::warn!("Unknown decoration {:?}", other); @@ -1393,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { inst.expect(5)?; let init_id = self.next()?; let lconst = self.lookup_constant.lookup(init_id)?; - Some( - ctx.expressions - .append(crate::Expression::Constant(lconst.handle), span), - ) + Some(ctx.expressions.append(lconst.inner.to_expr(), span)) } else { None }; @@ -3650,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let semantics_const = self.lookup_constant.lookup(semantics_id)?; - let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; - let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) + let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; if exec_scope == spirv::Scope::Workgroup as u32 { @@ -3692,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(5)?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| match predicate_const.inner { + Constant::Constant(constant) => matches!( + ctx.gctx().global_expressions[ctx.gctx().constants[constant].init], + crate::Expression::Literal(crate::Literal::Bool(true)), + ), + Constant::Override(_) => false, + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + let result_handle = ctx + .expressions + .append(crate::Expression::SubgroupBallotResult, span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + emitter.start(ctx.expressions); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + block.extend(emitter.finish(ctx.expressions)); + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 5 + } else { + 6 + }, + )?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 5 + } else { + 6 + }, + )?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3713,6 +3969,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { &mut self, globals: &Arena<crate::GlobalVariable>, constants: &Arena<crate::Constant>, + overrides: &Arena<crate::Override>, ) -> Arena<crate::Expression> { let mut expressions = Arena::new(); #[allow(clippy::panic)] @@ -3737,8 +3994,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> { } // register constants for (&id, con) in self.lookup_constant.iter() { - let span = constants.get_span(con.handle); - let handle = expressions.append(crate::Expression::Constant(con.handle), span); + let (expr, span) = match con.inner { + Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)), + Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)), + }; + let handle = expressions.append(expr, span); self.lookup_expression.insert( id, LookupExpression { @@ -3812,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, @@ -3944,10 +4207,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> { Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), - Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), + Op::ConstantComposite | Op::SpecConstantComposite => { + self.parse_composite_constant(inst, &mut module) + } Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), - Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), - Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module), + Op::ConstantTrue | Op::SpecConstantTrue => { + self.parse_bool_constant(inst, true, &mut module) + } + Op::ConstantFalse | Op::SpecConstantFalse => { + self.parse_bool_constant(inst, false, &mut module) + } Op::Variable => self.parse_global_variable(inst, &mut module), Op::Function => { self.switch(ModuleState::Function, inst.op)?; @@ -4504,9 +4773,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let length_id = self.next()?; let length_const = self.lookup_constant.lookup(length_id)?; - let size = resolve_constant(module.to_ctx(), length_const.handle) + let size = resolve_constant(module.to_ctx(), &length_const.inner) .and_then(NonZeroU32::new) - .ok_or(Error::InvalidArraySize(length_const.handle))?; + .ok_or(Error::InvalidArraySize(length_id))?; let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; @@ -4919,29 +5188,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> { _ => return Err(Error::UnsupportedType(type_lookup.handle)), }; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Literal(literal), span); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_composite_constant( @@ -4965,34 +5218,18 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let span = self.span_from_with_op(start); let constant = self.lookup_constant.lookup(component_id)?; let expr = module - .const_expressions - .append(crate::Expression::Constant(constant.handle), span); + .global_expressions + .append(constant.inner.to_expr(), span); components.push(expr); } - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_null_constant( @@ -5010,23 +5247,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let init = module - .const_expressions + .global_expressions .append(crate::Expression::ZeroValue(ty), span); - let handle = module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ); - self.lookup_constant - .insert(id, LookupConstant { handle, type_id }); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_bool_constant( @@ -5045,27 +5270,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - - let init = module.const_expressions.append( + let init = module.global_expressions.append( crate::Expression::Literal(crate::Literal::Bool(value)), span, ); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); + + self.insert_parsed_constant(module, id, type_id, ty, init, span) + } + + fn insert_parsed_constant( + &mut self, + module: &mut crate::Module, + id: u32, + type_id: u32, + ty: Handle<crate::Type>, + init: Handle<crate::Expression>, + span: crate::Span, + ) -> Result<(), Error> { + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let inner = if let Some(id) = decor.specialization_constant_id { + let o = crate::Override { + name: decor.name, + id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?), + ty, + init: Some(init), + }; + Constant::Override(module.overrides.append(o, span)) + } else { + let c = crate::Constant { + name: decor.name, + ty, + init, + }; + Constant::Constant(module.constants.append(c, span)) + }; + + self.lookup_constant + .insert(id, LookupConstant { inner, type_id }); Ok(()) } @@ -5087,8 +5329,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let span = self.span_from_with_op(start); let lconst = self.lookup_constant.lookup(init_id)?; let expr = module - .const_expressions - .append(crate::Expression::Constant(lconst.handle), span); + .global_expressions + .append(lconst.inner.to_expr(), span); Some(expr) } else { None @@ -5209,7 +5451,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { match null::generate_default_built_in( Some(built_in), ty, - &mut module.const_expressions, + &mut module.global_expressions, span, ) { Ok(handle) => Some(handle), @@ -5231,14 +5473,14 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let handle = null::generate_default_built_in( built_in, member.ty, - &mut module.const_expressions, + &mut module.global_expressions, span, )?; components.push(handle); } Some( module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span), ) } @@ -5303,11 +5545,12 @@ fn make_index_literal( Ok(expr) } -fn resolve_constant( - gctx: crate::proc::GlobalCtx, - constant: Handle<crate::Constant>, -) -> Option<u32> { - match gctx.const_expressions[gctx.constants[constant].init] { +fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> { + let constant = match *constant { + Constant::Constant(constant) => constant, + Constant::Override(_) => return None, + }; + match gctx.global_expressions[gctx.constants[constant].init] { crate::Expression::Literal(crate::Literal::U32(id)) => Some(id), crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32), _ => None, diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs index 42cccca80a..c7d3776841 100644 --- a/third_party/rust/naga/src/front/spv/null.rs +++ b/third_party/rust/naga/src/front/spv/null.rs @@ -5,14 +5,14 @@ use crate::arena::{Arena, Handle}; pub fn generate_default_built_in( built_in: Option<crate::BuiltIn>, ty: Handle<crate::Type>, - const_expressions: &mut Arena<crate::Expression>, + global_expressions: &mut Arena<crate::Expression>, span: crate::Span, ) -> Result<Handle<crate::Expression>, Error> { let expr = match built_in { Some(crate::BuiltIn::Position { .. }) => { - let zero = const_expressions + let zero = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(0.0)), span); - let one = const_expressions + let one = global_expressions .append(crate::Expression::Literal(crate::Literal::F32(1.0)), span); crate::Expression::Compose { ty, @@ -27,5 +27,5 @@ pub fn generate_default_built_in( // Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path _ => crate::Expression::ZeroValue(ty), }; - Ok(const_expressions.append(expr, span)) + Ok(global_expressions.append(expr, span)) } diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs index 54aa8296b1..dc1339521c 100644 --- a/third_party/rust/naga/src/front/wgsl/error.rs +++ b/third_party/rust/naga/src/front/wgsl/error.rs @@ -13,6 +13,7 @@ use thiserror::Error; #[derive(Clone, Debug)] pub struct ParseError { message: String, + // The first span should be the primary span, and the other ones should be complementary. labels: Vec<(Span, Cow<'static, str>)>, notes: Vec<String>, } @@ -190,7 +191,7 @@ pub enum Error<'a> { expected: String, got: String, }, - MissingType(Span), + DeclMissingTypeAndInit(Span), MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), @@ -269,6 +270,11 @@ pub enum Error<'a> { scalar: String, inner: ConstantEvaluatorError, }, + ExceededLimitForNestedBraces { + span: Span, + limit: u8, + }, + PipelineConstantIDValue(Span), } impl<'a> Error<'a> { @@ -518,11 +524,11 @@ impl<'a> Error<'a> { notes: vec![], } } - Error::MissingType(name_span) => ParseError { - message: format!("variable `{}` needs a type", &source[name_span]), + Error::DeclMissingTypeAndInit(name_span) => ParseError { + message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]), labels: vec![( name_span, - format!("definition of `{}`", &source[name_span]).into(), + "needs a type specifier or initializer".into(), )], notes: vec![], }, @@ -770,6 +776,21 @@ impl<'a> Error<'a> { format!("the expression should have been converted to have {} scalar type", scalar), ] }, + Error::ExceededLimitForNestedBraces { span, limit } => ParseError { + message: "brace nesting limit reached".into(), + labels: vec![(span, "limit reached at this brace".into())], + notes: vec![ + format!("nesting limit is currently set to {limit}"), + ], + }, + Error::PipelineConstantIDValue(span) => ParseError { + message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(), + labels: vec![( + span, + "must be between 0 and 65535 inclusive".into(), + )], + notes: vec![], + }, } } } diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs index a5524fe8f1..593405508f 100644 --- a/third_party/rust/naga/src/front/wgsl/index.rs +++ b/third_party/rust/naga/src/front/wgsl/index.rs @@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> { ast::GlobalDeclKind::Fn(ref f) => f.name, ast::GlobalDeclKind::Var(ref v) => v.name, ast::GlobalDeclKind::Const(ref c) => c.name, + ast::GlobalDeclKind::Override(ref o) => o.name, ast::GlobalDeclKind::Struct(ref s) => s.name, ast::GlobalDeclKind::Type(ref t) => t.name, } diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs index 2ca6c182b7..e7cce17723 100644 --- a/third_party/rust/naga/src/front/wgsl/lower/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs @@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> { module: &'out mut crate::Module, const_typifier: &'temp mut Typifier, + + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } impl<'source> GlobalContext<'source, '_, '_> { @@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> { module: self.module, const_typifier: self.const_typifier, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, + } + } + + fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Override, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -164,7 +179,8 @@ pub struct StatementContext<'source, 'temp, 'out> { /// with the form of the expressions; it is also tracking whether WGSL says /// we should consider them to be const. See the use of `force_non_const` in /// the code for lowering `let` bindings. - expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, + local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { @@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, module: self.module, expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { local_table: self.local_table, @@ -188,7 +205,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { block, emitter, typifier: self.typifier, - expression_constness: self.expression_constness, + local_expression_kind_tracker: self.local_expression_kind_tracker, }), } } @@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -232,8 +250,8 @@ pub struct RuntimeExpressionContext<'temp, 'out> { /// Which `Expression`s in `self.naga_expressions` are const expressions, in /// the WGSL sense. /// - /// See [`StatementContext::expression_constness`] for details. - expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, + /// See [`StatementContext::local_expression_kind_tracker`] for details. + local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, } /// The type of Naga IR expression we are lowering an [`ast::Expression`] to. @@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> { /// available in the [`ExpressionContext`], so this variant /// carries no further information. Constant, + + /// We are lowering to an override expression, to be included in the module's + /// constant expression arena. + /// + /// Everything override expressions are allowed to refer to is + /// available in the [`ExpressionContext`], so this variant + /// carries no further information. + Override, } /// State for lowering an [`ast::Expression`] to Naga IR. @@ -307,10 +333,11 @@ pub struct ExpressionContext<'source, 'temp, 'out> { /// [`Module`]: crate::Module module: &'out mut crate::Module, - /// Type judgments for [`module::const_expressions`]. + /// Type judgments for [`module::global_expressions`]. /// - /// [`module::const_expressions`]: crate::Module::const_expressions + /// [`module::global_expressions`]: crate::Module::global_expressions const_typifier: &'temp mut Typifier, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker, /// Whether we are lowering a constant expression or a general /// runtime expression, and the data needed in each case. @@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { const_typifier: self.const_typifier, module: self.module, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -344,11 +373,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function( self.module, &mut rctx.function.expressions, - rctx.expression_constness, + rctx.local_expression_kind_tracker, rctx.emitter, rctx.block, ), - ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module), + ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + false, + ), + ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + true, + ), } } @@ -358,24 +396,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { span: Span, ) -> Result<Handle<crate::Expression>, Error<'source>> { let mut eval = self.as_const_evaluator(); - match eval.try_eval_and_append(&expr, span) { - Ok(expr) => Ok(expr), - - // `expr` is not a constant expression. This is fine as - // long as we're not building `Module::const_expressions`. - Err(err) => match self.expr_type { - ExpressionContextType::Runtime(ref mut rctx) => { - Ok(rctx.function.expressions.append(expr, span)) - } - ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)), - }, - } + eval.try_eval_and_append(expr, span) + .map_err(|e| Error::ConstantEvaluatorError(e, span)) } fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => { - if !ctx.expression_constness.is_const(handle) { + if !ctx.local_expression_kind_tracker.is_const(handle) { return None; } @@ -385,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .ok() } ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + ExpressionContextType::Override => None, } } fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), - ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.module.global_expressions.get_span(handle) + } } } fn typifier(&self) -> &Typifier { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.const_typifier + } } } @@ -408,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), - ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), + ExpressionContextType::Constant | ExpressionContextType::Override => { + Err(Error::UnexpectedOperationInConstContext(span)) + } } } @@ -420,7 +455,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<crate::SwizzleComponent, Error<'source>> { match self.expr_type { ExpressionContextType::Runtime(ref rctx) => { - if !rctx.expression_constness.is_const(expr) { + if !rctx.local_expression_kind_tracker.is_const(expr) { return Err(Error::ExpectedConstExprConcreteIntegerScalar( component_span, )); @@ -445,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { } // This means a `gather` operation appeared in a constant expression. // This error refers to the `gather` itself, not its "component" argument. - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { Err(Error::UnexpectedOperationInConstContext(gather_span)) } } @@ -471,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { // to also borrow self.module.types mutably below. let typifier = match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => &*self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + &*self.const_typifier + } }; Ok(typifier.register_type(handle, &mut self.module.types)) } @@ -514,10 +551,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { typifier = &mut *ctx.typifier; expressions = &ctx.function.expressions; } - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); typifier = self.const_typifier; - expressions = &self.module.const_expressions; + expressions = &self.module.global_expressions; } }; typifier @@ -610,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } let result = self.append_expression(expression, span); match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { rctx.emitter.start(&rctx.function.expressions); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } result } @@ -786,6 +823,7 @@ enum LoweredGlobalDecl { Function(Handle<crate::Function>), Var(Handle<crate::GlobalVariable>), Const(Handle<crate::Constant>), + Override(Handle<crate::Override>), Type(Handle<crate::Type>), EntryPoint, } @@ -836,6 +874,29 @@ impl Texture { } } +enum SubgroupGather { + BroadcastFirst, + Broadcast, + Shuffle, + ShuffleDown, + ShuffleUp, + ShuffleXor, +} + +impl SubgroupGather { + pub fn map(word: &str) -> Option<Self> { + Some(match word { + "subgroupBroadcastFirst" => Self::BroadcastFirst, + "subgroupBroadcast" => Self::Broadcast, + "subgroupShuffle" => Self::Shuffle, + "subgroupShuffleDown" => Self::ShuffleDown, + "subgroupShuffleUp" => Self::ShuffleUp, + "subgroupShuffleXor" => Self::ShuffleXor, + _ => return None, + }) + } +} + pub struct Lowerer<'source, 'temp> { index: &'temp Index<'source>, layouter: Layouter, @@ -861,6 +922,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { types: &tu.types, module: &mut module, const_typifier: &mut Typifier::new(), + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(), }; for decl_handle in self.index.visit_ordered() { @@ -877,7 +939,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let init; if let Some(init_ast) = v.init { - let mut ectx = ctx.as_const(); + let mut ectx = ctx.as_override(); let lowered = self.expression_for_abstract(init_ast, &mut ectx)?; let ty_res = crate::proc::TypeResolution::Handle(ty); let converted = ectx @@ -956,7 +1018,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let handle = ctx.module.constants.append( crate::Constant { name: Some(c.name.name.to_string()), - r#override: crate::Override::None, ty, init, }, @@ -966,6 +1027,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.globals .insert(c.name.name, LoweredGlobalDecl::Const(handle)); } + ast::GlobalDeclKind::Override(ref o) => { + let init = o + .init + .map(|init| self.expression(init, &mut ctx.as_override())) + .transpose()?; + let inferred_type = init + .map(|init| ctx.as_const().register_type(init)) + .transpose()?; + + let explicit_ty = + o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) + .transpose()?; + + let id = + o.id.map(|id| self.const_u32(id, &mut ctx.as_const())) + .transpose()?; + + let id = if let Some((id, id_span)) = id { + Some( + u16::try_from(id) + .map_err(|_| Error::PipelineConstantIDValue(id_span))?, + ) + } else { + None + }; + + let ty = match (explicit_ty, inferred_type) { + (Some(explicit_ty), Some(inferred_type)) => { + if explicit_ty == inferred_type { + explicit_ty + } else { + let gctx = ctx.module.to_ctx(); + return Err(Error::InitializationTypeMismatch { + name: o.name.span, + expected: explicit_ty.to_wgsl(&gctx), + got: inferred_type.to_wgsl(&gctx), + }); + } + } + (Some(explicit_ty), None) => explicit_ty, + (None, Some(inferred_type)) => inferred_type, + (None, None) => { + return Err(Error::DeclMissingTypeAndInit(o.name.span)); + } + }; + + let handle = ctx.module.overrides.append( + crate::Override { + name: Some(o.name.name.to_string()), + id, + ty, + init, + }, + span, + ); + + ctx.globals + .insert(o.name.name, LoweredGlobalDecl::Override(handle)); + } ast::GlobalDeclKind::Struct(ref s) => { let handle = self.r#struct(s, span, &mut ctx)?; ctx.globals @@ -1000,6 +1120,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut local_table = FastHashMap::default(); let mut expressions = Arena::new(); let mut named_expressions = FastIndexMap::default(); + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); let arguments = f .arguments @@ -1011,6 +1132,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); local_table.insert(arg.handle, Typed::Plain(expr)); named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); + local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime); Ok(crate::FunctionArgument { name: Some(arg.name.name.to_string()), @@ -1053,7 +1175,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { named_expressions: &mut named_expressions, types: ctx.types, module: ctx.module, - expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + local_expression_kind_tracker: &mut local_expression_kind_tracker, + global_expression_kind_tracker: ctx.global_expression_kind_tracker, }; let mut body = self.block(&f.body, false, &mut stmt_ctx)?; ensure_block_returns(&mut body); @@ -1132,7 +1255,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // affects when errors must be reported, so we can't even // treat suitable `let` bindings as constant as an // optimization. - ctx.expression_constness.force_non_const(value); + ctx.local_expression_kind_tracker.force_non_const(value); let explicit_ty = l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global())) @@ -1203,7 +1326,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty = explicit_ty; initializer = None; } - (None, None) => return Err(Error::MissingType(v.name.span)), + (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)), } let (const_initializer, initializer) = { @@ -1216,7 +1339,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // - the initialization is not a constant // expression, so its value depends on the // state at the point of initialization. - if is_inside_loop || !ctx.expression_constness.is_const(init) { + if is_inside_loop + || !ctx.local_expression_kind_tracker.is_const_or_override(init) + { (None, Some(init)) } else { (Some(init), None) @@ -1469,6 +1594,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .function .expressions .append(crate::Expression::Binary { op, left, right }, stmt.span); + rctx.local_expression_kind_tracker + .insert(left, crate::proc::ExpressionKind::Runtime); + rctx.local_expression_kind_tracker + .insert(value, crate::proc::ExpressionKind::Runtime); block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Store { @@ -1562,7 +1691,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { LoweredGlobalDecl::Const(handle) => { Typed::Plain(crate::Expression::Constant(handle)) } - _ => { + LoweredGlobalDecl::Override(handle) => { + Typed::Plain(crate::Expression::Override(handle)) + } + LoweredGlobalDecl::Function(_) + | LoweredGlobalDecl::Type(_) + | LoweredGlobalDecl::EntryPoint => { return Err(Error::Unexpected(span, ExpectedToken::Variable)); } }; @@ -1819,9 +1953,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; Ok(Some(handle)) } - Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => { - Err(Error::Unexpected(function.span, ExpectedToken::Function)) - } + Some( + &LoweredGlobalDecl::Const(_) + | &LoweredGlobalDecl::Override(_) + | &LoweredGlobalDecl::Var(_), + ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)), Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)), Some(&LoweredGlobalDecl::Function(function)) => { let arguments = arguments @@ -1835,9 +1971,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); let result = has_result.then(|| { - rctx.function + let result = rctx + .function .expressions - .append(crate::Expression::CallResult(function), span) + .append(crate::Expression::CallResult(function), span); + rctx.local_expression_kind_tracker + .insert(result, crate::proc::ExpressionKind::Runtime); + result }); rctx.emitter.start(&rctx.function.expressions); rctx.block.push( @@ -1937,6 +2077,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx)? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = SubgroupGather::map(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); + } else if let Some(fun) = crate::AtomicFunction::map(function.name) { + return Ok(Some(self.atomic_helper(span, fun, arguments, ctx)?)); } else { match function.name { "select" => { @@ -1982,70 +2132,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Store { pointer, value }, span); return Ok(None); } - "atomicAdd" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Add, - arguments, - ctx, - )?)) - } - "atomicSub" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Subtract, - arguments, - ctx, - )?)) - } - "atomicAnd" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::And, - arguments, - ctx, - )?)) - } - "atomicOr" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::InclusiveOr, - arguments, - ctx, - )?)) - } - "atomicXor" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::ExclusiveOr, - arguments, - ctx, - )?)) - } - "atomicMin" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Min, - arguments, - ctx, - )?)) - } - "atomicMax" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Max, - arguments, - ctx, - )?)) - } - "atomicExchange" => { - return Ok(Some(self.atomic_helper( - span, - crate::AtomicFunction::Exchange { compare: None }, - arguments, - ctx, - )?)) - } "atomicCompareExchangeWeak" => { let mut args = ctx.prepare_args(arguments, 3, span); @@ -2104,6 +2190,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); } + "subgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span); + return Ok(None); + } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = args.next()?; @@ -2311,6 +2405,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; return Ok(Some(handle)); } + "subgroupBallot" => { + let mut args = ctx.prepare_args(arguments, 0, span); + let predicate = if arguments.len() == 1 { + Some(self.expression(args.next()?, ctx)?) + } else { + None + }; + args.finish()?; + + let result = ctx + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::SubgroupBallot { result, predicate }, span); + return Ok(Some(result)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2502,6 +2612,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }) } + fn subgroup_operation_helper( + &mut self, + span: Span, + op: crate::SubgroupOperation, + collective_op: crate::CollectiveOperation, + arguments: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + }, + span, + ); + Ok(result) + } + + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: SubgroupGather, + arguments: &[Handle<ast::Expression<'source>>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result<Handle<crate::Expression>, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx)?; + + use SubgroupGather as Sg; + let mode = if let Sg::BroadcastFirst = mode { + crate::GatherMode::BroadcastFirst + } else { + let index = self.expression(args.next()?, ctx)?; + match mode { + Sg::Broadcast => crate::GatherMode::Broadcast(index), + Sg::Shuffle => crate::GatherMode::Shuffle(index), + Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index), + Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index), + Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index), + Sg::BroadcastFirst => unreachable!(), + } + }; + + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, @@ -2760,3 +2944,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } } + +impl crate::AtomicFunction { + pub fn map(word: &str) -> Option<Self> { + Some(match word { + "atomicAdd" => crate::AtomicFunction::Add, + "atomicSub" => crate::AtomicFunction::Subtract, + "atomicAnd" => crate::AtomicFunction::And, + "atomicOr" => crate::AtomicFunction::InclusiveOr, + "atomicXor" => crate::AtomicFunction::ExclusiveOr, + "atomicMin" => crate::AtomicFunction::Min, + "atomicMax" => crate::AtomicFunction::Max, + "atomicExchange" => crate::AtomicFunction::Exchange { compare: None }, + _ => return None, + }) + } +} diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs index b6151fe1c0..aec1e657fc 100644 --- a/third_party/rust/naga/src/front/wgsl/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/mod.rs @@ -44,6 +44,17 @@ impl Frontend { } } +/// <div class="warning"> +// NOTE: Keep this in sync with `wgpu::Device::create_shader_module`! +// NOTE: Keep this in sync with `wgpu_core::Global::device_create_shader_module`! +/// +/// This function may consume a lot of stack space. Compiler-enforced limits for parsing recursion +/// exist; if shader compilation runs into them, it will return an error gracefully. However, on +/// some build profiles and platforms, the default stack size for a thread may be exceeded before +/// this limit is reached during parsing. Callers should ensure that there is enough stack space +/// for this, particularly if calls to this method are exposed to user input. +/// +/// </div> pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> { Frontend::new().parse(source) } diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs index dbaac523cb..ea8013ee7c 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/ast.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs @@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> { Fn(Function<'a>), Var(GlobalVariable<'a>), Const(Const<'a>), + Override(Override<'a>), Struct(Struct<'a>), Type(TypeAlias<'a>), } @@ -200,6 +201,14 @@ pub struct Const<'a> { pub init: Handle<Expression<'a>>, } +#[derive(Debug)] +pub struct Override<'a> { + pub name: Ident<'a>, + pub id: Option<Handle<Expression<'a>>>, + pub ty: Option<Handle<Type<'a>>>, + pub init: Option<Handle<Expression<'a>>>, +} + /// The size of an [`Array`] or [`BindingArray`]. /// /// [`Array`]: Type::Array diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs index 1a4911a3bd..207f0eda41 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/conv.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs @@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>> "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, + // subgroup + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, + "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -260,3 +265,26 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs index 51fc2f013b..79ea1ae609 100644 --- a/third_party/rust/naga/src/front/wgsl/parse/mod.rs +++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs @@ -1619,22 +1619,21 @@ impl Parser { lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, block: &mut ast::Block<'a>, + brace_nesting_level: u8, ) -> Result<(), Error<'a>> { self.push_rule_span(Rule::Statement, lexer); match lexer.peek() { (Token::Separator(';'), _) => { let _ = lexer.next(); self.pop_rule_span(lexer); - return Ok(()); } (Token::Paren('{'), _) => { - let (inner, span) = self.block(lexer, ctx)?; + let (inner, span) = self.block(lexer, ctx, brace_nesting_level)?; block.stmts.push(ast::Statement { kind: ast::StatementKind::Block(inner), span, }); self.pop_rule_span(lexer); - return Ok(()); } (Token::Word(word), _) => { let kind = match word { @@ -1711,7 +1710,7 @@ impl Parser { let _ = lexer.next(); let condition = self.general_expression(lexer, ctx)?; - let accept = self.block(lexer, ctx)?.0; + let accept = self.block(lexer, ctx, brace_nesting_level)?.0; let mut elsif_stack = Vec::new(); let mut elseif_span_start = lexer.start_byte_offset(); @@ -1722,12 +1721,12 @@ impl Parser { if !lexer.skip(Token::Word("if")) { // ... else { ... } - break self.block(lexer, ctx)?.0; + break self.block(lexer, ctx, brace_nesting_level)?.0; } // ... else if (...) { ... } let other_condition = self.general_expression(lexer, ctx)?; - let other_block = self.block(lexer, ctx)?; + let other_block = self.block(lexer, ctx, brace_nesting_level)?; elsif_stack.push((elseif_span_start, other_condition, other_block)); elseif_span_start = lexer.start_byte_offset(); }; @@ -1759,7 +1758,9 @@ impl Parser { "switch" => { let _ = lexer.next(); let selector = self.general_expression(lexer, ctx)?; - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = + Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut cases = Vec::new(); loop { @@ -1784,7 +1785,7 @@ impl Parser { }); }; - let body = self.block(lexer, ctx)?.0; + let body = self.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value, @@ -1794,7 +1795,7 @@ impl Parser { } (Token::Word("default"), _) => { lexer.skip(Token::Separator(':')); - let body = self.block(lexer, ctx)?.0; + let body = self.block(lexer, ctx, brace_nesting_level)?.0; cases.push(ast::SwitchCase { value: ast::SwitchValue::Default, body, @@ -1810,7 +1811,7 @@ impl Parser { ast::StatementKind::Switch { selector, cases } } - "loop" => self.r#loop(lexer, ctx)?, + "loop" => self.r#loop(lexer, ctx, brace_nesting_level)?, "while" => { let _ = lexer.next(); let mut body = ast::Block::default(); @@ -1834,7 +1835,7 @@ impl Parser { span, }); - let (block, span) = self.block(lexer, ctx)?; + let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1857,7 +1858,9 @@ impl Parser { let (_, span) = { let ctx = &mut *ctx; let block = &mut *block; - lexer.capture_span(|lexer| self.statement(lexer, ctx, block))? + lexer.capture_span(|lexer| { + self.statement(lexer, ctx, block, brace_nesting_level) + })? }; if block.stmts.len() != num_statements { @@ -1902,7 +1905,7 @@ impl Parser { lexer.expect(Token::Paren(')'))?; } - let (block, span) = self.block(lexer, ctx)?; + let (block, span) = self.block(lexer, ctx, brace_nesting_level)?; body.stmts.push(ast::Statement { kind: ast::StatementKind::Block(block), span, @@ -1964,13 +1967,15 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, + brace_nesting_level: u8, ) -> Result<ast::StatementKind<'a>, Error<'a>> { let _ = lexer.next(); let mut body = ast::Block::default(); let mut continuing = ast::Block::default(); let mut break_if = None; - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; ctx.local_table.push_scope(); @@ -1980,7 +1985,9 @@ impl Parser { // the last thing in the loop body // Expect a opening brace to start the continuing block - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = + Self::increase_brace_nesting(brace_nesting_level, brace_span)?; loop { if lexer.skip(Token::Word("break")) { // Branch for the `break if` statement, this statement @@ -2009,7 +2016,7 @@ impl Parser { break; } else { // Otherwise try to parse a statement - self.statement(lexer, ctx, &mut continuing)?; + self.statement(lexer, ctx, &mut continuing, brace_nesting_level)?; } } // Since the continuing block must be the last part of the loop body, @@ -2023,7 +2030,7 @@ impl Parser { break; } // Otherwise try to parse a statement - self.statement(lexer, ctx, &mut body)?; + self.statement(lexer, ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2040,15 +2047,17 @@ impl Parser { &mut self, lexer: &mut Lexer<'a>, ctx: &mut ExpressionContext<'a, '_, '_>, + brace_nesting_level: u8, ) -> Result<(ast::Block<'a>, Span), Error<'a>> { self.push_rule_span(Rule::Block, lexer); ctx.local_table.push_scope(); - lexer.expect(Token::Paren('{'))?; + let brace_span = lexer.expect_span(Token::Paren('{'))?; + let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?; let mut block = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, ctx, &mut block)?; + self.statement(lexer, ctx, &mut block, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2135,9 +2144,10 @@ impl Parser { // do not use `self.block` here, since we must not push a new scope lexer.expect(Token::Paren('{'))?; + let brace_nesting_level = 1; let mut body = ast::Block::default(); while !lexer.skip(Token::Paren('}')) { - self.statement(lexer, &mut ctx, &mut body)?; + self.statement(lexer, &mut ctx, &mut body, brace_nesting_level)?; } ctx.local_table.pop_scope(); @@ -2170,6 +2180,7 @@ impl Parser { let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); + let mut id = ParsedAttribute::default(); let mut dependencies = FastIndexSet::default(); let mut ctx = ExpressionContext { @@ -2193,6 +2204,11 @@ impl Parser { bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?; lexer.expect(Token::Paren(')'))?; } + ("id", name_span) => { + lexer.expect(Token::Paren('('))?; + id.set(self.general_expression(lexer, &mut ctx)?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } ("vertex", name_span) => { stage.set(crate::ShaderStage::Vertex, name_span)?; } @@ -2283,6 +2299,30 @@ impl Parser { Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init })) } + (Token::Word("override"), _) => { + let name = lexer.next_ident()?; + + let ty = if lexer.skip(Token::Separator(':')) { + Some(self.type_decl(lexer, &mut ctx)?) + } else { + None + }; + + let init = if lexer.skip(Token::Operation('=')) { + Some(self.general_expression(lexer, &mut ctx)?) + } else { + None + }; + + lexer.expect(Token::Separator(';'))?; + + Some(ast::GlobalDeclKind::Override(ast::Override { + name, + id: id.value, + ty, + init, + })) + } (Token::Word("var"), _) => { let mut var = self.variable_decl(lexer, &mut ctx)?; var.binding = binding.take(); @@ -2347,4 +2387,30 @@ impl Parser { Ok(tu) } + + const fn increase_brace_nesting( + brace_nesting_level: u8, + brace_span: Span, + ) -> Result<u8, Error<'static>> { + // From [spec.](https://gpuweb.github.io/gpuweb/wgsl/#limits): + // + // > § 2.4. Limits + // > + // > … + // > + // > Maximum nesting depth of brace-enclosed statements in a function[:] 127 + // + // _However_, we choose 64 instead because (a) it avoids stack overflows in CI and + // (b) we expect the limit to be decreased to 63 based on this conversation in + // WebGPU CTS upstream: + // <https://github.com/gpuweb/cts/pull/3389#discussion_r1543742701> + const BRACE_NESTING_MAXIMUM: u8 = 64; + if brace_nesting_level + 1 > BRACE_NESTING_MAXIMUM { + return Err(Error::ExceededLimitForNestedBraces { + span: brace_span, + limit: BRACE_NESTING_MAXIMUM, + }); + } + Ok(brace_nesting_level + 1) + } } diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs index c8331ace09..63bc9f7317 100644 --- a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs +++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs @@ -226,7 +226,8 @@ mod tests { let gctx = crate::proc::GlobalCtx { types: &types, constants: &crate::Arena::new(), - const_expressions: &crate::Arena::new(), + overrides: &crate::Arena::new(), + global_expressions: &crate::Arena::new(), }; let array = crate::TypeInner::Array { base: mytype1, diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs index 4b45174300..24e1b02c76 100644 --- a/third_party/rust/naga/src/lib.rs +++ b/third_party/rust/naga/src/lib.rs @@ -34,9 +34,6 @@ with optional span info, representing a series of statements executed in order. `EntryPoint`s or `Function` is a `Block`, and `Statement` has a [`Block`][Statement::Block] variant. -If the `clone` feature is enabled, [`Arena`], [`UniqueArena`], [`Type`], [`TypeInner`], -[`Constant`], [`Function`], [`EntryPoint`] and [`Module`] can be cloned. - ## Arenas To improve translator performance and reduce memory usage, most structures are @@ -175,7 +172,7 @@ tree. A Naga *constant expression* is one of the following [`Expression`] variants, whose operands (if any) are also constant expressions: - [`Literal`] -- [`Constant`], for [`Constant`s][const_type] whose [`override`] is [`None`] +- [`Constant`], for [`Constant`]s - [`ZeroValue`], for fixed-size types - [`Compose`] - [`Access`] @@ -194,8 +191,7 @@ A constant expression can be evaluated at module translation time. ## Override expressions A Naga *override expression* is the same as a [constant expression], -except that it is also allowed to refer to [`Constant`s][const_type] -whose [`override`] is something other than [`None`]. +except that it is also allowed to reference other [`Override`]s. An override expression can be evaluated at pipeline creation time. @@ -238,10 +234,6 @@ An override expression can be evaluated at pipeline creation time. [`Math`]: Expression::Math [`As`]: Expression::As -[const_type]: Constant -[`override`]: Constant::override -[`None`]: Override::None - [constant expression]: index.html#constant-expressions */ @@ -282,6 +274,7 @@ pub mod back; mod block; #[cfg(feature = "compact")] pub mod compact; +pub mod error; pub mod front; pub mod keywords; pub mod proc; @@ -439,6 +432,11 @@ pub enum BuiltIn { WorkGroupId, WorkGroupSize, NumWorkGroups, + // subgroup + NumSubgroups, + SubgroupId, + SubgroupSize, + SubgroupInvocationId, } /// Number of bytes per scalar. @@ -874,7 +872,7 @@ pub enum TypeInner { BindingArray { base: Handle<Type>, size: ArraySize }, } -#[derive(Debug, Clone, Copy, PartialOrd)] +#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -892,41 +890,37 @@ pub enum Literal { AbstractFloat(f64), } -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "clone", derive(Clone))] +/// Pipeline-overridable constant. +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum Override { - None, - ByName, - ByNameOrId(u32), +pub struct Override { + pub name: Option<String>, + /// Pipeline Constant ID. + pub id: Option<u16>, + pub ty: Handle<Type>, + + /// The default value of the pipeline-overridable constant. + /// + /// This [`Handle`] refers to [`Module::global_expressions`], not + /// any [`Function::expressions`] arena. + pub init: Option<Handle<Expression>>, } /// Constant value. -#[derive(Debug, PartialEq)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Constant { pub name: Option<String>, - pub r#override: Override, pub ty: Handle<Type>, /// The value of the constant. /// - /// This [`Handle`] refers to [`Module::const_expressions`], not + /// This [`Handle`] refers to [`Module::global_expressions`], not /// any [`Function::expressions`] arena. - /// - /// If [`override`] is [`None`], then this must be a Naga - /// [constant expression]. Otherwise, this may be a Naga - /// [override expression] or [constant expression]. - /// - /// [`override`]: Constant::override - /// [`None`]: Override::None - /// [constant expression]: index.html#constant-expressions - /// [override expression]: index.html#override-expressions pub init: Handle<Expression>, } @@ -992,7 +986,7 @@ pub struct GlobalVariable { pub ty: Handle<Type>, /// Initial value for this variable. /// - /// Expression handle lives in const_expressions + /// Expression handle lives in global_expressions pub init: Option<Handle<Expression>>, } @@ -1010,7 +1004,7 @@ pub struct LocalVariable { /// /// This handle refers to this `LocalVariable`'s function's /// [`expressions`] arena, but it is required to be an evaluated - /// constant expression. + /// override expression. /// /// [`expressions`]: Function::expressions pub init: Option<Handle<Expression>>, @@ -1289,6 +1283,51 @@ pub enum SwizzleComponent { W = 3, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum GatherMode { + /// All gather from the active lane with the smallest index + BroadcastFirst, + /// All gather from the same lane at the index given by the expression + Broadcast(Handle<Expression>), + /// Each gathers from a different lane at the index given by the expression + Shuffle(Handle<Expression>), + /// Each gathers from their lane plus the shift given by the expression + ShuffleDown(Handle<Expression>), + /// Each gathers from their lane minus the shift given by the expression + ShuffleUp(Handle<Expression>), + /// Each gathers from their lane xored with the given by the expression + ShuffleXor(Handle<Expression>), +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SubgroupOperation { + All = 0, + Any = 1, + Add = 2, + Mul = 3, + Min = 4, + Max = 5, + And = 6, + Or = 7, + Xor = 8, +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CollectiveOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, +} + bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -1297,9 +1336,11 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct Barrier: u32 { /// Barrier affects all `AddressSpace::Storage` accesses. - const STORAGE = 0x1; + const STORAGE = 1 << 0; /// Barrier affects all `AddressSpace::WorkGroup` accesses. - const WORK_GROUP = 0x2; + const WORK_GROUP = 1 << 1; + /// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction. + const SUB_GROUP = 1 << 2; } } @@ -1315,6 +1356,8 @@ pub enum Expression { Literal(Literal), /// Constant value. Constant(Handle<Constant>), + /// Pipeline-overridable constant. + Override(Handle<Override>), /// Zero value of a type. ZeroValue(Handle<Type>), /// Composite expression. @@ -1440,7 +1483,7 @@ pub enum Expression { gather: Option<SwizzleComponent>, coordinate: Handle<Expression>, array_index: Option<Handle<Expression>>, - /// Expression handle lives in const_expressions + /// Expression handle lives in global_expressions offset: Option<Handle<Expression>>, level: SampleLevel, depth_ref: Option<Handle<Expression>>, @@ -1598,6 +1641,15 @@ pub enum Expression { query: Handle<Expression>, committed: bool, }, + /// Result of a [`SubgroupBallot`] statement. + /// + /// [`SubgroupBallot`]: Statement::SubgroupBallot + SubgroupBallotResult, + /// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement. + /// + /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation + /// [`SubgroupGather`]: Statement::SubgroupGather + SubgroupOperationResult { ty: Handle<Type> }, } pub use block::Block; @@ -1882,6 +1934,39 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + /// Calculate a bitmask using a boolean from each active thread in the subgroup + SubgroupBallot { + /// The [`SubgroupBallotResult`] expression representing this load's result. + /// + /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult + result: Handle<Expression>, + /// The value from this thread to store in the ballot + predicate: Option<Handle<Expression>>, + }, + /// Gather a value from another active thread in the subgroup + SubgroupGather { + /// Specifies which thread to gather from + mode: GatherMode, + /// The value to broadcast over + argument: Handle<Expression>, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle<Expression>, + }, + /// Compute a collective operation across all active threads in the subgroup + SubgroupCollectiveOperation { + /// What operation to compute + op: SubgroupOperation, + /// How to combine the results + collective_op: CollectiveOperation, + /// The value to compute over + argument: Handle<Expression>, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle<Expression>, + }, } /// A function argument. @@ -1913,8 +1998,7 @@ pub struct FunctionResult { } /// A function defined in the module. -#[derive(Debug, Default)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -1978,8 +2062,7 @@ pub struct Function { /// [`Location`]: Binding::Location /// [`function`]: EntryPoint::function /// [`stage`]: EntryPoint::stage -#[derive(Debug)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2003,8 +2086,7 @@ pub struct EntryPoint { /// These cannot be spelled in WGSL source. /// /// Stored in [`SpecialTypes::predeclared_types`] and created by [`Module::generate_predeclared_type`]. -#[derive(Debug, PartialEq, Eq, Hash)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2021,8 +2103,7 @@ pub enum PredeclaredType { } /// Set of special types that can be optionally generated by the frontends. -#[derive(Debug, Default)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2057,8 +2138,7 @@ pub struct SpecialTypes { /// Alternatively, you can load an existing shader using one of the [available front ends][front]. /// /// When finished, you can export modules using one of the [available backends][back]. -#[derive(Debug, Default)] -#[cfg_attr(feature = "clone", derive(Clone))] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2069,6 +2149,8 @@ pub struct Module { pub special_types: SpecialTypes, /// Arena for the constants defined in this module. pub constants: Arena<Constant>, + /// Arena for the pipeline-overridable constants defined in this module. + pub overrides: Arena<Override>, /// Arena for the global variables defined in this module. pub global_variables: Arena<GlobalVariable>, /// [Constant expressions] and [override expressions] used by this module. @@ -2078,7 +2160,7 @@ pub struct Module { /// /// [Constant expressions]: index.html#constant-expressions /// [override expressions]: index.html#override-expressions - pub const_expressions: Arena<Expression>, + pub global_expressions: Arena<Expression>, /// Arena for the functions defined in this module. /// /// Each function must appear in this arena strictly before all its callers. diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs index 983af3718c..ead3d00980 100644 --- a/third_party/rust/naga/src/proc/constant_evaluator.rs +++ b/third_party/rust/naga/src/proc/constant_evaluator.rs @@ -4,8 +4,8 @@ use arrayvec::ArrayVec; use crate::{ arena::{Arena, Handle, UniqueArena}, - ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner, - UnaryOperator, + ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type, + TypeInner, UnaryOperator, }; /// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating @@ -253,9 +253,20 @@ gen_component_wise_extractor! { } #[derive(Debug)] -enum Behavior { - Wgsl, - Glsl, +enum Behavior<'a> { + Wgsl(WgslRestrictions<'a>), + Glsl(GlslRestrictions<'a>), +} + +impl Behavior<'_> { + /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions. + const fn has_runtime_restrictions(&self) -> bool { + matches!( + self, + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)) + ) + } } /// A context for evaluating constant expressions. @@ -278,7 +289,7 @@ enum Behavior { #[derive(Debug)] pub struct ConstantEvaluator<'a> { /// Which language's evaluation rules we should follow. - behavior: Behavior, + behavior: Behavior<'a>, /// The module's type arena. /// @@ -291,71 +302,155 @@ pub struct ConstantEvaluator<'a> { /// The module's constant arena. constants: &'a Arena<Constant>, + /// The module's override arena. + overrides: &'a Arena<Override>, + /// The arena to which we are contributing expressions. expressions: &'a mut Arena<Expression>, - /// When `self.expressions` refers to a function's local expression - /// arena, this needs to be populated - function_local_data: Option<FunctionLocalData<'a>>, + /// Tracks the constness of expressions residing in [`Self::expressions`] + expression_kind_tracker: &'a mut ExpressionKindTracker, +} + +#[derive(Debug)] +enum WgslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + Override, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), +} + +#[derive(Debug)] +enum GlslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), } #[derive(Debug)] struct FunctionLocalData<'a> { /// Global constant expressions - const_expressions: &'a Arena<Expression>, - /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions` - expression_constness: &'a mut ExpressionConstnessTracker, + global_expressions: &'a Arena<Expression>, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum ExpressionKind { + Const, + Override, + Runtime, +} + #[derive(Debug)] -pub struct ExpressionConstnessTracker { - inner: bit_set::BitSet, +pub struct ExpressionKindTracker { + inner: Vec<ExpressionKind>, } -impl ExpressionConstnessTracker { - pub fn new() -> Self { - Self { - inner: bit_set::BitSet::new(), - } +impl ExpressionKindTracker { + pub const fn new() -> Self { + Self { inner: Vec::new() } } /// Forces the the expression to not be const pub fn force_non_const(&mut self, value: Handle<Expression>) { - self.inner.remove(value.index()); + self.inner[value.index()] = ExpressionKind::Runtime; } - fn insert(&mut self, value: Handle<Expression>) { - self.inner.insert(value.index()); + pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) { + assert_eq!(self.inner.len(), value.index()); + self.inner.push(expr_type); + } + pub fn is_const(&self, h: Handle<Expression>) -> bool { + matches!(self.type_of(h), ExpressionKind::Const) + } + + pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool { + matches!( + self.type_of(h), + ExpressionKind::Const | ExpressionKind::Override + ) } - pub fn is_const(&self, value: Handle<Expression>) -> bool { - self.inner.contains(value.index()) + fn type_of(&self, value: Handle<Expression>) -> ExpressionKind { + self.inner[value.index()] } pub fn from_arena(arena: &Arena<Expression>) -> Self { - let mut tracker = Self::new(); - for (handle, expr) in arena.iter() { - let insert = match *expr { - crate::Expression::Literal(_) - | crate::Expression::ZeroValue(_) - | crate::Expression::Constant(_) => true, - crate::Expression::Compose { ref components, .. } => { - components.iter().all(|h| tracker.is_const(*h)) - } - crate::Expression::Splat { value, .. } => tracker.is_const(value), - _ => false, - }; - if insert { - tracker.insert(handle); - } + let mut tracker = Self { + inner: Vec::with_capacity(arena.len()), + }; + for (_, expr) in arena.iter() { + tracker.inner.push(tracker.type_of_with_expr(expr)); } tracker } + + fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind { + match *expr { + Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { + ExpressionKind::Const + } + Expression::Override(_) => ExpressionKind::Override, + Expression::Compose { ref components, .. } => { + let mut expr_type = ExpressionKind::Const; + for component in components { + expr_type = expr_type.max(self.type_of(*component)) + } + expr_type + } + Expression::Splat { value, .. } => self.type_of(value), + Expression::AccessIndex { base, .. } => self.type_of(base), + Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)), + Expression::Swizzle { vector, .. } => self.type_of(vector), + Expression::Unary { expr, .. } => self.type_of(expr), + Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)), + Expression::Math { + arg, + arg1, + arg2, + arg3, + .. + } => self + .type_of(arg) + .max( + arg1.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg2.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg3.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ), + Expression::As { expr, .. } => self.type_of(expr), + Expression::Select { + condition, + accept, + reject, + } => self + .type_of(condition) + .max(self.type_of(accept)) + .max(self.type_of(reject)), + Expression::Relational { argument, .. } => self.type_of(argument), + Expression::ArrayLength(expr) => self.type_of(expr), + _ => ExpressionKind::Runtime, + } + } } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, @@ -381,6 +476,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -432,6 +529,12 @@ pub enum ConstantEvaluatorError { ShiftedMoreThan32Bits, #[error(transparent)] Literal(#[from] crate::valid::LiteralError), + #[error("Can't use pipeline-overridable constants in const-expressions")] + Override, + #[error("Unexpected runtime-expression")] + RuntimeExpr, + #[error("Unexpected override-expression")] + OverrideExpr, } impl<'a> ConstantEvaluator<'a> { @@ -439,25 +542,49 @@ impl<'a> ConstantEvaluator<'a> { /// constant expression arena. /// /// Report errors according to WGSL's rules for constant evaluation. - pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Wgsl, module) + pub fn for_wgsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + in_override_ctx: bool, + ) -> Self { + Self::for_module( + Behavior::Wgsl(if in_override_ctx { + WgslRestrictions::Override + } else { + WgslRestrictions::Const + }), + module, + global_expression_kind_tracker, + ) } /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s /// constant expression arena. /// /// Report errors according to GLSL's rules for constant evaluation. - pub fn for_glsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Glsl, module) + pub fn for_glsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + ) -> Self { + Self::for_module( + Behavior::Glsl(GlslRestrictions::Const), + module, + global_expression_kind_tracker, + ) } - fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self { + fn for_module( + behavior: Behavior<'a>, + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionKindTracker, + ) -> Self { Self { behavior, types: &mut module.types, constants: &module.constants, - expressions: &mut module.const_expressions, - function_local_data: None, + overrides: &module.overrides, + expressions: &mut module.global_expressions, + expression_kind_tracker: global_expression_kind_tracker, } } @@ -468,18 +595,22 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_wgsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena<Expression>, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { - Self::for_function( - Behavior::Wgsl, - module, + Self { + behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData { + global_expressions: &module.global_expressions, + emitter, + block, + })), + types: &mut module.types, + constants: &module.constants, + overrides: &module.overrides, expressions, - expression_constness, - emitter, - block, - ) + expression_kind_tracker: local_expression_kind_tracker, + } } /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s @@ -489,39 +620,21 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_glsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena<Expression>, - expression_constness: &'a mut ExpressionConstnessTracker, - emitter: &'a mut super::Emitter, - block: &'a mut crate::Block, - ) -> Self { - Self::for_function( - Behavior::Glsl, - module, - expressions, - expression_constness, - emitter, - block, - ) - } - - fn for_function( - behavior: Behavior, - module: &'a mut crate::Module, - expressions: &'a mut Arena<Expression>, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionKindTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { Self { - behavior, + behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData { + global_expressions: &module.global_expressions, + emitter, + block, + })), types: &mut module.types, constants: &module.constants, + overrides: &module.overrides, expressions, - function_local_data: Some(FunctionLocalData { - const_expressions: &module.const_expressions, - expression_constness, - emitter, - block, - }), + expression_kind_tracker: local_expression_kind_tracker, } } @@ -529,19 +642,18 @@ impl<'a> ConstantEvaluator<'a> { crate::proc::GlobalCtx { types: self.types, constants: self.constants, - const_expressions: match self.function_local_data { - Some(ref data) => data.const_expressions, + overrides: self.overrides, + global_expressions: match self.function_local_data() { + Some(data) => data.global_expressions, None => self.expressions, }, } } fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> { - if let Some(ref function_local_data) = self.function_local_data { - if !function_local_data.expression_constness.is_const(expr) { - log::debug!("check: SubexpressionsAreNotConstant"); - return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); - } + if !self.expression_kind_tracker.is_const(expr) { + log::debug!("check: SubexpressionsAreNotConstant"); + return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); } Ok(()) } @@ -554,11 +666,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::Constant(c) => { // Are we working in a function's expression arena, or the // module's constant expression arena? - if let Some(ref function_local_data) = self.function_local_data { + if let Some(function_local_data) = self.function_local_data() { // Deep-copy the constant's value into our arena. self.copy_from( self.constants[c].init, - function_local_data.const_expressions, + function_local_data.global_expressions, ) } else { // "See through" the constant and use its initializer. @@ -580,9 +692,11 @@ impl<'a> ConstantEvaluator<'a> { /// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena /// `self` contributes to. /// - /// If `expr`'s value cannot be determined at compile time, return a an - /// error. If it's acceptable to evaluate `expr` at runtime, this error can - /// be ignored, and the caller can append `expr` to the arena itself. + /// If `expr`'s value cannot be determined at compile time, and `self` is + /// contributing to some function's expression arena, then append `expr` to + /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be + /// contributing to the module's constant expression arena; since `expr`'s + /// value is not a constant, return an error. /// /// We only consider `expr` itself, without recursing into its operands. Its /// operands must all have been produced by prior calls to @@ -595,16 +709,81 @@ impl<'a> ConstantEvaluator<'a> { /// [`Swizzle`]: Expression::Swizzle pub fn try_eval_and_append( &mut self, + expr: Expression, + span: Span, + ) -> Result<Handle<Expression>, ConstantEvaluatorError> { + match self.expression_kind_tracker.type_of_with_expr(&expr) { + ExpressionKind::Const => { + let eval_result = self.try_eval_and_append_impl(&expr, span); + // We should be able to evaluate `Const` expressions at this + // point. If we failed to, then that probably means we just + // haven't implemented that part of constant evaluation. Work + // around this by simply emitting it as a run-time expression. + if self.behavior.has_runtime_restrictions() + && matches!( + eval_result, + Err(ConstantEvaluatorError::NotImplemented(_) + | ConstantEvaluatorError::InvalidBinaryOpArgs,) + ) + { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + eval_result + } + } + ExpressionKind::Override => match self.behavior { + Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => { + Ok(self.append_expr(expr, span, ExpressionKind::Override)) + } + Behavior::Wgsl(WgslRestrictions::Const) => { + Err(ConstantEvaluatorError::OverrideExpr) + } + Behavior::Glsl(_) => { + unreachable!() + } + }, + ExpressionKind::Runtime => { + if self.behavior.has_runtime_restrictions() { + Ok(self.append_expr(expr, span, ExpressionKind::Runtime)) + } else { + Err(ConstantEvaluatorError::RuntimeExpr) + } + } + } + } + + /// Is the [`Self::expressions`] arena the global module expression arena? + const fn is_global_arena(&self) -> bool { + matches!( + self.behavior, + Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override) + | Behavior::Glsl(GlslRestrictions::Const) + ) + } + + const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> { + match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => { + Some(function_local_data) + } + _ => None, + } + } + + fn try_eval_and_append_impl( + &mut self, expr: &Expression, span: Span, ) -> Result<Handle<Expression>, ConstantEvaluatorError> { log::trace!("try_eval_and_append: {:?}", expr); match *expr { - Expression::Constant(c) if self.function_local_data.is_none() => { + Expression::Constant(c) if self.is_global_arena() => { // "See through" the constant and use its initializer. // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } + Expression::Override(_) => Err(ConstantEvaluatorError::Override), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -685,8 +864,8 @@ impl<'a> ConstantEvaluator<'a> { format!("{fun:?} built-in function"), )), Expression::ArrayLength(expr) => match self.behavior { - Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength), - Behavior::Glsl => { + Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), + Behavior::Glsl(_) => { let expr = self.check_and_get(expr)?; self.array_length(expr, span) } @@ -707,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } } } @@ -765,10 +950,10 @@ impl<'a> ConstantEvaluator<'a> { pattern: [crate::SwizzleComponent; 4], ) -> Result<Handle<Expression>, ConstantEvaluatorError> { let mut get_dst_ty = |ty| match self.types[ty].inner { - crate::TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( + TypeInner::Vector { size: _, scalar } => Ok(self.types.insert( Type { name: None, - inner: crate::TypeInner::Vector { size, scalar }, + inner: TypeInner::Vector { size, scalar }, }, span, )), @@ -1059,13 +1244,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => { match self.types[ty].inner { TypeInner::Array { size, .. } => match size { - crate::ArraySize::Constant(len) => { + ArraySize::Constant(len) => { let expr = Expression::Literal(Literal::U32(len.get())); self.register_evaluated_expr(expr, span) } - crate::ArraySize::Dynamic => { - Err(ConstantEvaluatorError::ArrayLengthDynamic) - } + ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic), }, _ => Err(ConstantEvaluatorError::InvalidArrayLengthArg), } @@ -1128,7 +1311,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::ZeroValue(ty) if matches!( self.types[ty].inner, - crate::TypeInner::Scalar(crate::Scalar { + TypeInner::Scalar(crate::Scalar { kind: ScalarKind::Uint, .. }) @@ -1443,7 +1626,7 @@ impl<'a> ConstantEvaluator<'a> { return self.cast(expr, target, span); }; - let crate::TypeInner::Array { + let TypeInner::Array { base: _, size, stride: _, @@ -1853,29 +2036,35 @@ impl<'a> ConstantEvaluator<'a> { crate::valid::check_literal_value(literal)?; } - if let Some(FunctionLocalData { - ref mut emitter, - ref mut block, - ref mut expression_constness, - .. - }) = self.function_local_data - { - let is_running = emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { - block.extend(emitter.finish(self.expressions)); - let h = self.expressions.append(expr, span); - emitter.start(self.expressions); - expression_constness.insert(h); - Ok(h) - } else { - let h = self.expressions.append(expr, span); - expression_constness.insert(h); - Ok(h) + Ok(self.append_expr(expr, span, ExpressionKind::Const)) + } + + fn append_expr( + &mut self, + expr: Expression, + span: Span, + expr_type: ExpressionKind, + ) -> Handle<Expression> { + let h = match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => { + let is_running = function_local_data.emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + function_local_data + .block + .extend(function_local_data.emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + function_local_data.emitter.start(self.expressions); + h + } else { + self.expressions.append(expr, span) + } } - } else { - Ok(self.expressions.append(expr, span)) - } + _ => self.expressions.append(expr, span), + }; + self.expression_kind_tracker.insert(h, expr_type); + h } fn resolve_type( @@ -2029,13 +2218,14 @@ mod tests { UniqueArena, VectorSize, }; - use super::{Behavior, ConstantEvaluator}; + use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions}; #[test] fn unary_op() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2059,9 +2249,8 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), @@ -2070,9 +2259,8 @@ mod tests { let h1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(8)), Default::default()), }, Default::default(), @@ -2081,9 +2269,8 @@ mod tests { let vec_h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec![constants[h].init, constants[h1].init], @@ -2094,8 +2281,8 @@ mod tests { Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); - let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default()); let expr2 = Expression::Unary { op: UnaryOperator::Negate, @@ -2112,35 +2299,37 @@ mod tests { expr: expr1, }; + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let res1 = solver - .try_eval_and_append(&expr2, Default::default()) + .try_eval_and_append(expr2, Default::default()) .unwrap(); let res2 = solver - .try_eval_and_append(&expr3, Default::default()) + .try_eval_and_append(expr3, Default::default()) .unwrap(); let res3 = solver - .try_eval_and_append(&expr4, Default::default()) + .try_eval_and_append(expr4, Default::default()) .unwrap(); assert_eq!( - const_expressions[res1], + global_expressions[res1], Expression::Literal(Literal::I32(-4)) ); assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::I32(!4)) ); - let res3_inner = &const_expressions[res3]; + let res3_inner = &global_expressions[res3]; match *res3_inner { Expression::Compose { @@ -2150,11 +2339,11 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!4)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::I32(!8)) ); assert!(components_iter.next().is_none()); @@ -2167,7 +2356,8 @@ mod tests { fn cast() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let scalar_ty = types.insert( Type { @@ -2180,15 +2370,14 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: scalar_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expr = global_expressions.append(Expression::Constant(h), Default::default()); let root = Expression::As { expr, @@ -2196,20 +2385,22 @@ mod tests { convert: Some(crate::BOOL_WIDTH), }; + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let res = solver - .try_eval_and_append(&root, Default::default()) + .try_eval_and_append(root, Default::default()) .unwrap(); assert_eq!( - const_expressions[res], + global_expressions[res], Expression::Literal(Literal::Bool(true)) ); } @@ -2218,7 +2409,8 @@ mod tests { fn access() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let matrix_ty = types.insert( Type { @@ -2247,7 +2439,7 @@ mod tests { let mut vec2_components = Vec::with_capacity(3); for i in 0..3 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2256,7 +2448,7 @@ mod tests { } for i in 3..6 { - let h = const_expressions.append( + let h = global_expressions.append( Expression::Literal(Literal::F32(i as f32)), Default::default(), ); @@ -2267,9 +2459,8 @@ mod tests { let vec1 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec1_components, @@ -2283,9 +2474,8 @@ mod tests { let vec2 = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: vec_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: vec_ty, components: vec2_components, @@ -2299,9 +2489,8 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: matrix_ty, - init: const_expressions.append( + init: global_expressions.append( Expression::Compose { ty: matrix_ty, components: vec![constants[vec1].init, constants[vec2].init], @@ -2312,20 +2501,22 @@ mod tests { Default::default(), ); - let base = const_expressions.append(Expression::Constant(h), Default::default()); + let base = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let root1 = Expression::AccessIndex { base, index: 1 }; let res1 = solver - .try_eval_and_append(&root1, Default::default()) + .try_eval_and_append(root1, Default::default()) .unwrap(); let root2 = Expression::AccessIndex { @@ -2334,10 +2525,10 @@ mod tests { }; let res2 = solver - .try_eval_and_append(&root2, Default::default()) + .try_eval_and_append(root2, Default::default()) .unwrap(); - match const_expressions[res1] { + match global_expressions[res1] { Expression::Compose { ref ty, ref components, @@ -2345,15 +2536,15 @@ mod tests { assert_eq!(*ty, vec_ty); let mut components_iter = components.iter().copied(); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(3.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(4.)) ); assert_eq!( - const_expressions[components_iter.next().unwrap()], + global_expressions[components_iter.next().unwrap()], Expression::Literal(Literal::F32(5.)) ); assert!(components_iter.next().is_none()); @@ -2362,7 +2553,7 @@ mod tests { } assert_eq!( - const_expressions[res2], + global_expressions[res2], Expression::Literal(Literal::F32(5.)) ); } @@ -2371,7 +2562,8 @@ mod tests { fn compose_of_constants() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2395,27 +2587,28 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let solved_compose = solver .try_eval_and_append( - &Expression::Compose { + Expression::Compose { ty: vec2_i32_ty, components: vec![h_expr, h_expr], }, @@ -2424,7 +2617,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, @@ -2432,11 +2625,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } @@ -2451,7 +2644,8 @@ mod tests { fn splat_of_constant() { let mut types = UniqueArena::new(); let mut constants = Arena::new(); - let mut const_expressions = Arena::new(); + let overrides = Arena::new(); + let mut global_expressions = Arena::new(); let i32_ty = types.insert( Type { @@ -2475,27 +2669,28 @@ mod tests { let h = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_ty, - init: const_expressions + init: global_expressions .append(Expression::Literal(Literal::I32(4)), Default::default()), }, Default::default(), ); - let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let h_expr = global_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, - expressions: &mut const_expressions, - function_local_data: None, + overrides: &overrides, + expressions: &mut global_expressions, + expression_kind_tracker, }; let solved_compose = solver .try_eval_and_append( - &Expression::Splat { + Expression::Splat { size: VectorSize::Bi, value: h_expr, }, @@ -2504,7 +2699,7 @@ mod tests { .unwrap(); let solved_negate = solver .try_eval_and_append( - &Expression::Unary { + Expression::Unary { op: UnaryOperator::Negate, expr: solved_compose, }, @@ -2512,11 +2707,11 @@ mod tests { ) .unwrap(); - let pass = match const_expressions[solved_negate] { + let pass = match global_expressions[solved_negate] { Expression::Compose { ty, ref components } => { ty == vec2_i32_ty && components.iter().all(|&component| { - let component = &const_expressions[component]; + let component = &global_expressions[component]; matches!(*component, Expression::Literal(Literal::I32(-4))) }) } diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs index af3221c0fe..e2c3de8eb0 100644 --- a/third_party/rust/naga/src/proc/index.rs +++ b/third_party/rust/naga/src/proc/index.rs @@ -239,7 +239,7 @@ pub enum GuardedIndex { pub fn find_checked_indexes( module: &crate::Module, function: &crate::Function, - info: &crate::valid::FunctionInfo, + info: &valid::FunctionInfo, policies: BoundsCheckPolicies, ) -> BitSet { use crate::Expression as Ex; @@ -321,7 +321,7 @@ pub fn access_needs_check( mut index: GuardedIndex, module: &crate::Module, function: &crate::Function, - info: &crate::valid::FunctionInfo, + info: &valid::FunctionInfo, ) -> Option<IndexableLength> { let base_inner = info[base].ty.inner_with(&module.types); // Unwrap safety: `Err` here indicates unindexable base types and invalid diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs index 46cbb6c3b3..93aac5b3e5 100644 --- a/third_party/rust/naga/src/proc/mod.rs +++ b/third_party/rust/naga/src/proc/mod.rs @@ -11,7 +11,7 @@ mod terminator; mod typifier; pub use constant_evaluator::{ - ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, + ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker, }; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; @@ -153,56 +153,31 @@ impl super::Scalar { } } -impl PartialEq for crate::Literal { - fn eq(&self, other: &Self) -> bool { - match (*self, *other) { - (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(), - (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), - (Self::U32(a), Self::U32(b)) => a == b, - (Self::I32(a), Self::I32(b)) => a == b, - (Self::U64(a), Self::U64(b)) => a == b, - (Self::I64(a), Self::I64(b)) => a == b, - (Self::Bool(a), Self::Bool(b)) => a == b, - _ => false, - } - } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum HashableLiteral { + F64(u64), + F32(u32), + U32(u32), + I32(i32), + U64(u64), + I64(i64), + Bool(bool), + AbstractInt(i64), + AbstractFloat(u64), } -impl Eq for crate::Literal {} -impl std::hash::Hash for crate::Literal { - fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) { - match *self { - Self::F64(v) | Self::AbstractFloat(v) => { - hasher.write_u8(0); - v.to_bits().hash(hasher); - } - Self::F32(v) => { - hasher.write_u8(1); - v.to_bits().hash(hasher); - } - Self::U32(v) => { - hasher.write_u8(2); - v.hash(hasher); - } - Self::I32(v) => { - hasher.write_u8(3); - v.hash(hasher); - } - Self::Bool(v) => { - hasher.write_u8(4); - v.hash(hasher); - } - Self::I64(v) => { - hasher.write_u8(5); - v.hash(hasher); - } - Self::U64(v) => { - hasher.write_u8(6); - v.hash(hasher); - } - Self::AbstractInt(v) => { - hasher.write_u8(7); - v.hash(hasher); - } + +impl From<crate::Literal> for HashableLiteral { + fn from(l: crate::Literal) -> Self { + match l { + crate::Literal::F64(v) => Self::F64(v.to_bits()), + crate::Literal::F32(v) => Self::F32(v.to_bits()), + crate::Literal::U32(v) => Self::U32(v), + crate::Literal::I32(v) => Self::I32(v), + crate::Literal::U64(v) => Self::U64(v), + crate::Literal::I64(v) => Self::I64(v), + crate::Literal::Bool(v) => Self::Bool(v), + crate::Literal::AbstractInt(v) => Self::AbstractInt(v), + crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()), } } } @@ -216,8 +191,8 @@ impl crate::Literal { (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), (value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)), (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)), - (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), - (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), + (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)), + (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)), _ => None, } } @@ -279,8 +254,9 @@ impl super::TypeInner { self.scalar().map(|scalar| scalar.kind) } + /// Returns the scalar width in bytes pub fn scalar_width(&self) -> Option<u8> { - self.scalar().map(|scalar| scalar.width * 8) + self.scalar().map(|scalar| scalar.width) } pub const fn pointer_space(&self) -> Option<crate::AddressSpace> { @@ -532,6 +508,7 @@ impl crate::Expression { match *self { Self::Literal(_) | Self::Constant(_) + | Self::Override(_) | Self::ZeroValue(_) | Self::FunctionArgument(_) | Self::GlobalVariable(_) @@ -553,13 +530,9 @@ impl crate::Expression { /// /// [`Access`]: crate::Expression::Access /// [`ResolveContext`]: crate::proc::ResolveContext - pub fn is_dynamic_index(&self, module: &crate::Module) -> bool { + pub const fn is_dynamic_index(&self) -> bool { match *self { - Self::Literal(_) | Self::ZeroValue(_) => false, - Self::Constant(handle) => { - let constant = &module.constants[handle]; - !matches!(constant.r#override, crate::Override::None) - } + Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false, _ => true, } } @@ -652,7 +625,8 @@ impl crate::Module { GlobalCtx { types: &self.types, constants: &self.constants, - const_expressions: &self.const_expressions, + overrides: &self.overrides, + global_expressions: &self.global_expressions, } } } @@ -667,17 +641,18 @@ pub(super) enum U32EvalError { pub struct GlobalCtx<'a> { pub types: &'a crate::UniqueArena<crate::Type>, pub constants: &'a crate::Arena<crate::Constant>, - pub const_expressions: &'a crate::Arena<crate::Expression>, + pub overrides: &'a crate::Arena<crate::Override>, + pub global_expressions: &'a crate::Arena<crate::Expression>, } impl GlobalCtx<'_> { - /// Try to evaluate the expression in `self.const_expressions` using its `handle` and return it as a `u32`. + /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`. #[allow(dead_code)] pub(super) fn eval_expr_to_u32( &self, handle: crate::Handle<crate::Expression>, ) -> Result<u32, U32EvalError> { - self.eval_expr_to_u32_from(handle, self.const_expressions) + self.eval_expr_to_u32_from(handle, self.global_expressions) } /// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`. @@ -700,7 +675,7 @@ impl GlobalCtx<'_> { &self, handle: crate::Handle<crate::Expression>, ) -> Option<crate::Literal> { - self.eval_expr_to_literal_from(handle, self.const_expressions) + self.eval_expr_to_literal_from(handle, self.global_expressions) } fn eval_expr_to_literal_from( @@ -724,7 +699,7 @@ impl GlobalCtx<'_> { } match arena[handle] { crate::Expression::Constant(c) => { - get(*self, self.constants[c].init, self.const_expressions) + get(*self, self.constants[c].init, self.global_expressions) } _ => get(*self, handle, arena), } diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs index a5239d4eca..5edf55cb73 100644 --- a/third_party/rust/naga/src/proc/terminator.rs +++ b/third_party/rust/naga/src/proc/terminator.rs @@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::RayQuery { .. } | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs index 9c4403445c..3936e7efbe 100644 --- a/third_party/rust/naga/src/proc/typifier.rs +++ b/third_party/rust/naga/src/proc/typifier.rs @@ -185,6 +185,7 @@ pub enum ResolveError { pub struct ResolveContext<'a> { pub constants: &'a Arena<crate::Constant>, + pub overrides: &'a Arena<crate::Override>, pub types: &'a UniqueArena<crate::Type>, pub special_types: &'a crate::SpecialTypes, pub global_vars: &'a Arena<crate::GlobalVariable>, @@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> { ) -> Self { Self { constants: &module.constants, + overrides: &module.overrides, types: &module.types, special_types: &module.special_types, global_vars: &module.global_variables, @@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> { }, crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()), crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty), + crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty), crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty), crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { @@ -595,6 +598,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty), crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { expr, .. } => past(expr)?.clone(), @@ -882,6 +886,10 @@ impl<'a> ResolveContext<'a> { .ok_or(ResolveError::MissingSpecialType)?; TypeResolution::Handle(result) } + crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector { + scalar: crate::Scalar::U32, + size: crate::VectorSize::Quad, + }), }) } } diff --git a/third_party/rust/naga/src/span.rs b/third_party/rust/naga/src/span.rs index 10744647e9..82cfbe5a4b 100644 --- a/third_party/rust/naga/src/span.rs +++ b/third_party/rust/naga/src/span.rs @@ -72,8 +72,8 @@ impl Span { pub fn location(&self, source: &str) -> SourceLocation { let prefix = &source[..self.start as usize]; let line_number = prefix.matches('\n').count() as u32 + 1; - let line_start = prefix.rfind('\n').map(|pos| pos + 1).unwrap_or(0); - let line_position = source[line_start..self.start as usize].chars().count() as u32 + 1; + let line_start = prefix.rfind('\n').map(|pos| pos + 1).unwrap_or(0) as u32; + let line_position = self.start - line_start + 1; SourceLocation { line_number, @@ -107,14 +107,14 @@ impl std::ops::Index<Span> for str { /// Roughly corresponds to the positional members of [`GPUCompilationMessage`][gcm] from /// the WebGPU specification, except /// - `offset` and `length` are in bytes (UTF-8 code units), instead of UTF-16 code units. -/// - `line_position` counts entire Unicode code points, instead of UTF-16 code units. +/// - `line_position` is in bytes (UTF-8 code units), instead of UTF-16 code units. /// /// [gcm]: https://www.w3.org/TR/webgpu/#gpucompilationmessage #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct SourceLocation { /// 1-based line number. pub line_number: u32, - /// 1-based column of the start of this span, counted in Unicode code points. + /// 1-based column in code units (in bytes) of the start of the span. pub line_position: u32, /// 0-based Offset in code units (in bytes) of the start of the span. pub offset: u32, @@ -136,7 +136,7 @@ impl<E> fmt::Display for WithSpan<E> where E: fmt::Display, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.inner.fmt(f) } } @@ -304,7 +304,7 @@ impl<E> WithSpan<E> { use term::termcolor::NoColor; let files = files::SimpleFile::new(path, source); - let config = codespan_reporting::term::Config::default(); + let config = term::Config::default(); let mut writer = NoColor::new(Vec::new()); term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error"); String::from_utf8(writer.into_inner()).unwrap() diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs index 03fbc4089b..6799e5db27 100644 --- a/third_party/rust/naga/src/valid/analyzer.rs +++ b/third_party/rust/naga/src/valid/analyzer.rs @@ -226,7 +226,7 @@ struct Sampling { sampler: GlobalOrArgument, } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct FunctionInfo { @@ -574,7 +574,7 @@ impl FunctionInfo { non_uniform_result: self.add_ref(vector), requirements: UniformityRequirements::empty(), }, - E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(), + E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(), E::Compose { ref components, .. } => { let non_uniform_result = components .iter() @@ -787,6 +787,14 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::SubgroupBallotResult => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::SubgroupOperationResult { .. } => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -827,7 +835,7 @@ impl FunctionInfo { let req = self.expressions[expr.index()].uniformity.requirements; if self .flags - .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY) + .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY) && !req.is_empty() { if let Some(cause) = disruptor { @@ -1029,6 +1037,42 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::SubgroupBallot { + result: _, + predicate, + } => { + if let Some(predicate) = predicate { + let _ = self.add_ref(predicate); + } + FunctionUniformity::new() + } + S::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } + S::SubgroupGather { + mode, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let _ = self.add_ref(index); + } + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); @@ -1047,7 +1091,7 @@ impl ModuleInfo { gctx: crate::proc::GlobalCtx, ) -> Result<(), super::ConstExpressionError> { self.const_expression_types[handle.index()] = - resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?; + resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?; Ok(()) } @@ -1186,6 +1230,7 @@ fn uniform_control_flow() { }; let resolve_context = ResolveContext { constants: &Arena::new(), + overrides: &Arena::new(), types: &type_arena, special_types: &crate::SpecialTypes::default(), global_vars: &global_var_arena, diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs index 838ecc4e27..525bd28c17 100644 --- a/third_party/rust/naga/src/valid/expression.rs +++ b/third_party/rust/naga/src/valid/expression.rs @@ -90,6 +90,8 @@ pub enum ExpressionError { sampler: bool, has_ref: bool, }, + #[error("Sample offset must be a const-expression")] + InvalidSampleOffsetExprType, #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>), #[error("Depth reference {0:?} is not a scalar float")] @@ -129,9 +131,12 @@ pub enum ExpressionError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstExpressionError { - #[error("The expression is not a constant expression")] - NonConst, + #[error("The expression is not a constant or override expression")] + NonConstOrOverride, + #[error("The expression is not a fully evaluated constant expression")] + NonFullyEvaluatedConst, #[error(transparent)] Compose(#[from] super::ComposeError), #[error("Splatting {0:?} can't be done")] @@ -184,10 +189,15 @@ impl super::Validator { handle: Handle<crate::Expression>, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), ConstExpressionError> { use crate::Expression as E; - match gctx.const_expressions[handle] { + if !global_expr_kind.is_const_or_override(handle) { + return Err(ConstExpressionError::NonConstOrOverride); + } + + match gctx.global_expressions[handle] { E::Literal(literal) => { self.validate_literal(literal)?; } @@ -201,14 +211,19 @@ impl super::Validator { } E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) { crate::TypeInner::Scalar { .. } => {} - _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), + _ => return Err(ConstExpressionError::InvalidSplatType(value)), }, - _ => return Err(super::ConstExpressionError::NonConst), + _ if global_expr_kind.is_const(handle) || !self.allow_overrides => { + return Err(ConstExpressionError::NonFullyEvaluatedConst) + } + // the constant evaluator will report errors about override-expressions + _ => {} } Ok(()) } + #[allow(clippy::too_many_arguments)] pub(super) fn validate_expression( &self, root: Handle<crate::Expression>, @@ -217,6 +232,7 @@ impl super::Validator { module: &crate::Module, info: &FunctionInfo, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<ShaderStages, ExpressionError> { use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; @@ -252,9 +268,7 @@ impl super::Validator { return Err(ExpressionError::InvalidIndexType(index)); } } - if dynamic_indexing_restricted - && function.expressions[index].is_dynamic_index(module) - { + if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() { return Err(ExpressionError::IndexMustBeConstant(base)); } @@ -347,7 +361,7 @@ impl super::Validator { self.validate_literal(literal)?; ShaderStages::all() } - E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(), + E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(), E::Compose { ref components, ty } => { validate_compose( ty, @@ -464,6 +478,10 @@ impl super::Validator { // check constant offset if let Some(const_expr) = offset { + if !global_expr_kind.is_const(const_expr) { + return Err(ExpressionError::InvalidSampleOffsetExprType); + } + match *mod_info[const_expr].inner_with(&module.types) { Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {} Ti::Vector { @@ -1623,6 +1641,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, + E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, }; Ok(stages) } @@ -1716,7 +1735,7 @@ fn validate_with_const_expression( use crate::span::Span; let mut module = crate::Module::default(); - module.const_expressions.append(expr, Span::default()); + module.global_expressions.append(expr, Span::default()); let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps); diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs index f0ca22cbda..71128fc86d 100644 --- a/third_party/rust/naga/src/valid/function.rs +++ b/third_party/rust/naga/src/valid/function.rs @@ -49,13 +49,26 @@ pub enum AtomicError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] +pub enum SubgroupError { + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle<crate::Expression>), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle<crate::Expression>), + #[error("Support for subgroup operation {0:?} is required")] + UnsupportedOperation(super::SubgroupOperationSet), + #[error("Unknown operation")] + UnknownOperation, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { #[error("Local variable has a type {0:?} that can't be stored in a local variable.")] InvalidType(Handle<crate::Type>), #[error("Initializer doesn't match the variable type")] InitializerType, - #[error("Initializer is not const")] - NonConstInitializer, + #[error("Initializer is not a const or override expression")] + NonConstOrOverrideInitializer, } #[derive(Clone, Debug, thiserror::Error)] @@ -135,6 +148,8 @@ pub enum FunctionError { InvalidRayDescriptor(Handle<crate::Expression>), #[error("Ray Query {0:?} does not have a matching type")] InvalidRayQueryType(Handle<crate::Type>), + #[error("Shader requires capability {0:?}")] + MissingCapability(super::Capabilities), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -155,6 +170,8 @@ pub enum FunctionError { WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>), #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>), + #[error("Subgroup operation is invalid")] + InvalidSubgroup(#[from] SubgroupError), } bitflags::bitflags! { @@ -399,6 +416,127 @@ impl super::Validator { } Ok(()) } + fn validate_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + let (is_scalar, scalar) = match *argument_inner { + crate::TypeInner::Scalar(scalar) => (true, scalar), + crate::TypeInner::Vector { scalar, .. } => (false, scalar), + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + match (scalar.kind, *op) { + (sk::Bool, sg::All | sg::Any) if is_scalar => {} + (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} + (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {} + + (_, _) => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::CollectiveOperation as co; + match (*collective_op, *op) { + ( + co::Reduce, + sg::All + | sg::Any + | sg::Add + | sg::Mul + | sg::Min + | sg::Max + | sg::And + | sg::Or + | sg::Xor, + ) => {} + (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {} + + (_, _) => { + return Err(SubgroupError::UnknownOperation.with_span().into_other()); + } + }; + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + fn validate_subgroup_gather( + &mut self, + mode: &crate::GatherMode, + argument: Handle<crate::Expression>, + result: Handle<crate::Expression>, + context: &BlockContext, + ) -> Result<(), WithSpan<FunctionError>> { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar(crate::Scalar::U32) => {} + _ => { + log::error!( + "Subgroup gather index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } + } + } + } + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + if !matches!(*argument_inner, + crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } + if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup gather operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } fn validate_block_impl( &mut self, @@ -613,8 +751,30 @@ impl super::Validator { stages &= super::ShaderStages::FRAGMENT; finished = true; } - S::Barrier(_) => { + S::Barrier(barrier) => { stages &= super::ShaderStages::COMPUTE; + if barrier.contains(crate::Barrier::SUB_GROUP) { + if !self.capabilities.contains( + super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER, + ) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP + | super::Capabilities::SUBGROUP_BARRIER, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BASIC) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BASIC, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + } } S::Store { pointer, value } => { let mut current = pointer; @@ -904,6 +1064,86 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result, predicate } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BALLOT) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BALLOT, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + if let Some(predicate) = predicate { + let predicate_inner = + context.resolve_type(predicate, &self.valid_expression_set)?; + if !matches!( + *predicate_inner, + crate::TypeInner::Scalar(crate::Scalar::BOOL,) + ) { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); + } + } + self.emit_expression(result, context)?; + } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = op.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_operation(op, collective_op, argument, result, context)?; + } + S::SubgroupGather { + ref mode, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = mode.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_gather(mode, argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) @@ -927,7 +1167,7 @@ impl super::Validator { var: &crate::LocalVariable, gctx: crate::proc::GlobalCtx, fun_info: &FunctionInfo, - expression_constness: &crate::proc::ExpressionConstnessTracker, + local_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -945,8 +1185,8 @@ impl super::Validator { return Err(LocalVariableError::InitializerType); } - if !expression_constness.is_const(init) { - return Err(LocalVariableError::NonConstInitializer); + if !local_expr_kind.is_const_or_override(init) { + return Err(LocalVariableError::NonConstOrOverrideInitializer); } } @@ -959,14 +1199,14 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, entry_point: bool, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<FunctionInfo, WithSpan<FunctionError>> { let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; - let expression_constness = - crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions); for (var_handle, var) in fun.local_variables.iter() { - self.validate_local_var(var, module.to_ctx(), &info, &expression_constness) + self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind) .map_err(|source| { FunctionError::LocalVariable { handle: var_handle, @@ -1032,7 +1272,15 @@ impl super::Validator { self.valid_expression_set.insert(handle.index()); } if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { - match self.validate_expression(handle, expr, fun, module, &info, mod_info) { + match self.validate_expression( + handle, + expr, + fun, + module, + &info, + mod_info, + global_expr_kind, + ) { Ok(stages) => info.available_stages &= stages, Err(source) => { return Err(FunctionError::Expression { handle, source } diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs index e482f293bb..8f78204055 100644 --- a/third_party/rust/naga/src/valid/handles.rs +++ b/third_party/rust/naga/src/valid/handles.rs @@ -31,12 +31,13 @@ impl super::Validator { pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { let &crate::Module { ref constants, + ref overrides, ref entry_points, ref functions, ref global_variables, ref types, ref special_types, - ref const_expressions, + ref global_expressions, } = module; // NOTE: Types being first is important. All other forms of validation depend on this. @@ -67,23 +68,31 @@ impl super::Validator { } } - for handle_and_expr in const_expressions.iter() { - Self::validate_const_expression_handles(handle_and_expr, constants, types)?; + for handle_and_expr in global_expressions.iter() { + Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?; } let validate_type = |handle| Self::validate_type_handle(handle, types); let validate_const_expr = - |handle| Self::validate_expression_handle(handle, const_expressions); + |handle| Self::validate_expression_handle(handle, global_expressions); for (_handle, constant) in constants.iter() { - let &crate::Constant { + let &crate::Constant { name: _, ty, init } = constant; + validate_type(ty)?; + validate_const_expr(init)?; + } + + for (_handle, override_) in overrides.iter() { + let &crate::Override { name: _, - r#override: _, + id: _, ty, init, - } = constant; + } = override_; validate_type(ty)?; - validate_const_expr(init)?; + if let Some(init_expr) = init { + validate_const_expr(init_expr)?; + } } for (_handle, global_variable) in global_variables.iter() { @@ -140,7 +149,8 @@ impl super::Validator { Self::validate_expression_handles( handle_and_expr, constants, - const_expressions, + overrides, + global_expressions, types, local_variables, global_variables, @@ -186,6 +196,13 @@ impl super::Validator { handle.check_valid_for(constants).map(|_| ()) } + fn validate_override_handle( + handle: Handle<crate::Override>, + overrides: &Arena<crate::Override>, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(overrides).map(|_| ()) + } + fn validate_expression_handle( handle: Handle<crate::Expression>, expressions: &Arena<crate::Expression>, @@ -203,9 +220,11 @@ impl super::Validator { fn validate_const_expression_handles( (handle, expression): (Handle<crate::Expression>, &crate::Expression), constants: &Arena<crate::Constant>, + overrides: &Arena<crate::Override>, types: &UniqueArena<crate::Type>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -214,6 +233,12 @@ impl super::Validator { validate_constant(constant)?; handle.check_dep(constants[constant].init)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + if let Some(init) = overrides[override_].init { + handle.check_dep(init)?; + } + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -230,7 +255,8 @@ impl super::Validator { fn validate_expression_handles( (handle, expression): (Handle<crate::Expression>, &crate::Expression), constants: &Arena<crate::Constant>, - const_expressions: &Arena<crate::Expression>, + overrides: &Arena<crate::Override>, + global_expressions: &Arena<crate::Expression>, types: &UniqueArena<crate::Type>, local_variables: &Arena<crate::LocalVariable>, global_variables: &Arena<crate::GlobalVariable>, @@ -239,8 +265,9 @@ impl super::Validator { current_function: Option<Handle<crate::Function>>, ) -> Result<(), InvalidHandleError> { let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_override = |handle| Self::validate_override_handle(handle, overrides); let validate_const_expr = - |handle| Self::validate_expression_handle(handle, const_expressions); + |handle| Self::validate_expression_handle(handle, global_expressions); let validate_type = |handle| Self::validate_type_handle(handle, types); match *expression { @@ -260,6 +287,9 @@ impl super::Validator { crate::Expression::Constant(constant) => { validate_constant(constant)?; } + crate::Expression::Override(override_) => { + validate_override(override_)?; + } crate::Expression::ZeroValue(ty) => { validate_type(ty)?; } @@ -390,6 +420,8 @@ impl super::Validator { } crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -535,6 +567,38 @@ impl super::Validator { } Ok(()) } + crate::Statement::SubgroupBallot { result, predicate } => { + validate_expr_opt(predicate)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + validate_expr(argument)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, + } + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill @@ -562,6 +626,7 @@ impl From<BadRangeError> for ValidationError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum InvalidHandleError { #[error(transparent)] BadHandle(#[from] BadHandle), @@ -572,6 +637,7 @@ pub enum InvalidHandleError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] #[error( "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \ which has not been processed yet" @@ -664,6 +730,7 @@ fn constant_deps() { let mut const_exprs = Arena::new(); let mut fun_exprs = Arena::new(); let mut constants = Arena::new(); + let overrides = Arena::new(); let i32_handle = types.insert( Type { @@ -679,7 +746,6 @@ fn constant_deps() { let self_referential_const = constants.append( Constant { name: None, - r#override: crate::Override::None, ty: i32_handle, init: fun_expr, }, @@ -692,6 +758,7 @@ fn constant_deps() { assert!(super::Validator::validate_const_expression_handles( handle_and_expr, &constants, + &overrides, &types, ) .is_err()); diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs index 84c8b09ddb..db890ddbac 100644 --- a/third_party/rust/naga/src/valid/interface.rs +++ b/third_party/rust/naga/src/valid/interface.rs @@ -10,6 +10,7 @@ use bit_set::BitSet; const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum GlobalVariableError { #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), @@ -30,6 +31,8 @@ pub enum GlobalVariableError { Handle<crate::Type>, #[source] Disalignment, ), + #[error("Initializer must be an override-expression")] + InitializerExprType, #[error("Initializer doesn't match the variable type")] InitializerType, #[error("Initializer can't be used with address space {0:?}")] @@ -39,6 +42,7 @@ pub enum GlobalVariableError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum VaryingError { #[error("The type {0:?} does not match the varying")] InvalidType(Handle<crate::Type>), @@ -73,9 +77,12 @@ pub enum VaryingError { location: u32, attribute: &'static str, }, + #[error("Workgroup size is multi dimensional, @builtin(subgroup_id) and @builtin(subgroup_invocation_id) are not supported.")] + InvalidMultiDimensionalSubgroupBuiltIn, } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum EntryPointError { #[error("Multiple conflicting entry points")] Conflict, @@ -135,6 +142,7 @@ struct VaryingContext<'a> { impl VaryingContext<'_> { fn validate_impl( &mut self, + ep: &crate::EntryPoint, ty: Handle<crate::Type>, binding: &crate::Binding, ) -> Result<(), VaryingError> { @@ -162,12 +170,24 @@ impl VaryingContext<'_> { Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX, Bi::ViewIndex => Capabilities::MULTIVIEW, Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING, + Bi::NumSubgroups + | Bi::SubgroupId + | Bi::SubgroupSize + | Bi::SubgroupInvocationId => Capabilities::SUBGROUP, _ => Capabilities::empty(), }; if !self.capabilities.contains(required) { return Err(VaryingError::UnsupportedCapability(required)); } + if matches!( + built_in, + crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId + ) && ep.workgroup_size[1..].iter().any(|&s| s > 1) + { + return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn); + } + let (visible, type_good) = match built_in { Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( self.stage == St::Vertex && !self.output, @@ -249,6 +269,17 @@ impl VaryingContext<'_> { scalar: crate::Scalar::U32, }, ), + Bi::NumSubgroups | Bi::SubgroupId => ( + self.stage == St::Compute && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::SubgroupSize | Bi::SubgroupInvocationId => ( + match self.stage { + St::Compute | St::Fragment => !self.output, + St::Vertex => false, + }, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), }; if !visible { @@ -349,13 +380,14 @@ impl VaryingContext<'_> { fn validate( &mut self, + ep: &crate::EntryPoint, ty: Handle<crate::Type>, binding: Option<&crate::Binding>, ) -> Result<(), WithSpan<VaryingError>> { let span_context = self.types.get_span_context(ty); match binding { Some(binding) => self - .validate_impl(ty, binding) + .validate_impl(ep, ty, binding) .map_err(|e| e.with_span_context(span_context)), None => { match self.types[ty].inner { @@ -372,7 +404,7 @@ impl VaryingContext<'_> { } } Some(ref binding) => self - .validate_impl(member.ty, binding) + .validate_impl(ep, member.ty, binding) .map_err(|e| e.with_span_context(span_context))?, } } @@ -395,6 +427,7 @@ impl super::Validator { var: &crate::GlobalVariable, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<(), GlobalVariableError> { use super::TypeFlags; @@ -523,6 +556,10 @@ impl super::Validator { } } + if !global_expr_kind.is_const_or_override(init) { + return Err(GlobalVariableError::InitializerExprType); + } + let decl_ty = &gctx.types[var.ty].inner; let init_ty = mod_info[init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -538,6 +575,7 @@ impl super::Validator { ep: &crate::EntryPoint, module: &crate::Module, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionKindTracker, ) -> Result<FunctionInfo, WithSpan<EntryPointError>> { if ep.early_depth_test.is_some() { let required = Capabilities::EARLY_DEPTH_TEST; @@ -566,7 +604,7 @@ impl super::Validator { } let mut info = self - .validate_function(&ep.function, module, mod_info, true) + .validate_function(&ep.function, module, mod_info, true, global_expr_kind) .map_err(WithSpan::into_other)?; { @@ -598,7 +636,7 @@ impl super::Validator { capabilities: self.capabilities, flags: self.flags, }; - ctx.validate(fa.ty, fa.binding.as_ref()) + ctx.validate(ep, fa.ty, fa.binding.as_ref()) .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; } @@ -616,7 +654,7 @@ impl super::Validator { capabilities: self.capabilities, flags: self.flags, }; - ctx.validate(fr.ty, fr.binding.as_ref()) + ctx.validate(ep, fr.ty, fr.binding.as_ref()) .map_err_inner(|e| EntryPointError::Result(e).with_span())?; if ctx.second_blend_source { // Only the first location may be used when dual source blending diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs index 5459434f33..a0057f39ac 100644 --- a/third_party/rust/naga/src/valid/mod.rs +++ b/third_party/rust/naga/src/valid/mod.rs @@ -12,7 +12,7 @@ mod r#type; use crate::{ arena::Handle, - proc::{LayoutError, Layouter, TypeResolution}, + proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution}, FastHashSet, }; use bit_set::BitSet; @@ -77,7 +77,7 @@ bitflags::bitflags! { #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[derive(Clone, Copy, Debug, Eq, PartialEq)] - pub struct Capabilities: u16 { + pub struct Capabilities: u32 { /// Support for [`AddressSpace:PushConstant`]. const PUSH_CONSTANT = 0x1; /// Float values with width = 8. @@ -110,6 +110,10 @@ bitflags::bitflags! { const CUBE_ARRAY_TEXTURES = 0x4000; /// Support for 64-bit signed and unsigned integers. const SHADER_INT64 = 0x8000; + /// Support for subgroup operations. + const SUBGROUP = 0x10000; + /// Support for subgroup barriers. + const SUBGROUP_BARRIER = 0x20000; } } @@ -120,6 +124,57 @@ impl Default for Capabilities { } bitflags::bitflags! { + /// Supported subgroup operations + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct SubgroupOperationSet: u8 { + /// Elect, Barrier + const BASIC = 1 << 0; + /// Any, All + const VOTE = 1 << 1; + /// reductions, scans + const ARITHMETIC = 1 << 2; + /// ballot, broadcast + const BALLOT = 1 << 3; + /// shuffle, shuffle xor + const SHUFFLE = 1 << 4; + /// shuffle up, down + const SHUFFLE_RELATIVE = 1 << 5; + // We don't support these operations yet + // /// Clustered + // const CLUSTERED = 1 << 6; + // /// Quad supported + // const QUAD_FRAGMENT_COMPUTE = 1 << 7; + // /// Quad supported in all stages + // const QUAD_ALL_STAGES = 1 << 8; + } +} + +impl super::SubgroupOperation { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::All | Self::Any => S::VOTE, + Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => { + S::ARITHMETIC + } + } + } +} + +impl super::GatherMode { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, + Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, + Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + } + } +} + +bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] @@ -131,7 +186,7 @@ bitflags::bitflags! { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { @@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo { pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, + subgroup_stages: ShaderStages, + subgroup_operations: SubgroupOperationSet, types: Vec<r#type::TypeInfo>, layouter: Layouter, location_mask: BitSet, @@ -174,10 +231,15 @@ pub struct Validator { switch_values: FastHashSet<crate::SwitchValue>, valid_expression_list: Vec<Handle<crate::Expression>>, valid_expression_set: BitSet, + override_ids: FastHashSet<u16>, + allow_overrides: bool, } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantError { + #[error("Initializer must be a const-expression")] + InitializerExprType, #[error("The type doesn't match the constant")] InvalidType, #[error("The type is not constructible")] @@ -185,6 +247,26 @@ pub enum ConstantError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum OverrideError { + #[error("Override name and ID are missing")] + MissingNameAndID, + #[error("Override ID must be unique")] + DuplicateID, + #[error("Initializer must be a const-expression or override-expression")] + InitializerExprType, + #[error("The type doesn't match the override")] + InvalidType, + #[error("The type is not constructible")] + NonConstructibleType, + #[error("The type is not a scalar")] + TypeNotScalar, + #[error("Override declarations are not allowed")] + NotAllowed, +} + +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ValidationError { #[error(transparent)] InvalidHandle(#[from] InvalidHandleError), @@ -207,6 +289,12 @@ pub enum ValidationError { name: String, source: ConstantError, }, + #[error("Override {handle:?} '{name}' is invalid")] + Override { + handle: Handle<crate::Override>, + name: String, + source: OverrideError, + }, #[error("Global variable {handle:?} '{name}' is invalid")] GlobalVariable { handle: Handle<crate::GlobalVariable>, @@ -286,6 +374,8 @@ impl Validator { Validator { flags, capabilities, + subgroup_stages: ShaderStages::empty(), + subgroup_operations: SubgroupOperationSet::empty(), types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), @@ -293,9 +383,21 @@ impl Validator { switch_values: FastHashSet::default(), valid_expression_list: Vec::new(), valid_expression_set: BitSet::new(), + override_ids: FastHashSet::default(), + allow_overrides: true, } } + pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self { + self.subgroup_stages = stages; + self + } + + pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self { + self.subgroup_operations = operations; + self + } + /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); @@ -305,6 +407,7 @@ impl Validator { self.switch_values.clear(); self.valid_expression_list.clear(); self.valid_expression_set.clear(); + self.override_ids.clear(); } fn validate_constant( @@ -312,6 +415,7 @@ impl Validator { handle: Handle<crate::Constant>, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &ExpressionKindTracker, ) -> Result<(), ConstantError> { let con = &gctx.constants[handle]; @@ -320,6 +424,10 @@ impl Validator { return Err(ConstantError::NonConstructibleType); } + if !global_expr_kind.is_const(con.init) { + return Err(ConstantError::InitializerExprType); + } + let decl_ty = &gctx.types[con.ty].inner; let init_ty = mod_info[con.init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -329,11 +437,80 @@ impl Validator { Ok(()) } + fn validate_override( + &mut self, + handle: Handle<crate::Override>, + gctx: crate::proc::GlobalCtx, + mod_info: &ModuleInfo, + ) -> Result<(), OverrideError> { + if !self.allow_overrides { + return Err(OverrideError::NotAllowed); + } + + let o = &gctx.overrides[handle]; + + if o.name.is_none() && o.id.is_none() { + return Err(OverrideError::MissingNameAndID); + } + + if let Some(id) = o.id { + if !self.override_ids.insert(id) { + return Err(OverrideError::DuplicateID); + } + } + + let type_info = &self.types[o.ty.index()]; + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(OverrideError::NonConstructibleType); + } + + let decl_ty = &gctx.types[o.ty].inner; + match decl_ty { + &crate::TypeInner::Scalar(scalar) => match scalar { + crate::Scalar::BOOL + | crate::Scalar::I32 + | crate::Scalar::U32 + | crate::Scalar::F32 + | crate::Scalar::F64 => {} + _ => return Err(OverrideError::TypeNotScalar), + }, + _ => return Err(OverrideError::TypeNotScalar), + } + + if let Some(init) = o.init { + let init_ty = mod_info[init].inner_with(gctx.types); + if !decl_ty.equivalent(init_ty, gctx.types) { + return Err(OverrideError::InvalidType); + } + } + + Ok(()) + } + /// Check the given module to be valid. pub fn validate( &mut self, module: &crate::Module, ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.allow_overrides = true; + self.validate_impl(module) + } + + /// Check the given module to be valid. + /// + /// With the additional restriction that overrides are not present. + pub fn validate_no_overrides( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { + self.allow_overrides = false; + self.validate_impl(module) + } + + fn validate_impl( + &mut self, + module: &crate::Module, + ) -> Result<ModuleInfo, WithSpan<ValidationError>> { self.reset(); self.reset_types(module.types.len()); @@ -354,7 +531,7 @@ impl Validator { type_flags: Vec::with_capacity(module.types.len()), functions: Vec::with_capacity(module.functions.len()), entry_points: Vec::with_capacity(module.entry_points.len()), - const_expression_types: vec![placeholder; module.const_expressions.len()] + const_expression_types: vec![placeholder; module.global_expressions.len()] .into_boxed_slice(), }; @@ -376,27 +553,34 @@ impl Validator { { let t = crate::Arena::new(); let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]); - for (handle, _) in module.const_expressions.iter() { + for (handle, _) in module.global_expressions.iter() { mod_info .process_const_expression(handle, &resolve_context, module.to_ctx()) .map_err(|source| { ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) + .with_span_handle(handle, &module.global_expressions) })? } } + let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions); + if self.flags.contains(ValidationFlags::CONSTANTS) { - for (handle, _) in module.const_expressions.iter() { - self.validate_const_expression(handle, module.to_ctx(), &mod_info) - .map_err(|source| { - ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) - })? + for (handle, _) in module.global_expressions.iter() { + self.validate_const_expression( + handle, + module.to_ctx(), + &mod_info, + &global_expr_kind, + ) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.global_expressions) + })? } for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, module.to_ctx(), &mod_info) + self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::Constant { handle, @@ -406,10 +590,22 @@ impl Validator { .with_span_handle(handle, &module.constants) })? } + + for (handle, override_) in module.overrides.iter() { + self.validate_override(handle, module.to_ctx(), &mod_info) + .map_err(|source| { + ValidationError::Override { + handle, + name: override_.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.overrides) + })? + } } for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, module.to_ctx(), &mod_info) + self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::GlobalVariable { handle: var_handle, @@ -421,7 +617,7 @@ impl Validator { } for (handle, fun) in module.functions.iter() { - match self.validate_function(fun, module, &mod_info, false) { + match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { @@ -447,7 +643,7 @@ impl Validator { .with_span()); // TODO: keep some EP span information? } - match self.validate_entry_point(ep, module, &mod_info) { + match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs index b8eb618ed4..f5b9856074 100644 --- a/third_party/rust/naga/src/valid/type.rs +++ b/third_party/rust/naga/src/valid/type.rs @@ -63,6 +63,7 @@ bitflags::bitflags! { } #[derive(Clone, Copy, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum Disalignment { #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")] ArrayStride { stride: u32, alignment: Alignment }, @@ -87,6 +88,7 @@ pub enum Disalignment { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum TypeError { #[error("Capability {0:?} is required")] MissingCapability(Capabilities), @@ -326,7 +328,6 @@ impl super::Validator { TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::COPY - | TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE | shareable, |