From 8dd16259287f58f9273002717ec4d27e97127719 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 12 Jun 2024 07:43:14 +0200 Subject: Merging upstream version 127.0. Signed-off-by: Daniel Baumann --- third_party/rust/naga/src/back/dot/mod.rs | 91 ++ third_party/rust/naga/src/back/glsl/features.rs | 23 + third_party/rust/naga/src/back/glsl/mod.rs | 148 +++- third_party/rust/naga/src/back/hlsl/conv.rs | 5 + third_party/rust/naga/src/back/hlsl/help.rs | 94 +- third_party/rust/naga/src/back/hlsl/mod.rs | 17 + third_party/rust/naga/src/back/hlsl/writer.rs | 315 ++++++- third_party/rust/naga/src/back/mod.rs | 17 + third_party/rust/naga/src/back/msl/mod.rs | 27 +- third_party/rust/naga/src/back/msl/writer.rs | 192 +++-- .../rust/naga/src/back/pipeline_constants.rs | 957 +++++++++++++++++++++ third_party/rust/naga/src/back/spv/block.rs | 32 +- third_party/rust/naga/src/back/spv/helpers.rs | 53 +- third_party/rust/naga/src/back/spv/instructions.rs | 103 +++ third_party/rust/naga/src/back/spv/mod.rs | 47 +- third_party/rust/naga/src/back/spv/subgroup.rs | 207 +++++ third_party/rust/naga/src/back/spv/writer.rs | 68 +- third_party/rust/naga/src/back/wgsl/writer.rs | 131 ++- 18 files changed, 2400 insertions(+), 127 deletions(-) create mode 100644 third_party/rust/naga/src/back/pipeline_constants.rs create mode 100644 third_party/rust/naga/src/back/spv/subgroup.rs (limited to 'third_party/rust/naga/src/back') diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs index 1556371df1..9a7702b3f6 100644 --- a/third_party/rust/naga/src/back/dot/mod.rs +++ b/third_party/rust/naga/src/back/dot/mod.rs @@ -279,6 +279,94 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.dependencies.push((id, predicate, "predicate")); + } + self.emits.push((id, result)); + "SubgroupBallot" + } + S::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupInclusiveMul", + _ => unimplemented!(), + } + } + S::SubgroupGather { + mode, + argument, + result, + } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -404,6 +492,7 @@ fn write_function_expressions( let (label, color_id) = match *expression { E::Literal(_) => ("Literal".into(), 2), E::Constant(_) => ("Constant".into(), 2), + E::Override(_) => ("Override".into(), 2), E::ZeroValue(_) => ("ZeroValue".into(), 2), E::Compose { ref components, .. } => { payload = Some(Payload::Arguments(components)); @@ -586,6 +675,8 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{}Intersection", ty).into(), 4) } + E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), + E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4), }; // give uniform expressions an outline diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs index 99c128c6d9..e5a43f3e02 100644 --- a/third_party/rust/naga/src/back/glsl/features.rs +++ b/third_party/rust/naga/src/back/glsl/features.rs @@ -50,6 +50,8 @@ bitflags::bitflags! { const INSTANCE_INDEX = 1 << 22; /// Sample specific LODs of cube / array shadow textures const TEXTURE_SHADOW_LOD = 1 << 23; + /// Subgroup operations + const SUBGROUP_OPERATIONS = 1 << 24; } } @@ -117,6 +119,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + check_feature!(SUBGROUP_OPERATIONS, 430, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), @@ -259,6 +262,22 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?; } + if self.0.contains(Features::SUBGROUP_OPERATIONS) { + // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt + writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_arithmetic : require" + )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_shuffle_relative : require" + )?; + } + Ok(()) } } @@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::SubgroupBallotResult | + Expression::SubgroupOperationResult { .. } => { + features.request(Features::SUBGROUP_OPERATIONS) + } _ => {} } } diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs index 9bda594610..c8c7ea557d 100644 --- a/third_party/rust/naga/src/back/glsl/mod.rs +++ b/third_party/rust/naga/src/back/glsl/mod.rs @@ -282,7 +282,7 @@ impl Default for Options { } /// A subset of options meant to be changed per pipeline. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { @@ -497,6 +497,8 @@ pub enum Error { ImageMultipleSamplers, #[error("{0}")] Custom(String), + #[error("overrides should not be present at this stage")] + Override, } /// Binary operation with a different logic on the GLSL side. @@ -565,6 +567,10 @@ impl<'a, W: Write> Writer<'a, W> { pipeline_options: &'a PipelineOptions, policies: proc::BoundsCheckPolicies, ) -> Result { + if !module.overrides.is_empty() { + return Err(Error::Override); + } + // Check if the requested version is supported if !options.version.is_supported() { log::error!("Version {}", options.version); @@ -2384,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + write!(self.out, "subgroupBallot(")?; + match predicate { + Some(predicate) => self.write_expr(predicate, ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -2402,7 +2527,7 @@ impl<'a, W: Write> Writer<'a, W> { fn write_const_expr(&mut self, expr: Handle) -> BackendResult { self.write_possibly_const_expr( expr, - &self.module.const_expressions, + &self.module.global_expressions, |expr| &self.info[expr], |writer, expr| writer.write_const_expr(expr), ) @@ -2536,6 +2661,7 @@ impl<'a, W: Write> Writer<'a, W> { |writer, expr| writer.write_expr(expr, ctx), )?; } + Expression::Override(_) => return Err(Error::Override), // `Access` is applied to arrays, vectors and matrices and is written as indexing Expression::Access { base, index } => { self.write_expr(base, ctx)?; @@ -3411,7 +3537,8 @@ impl<'a, W: Write> Writer<'a, W> { let scalar_bits = ctx .resolve_type(arg, &self.module.types) .scalar_width() - .unwrap(); + .unwrap() + * 8; write!(self.out, "bitfieldExtract(")?; self.write_expr(arg, ctx)?; @@ -3430,7 +3557,8 @@ impl<'a, W: Write> Writer<'a, W> { let scalar_bits = ctx .resolve_type(arg, &self.module.types) .scalar_width() - .unwrap(); + .unwrap() + * 8; write!(self.out, "bitfieldInsert(")?; self.write_expr(arg, ctx)?; @@ -3649,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult - | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupOperationResult { .. } + | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; @@ -4218,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> { if flags.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}memoryBarrierShared();")?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; + } writeln!(self.out, "{level}barrier();")?; Ok(()) } @@ -4487,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s Bi::WorkGroupId => "gl_WorkGroupID", Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", + // subgroup + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", + Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs index 2a6db35db8..7d15f43f6c 100644 --- a/third_party/rust/naga/src/back/hlsl/conv.rs +++ b/third_party/rust/naga/src/back/hlsl/conv.rs @@ -179,6 +179,11 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", + // These builtins map to functions + Self::SubgroupSize + | Self::SubgroupInvocationId + | Self::NumSubgroups + | Self::SubgroupId => unreachable!(), Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { return Err(Error::Unimplemented(format!("builtin {self:?}"))) } diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs index 4dd9ea5987..d3bb1ce7f5 100644 --- a/third_party/rust/naga/src/back/hlsl/help.rs +++ b/third_party/rust/naga/src/back/hlsl/help.rs @@ -70,6 +70,11 @@ pub(super) struct WrappedMath { pub(super) components: Option, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +pub(super) struct WrappedZeroValue { + pub(super) ty: Handle, +} + /// HLSL backend requires its own `ImageQuery` enum. /// /// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function. @@ -359,7 +364,7 @@ impl<'a, W: Write> super::Writer<'a, W> { } /// Helper function that write wrapped function for `Expression::Compose` for structures. - pub(super) fn write_wrapped_constructor_function( + fn write_wrapped_constructor_function( &mut self, module: &crate::Module, constructor: WrappedConstructor, @@ -862,6 +867,25 @@ impl<'a, W: Write> super::Writer<'a, W> { Ok(()) } + // TODO: we could merge this with iteration in write_wrapped_compose_functions... + // + /// Helper function that writes zero value wrapped functions + pub(super) fn write_wrapped_zero_value_functions( + &mut self, + module: &crate::Module, + expressions: &crate::Arena, + ) -> BackendResult { + for (handle, _) in expressions.iter() { + if let crate::Expression::ZeroValue(ty) = expressions[handle] { + let zero_value = WrappedZeroValue { ty }; + if self.wrapped.zero_values.insert(zero_value) { + self.write_wrapped_zero_value_function(module, zero_value)?; + } + } + } + Ok(()) + } + pub(super) fn write_wrapped_math_functions( &mut self, module: &crate::Module, @@ -1006,6 +1030,7 @@ impl<'a, W: Write> super::Writer<'a, W> { ) -> BackendResult { self.write_wrapped_math_functions(module, func_ctx)?; self.write_wrapped_compose_functions(module, func_ctx.expressions)?; + self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?; for (handle, _) in func_ctx.expressions.iter() { match func_ctx.expressions[handle] { @@ -1283,4 +1308,71 @@ impl<'a, W: Write> super::Writer<'a, W> { Ok(()) } + + pub(super) fn write_wrapped_zero_value_function_name( + &mut self, + module: &crate::Module, + zero_value: WrappedZeroValue, + ) -> BackendResult { + let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?; + write!(self.out, "ZeroValue{name}")?; + Ok(()) + } + + /// Helper function that write wrapped function for `Expression::ZeroValue` + /// + /// This is necessary since we might have a member access after the zero value expression, e.g. + /// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc). + /// + /// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly. + /// + /// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle + /// cases like: + /// + /// ```ignore + /// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet + /// t_1.am = (__mat4x2[2])((float4x2[2])0); + /// ^ + /// ``` + fn write_wrapped_zero_value_function( + &mut self, + module: &crate::Module, + zero_value: WrappedZeroValue, + ) -> BackendResult { + use crate::back::INDENT; + + const RETURN_VARIABLE_NAME: &str = "ret"; + + // Write function return type and name + if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner { + write!(self.out, "typedef ")?; + self.write_type(module, zero_value.ty)?; + write!(self.out, " ret_")?; + self.write_wrapped_zero_value_function_name(module, zero_value)?; + self.write_array_size(module, base, size)?; + writeln!(self.out, ";")?; + + write!(self.out, "ret_")?; + self.write_wrapped_zero_value_function_name(module, zero_value)?; + } else { + self.write_type(module, zero_value.ty)?; + } + write!(self.out, " ")?; + self.write_wrapped_zero_value_function_name(module, zero_value)?; + + // Write function parameters (none) and start function body + writeln!(self.out, "() {{")?; + + // Write `ZeroValue` function. + write!(self.out, "{INDENT}return ")?; + self.write_default_init(module, zero_value.ty)?; + writeln!(self.out, ";")?; + + // End of function body + writeln!(self.out, "}}")?; + // Write extra new line + writeln!(self.out)?; + + Ok(()) + } } diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs index f37a223f47..28edbf70e1 100644 --- a/third_party/rust/naga/src/back/hlsl/mod.rs +++ b/third_party/rust/naga/src/back/hlsl/mod.rs @@ -131,6 +131,13 @@ pub enum ShaderModel { V5_0, V5_1, V6_0, + V6_1, + V6_2, + V6_3, + V6_4, + V6_5, + V6_6, + V6_7, } impl ShaderModel { @@ -139,6 +146,13 @@ impl ShaderModel { Self::V5_0 => "5_0", Self::V5_1 => "5_1", Self::V6_0 => "6_0", + Self::V6_1 => "6_1", + Self::V6_2 => "6_2", + Self::V6_3 => "6_3", + Self::V6_4 => "6_4", + Self::V6_5 => "6_5", + Self::V6_6 => "6_6", + Self::V6_7 => "6_7", } } } @@ -247,10 +261,13 @@ pub enum Error { Unimplemented(String), // TODO: Error used only during development #[error("{0}")] Custom(String), + #[error("overrides should not be present at this stage")] + Override, } #[derive(Default)] struct Wrapped { + zero_values: crate::FastHashSet, array_lengths: crate::FastHashSet, image_queries: crate::FastHashSet, constructors: crate::FastHashSet, diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs index 4ba856946b..86d8f89035 100644 --- a/third_party/rust/naga/src/back/hlsl/writer.rs +++ b/third_party/rust/naga/src/back/hlsl/writer.rs @@ -1,5 +1,8 @@ use super::{ - help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess}, + help::{ + WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess, + WrappedZeroValue, + }, storage::StoreValue, BackendResult, Error, Options, }; @@ -77,6 +80,19 @@ enum Io { Output, } +const fn is_subgroup_builtin_binding(binding: &Option) -> bool { + let &Some(crate::Binding::BuiltIn(builtin)) = binding else { + return false; + }; + matches!( + builtin, + crate::BuiltIn::SubgroupSize + | crate::BuiltIn::SubgroupInvocationId + | crate::BuiltIn::NumSubgroups + | crate::BuiltIn::SubgroupId + ) +} + impl<'a, W: fmt::Write> super::Writer<'a, W> { pub fn new(out: W, options: &'a Options) -> Self { Self { @@ -161,6 +177,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } } + for statement in func.body.iter() { + match *statement { + crate::Statement::SubgroupCollectiveOperation { + op: _, + collective_op: crate::CollectiveOperation::InclusiveScan, + argument, + result: _, + } => { + self.need_bake_expressions.insert(argument); + } + _ => {} + } + } } pub fn write( @@ -168,6 +197,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { module: &Module, module_info: &valid::ModuleInfo, ) -> Result { + if !module.overrides.is_empty() { + return Err(Error::Override); + } + self.reset(module); // Write special constants, if needed @@ -233,7 +266,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_special_functions(module)?; - self.write_wrapped_compose_functions(module, &module.const_expressions)?; + self.write_wrapped_compose_functions(module, &module.global_expressions)?; + self.write_wrapped_zero_value_functions(module, &module.global_expressions)?; // Write all named constants let mut constants = module @@ -397,31 +431,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // if they are struct, so that the `stage` argument here could be omitted. fn write_semantic( &mut self, - binding: &crate::Binding, + binding: &Option, stage: Option<(ShaderStage, Io)>, ) -> BackendResult { match *binding { - crate::Binding::BuiltIn(builtin) => { + Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => { let builtin_str = builtin.to_hlsl_str()?; write!(self.out, " : {builtin_str}")?; } - crate::Binding::Location { + Some(crate::Binding::Location { second_blend_source: true, .. - } => { + }) => { write!(self.out, " : SV_Target1")?; } - crate::Binding::Location { + Some(crate::Binding::Location { location, second_blend_source: false, .. - } => { + }) => { if stage == Some((crate::ShaderStage::Fragment, Io::Output)) { write!(self.out, " : SV_Target{location}")?; } else { write!(self.out, " : {LOCATION_SEMANTIC}{location}")?; } } + _ => {} } Ok(()) @@ -442,17 +477,30 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "struct {struct_name}")?; writeln!(self.out, " {{")?; for m in members.iter() { + if is_subgroup_builtin_binding(&m.binding) { + continue; + } write!(self.out, "{}", back::INDENT)?; if let Some(ref binding) = m.binding { self.write_modifier(binding)?; } self.write_type(module, m.ty)?; write!(self.out, " {}", &m.name)?; - if let Some(ref binding) = m.binding { - self.write_semantic(binding, Some(shader_stage))?; - } + self.write_semantic(&m.binding, Some(shader_stage))?; writeln!(self.out, ";")?; } + if members.iter().any(|arg| { + matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) + ) + }) { + writeln!( + self.out, + "{}uint __local_invocation_index : SV_GroupIndex;", + back::INDENT + )?; + } writeln!(self.out, "}};")?; writeln!(self.out)?; @@ -553,8 +601,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } /// Writes special interface structures for an entry point. The special structures have - /// all the fields flattened into them and sorted by binding. They are only needed for - /// VS outputs and FS inputs, so that these interfaces match. + /// all the fields flattened into them and sorted by binding. They are needed to emulate + /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match. fn write_ep_interface( &mut self, module: &Module, @@ -563,7 +611,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { ep_name: &str, ) -> Result { Ok(EntryPointInterface { - input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment { + input: if !func.arguments.is_empty() + && (stage == ShaderStage::Fragment + || func + .arguments + .iter() + .any(|arg| is_subgroup_builtin_binding(&arg.binding))) + { Some(self.write_ep_input_struct(module, func, stage, ep_name)?) } else { None @@ -577,6 +631,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { }) } + fn write_ep_argument_initialization( + &mut self, + ep: &crate::EntryPoint, + ep_input: &EntryPointBinding, + fake_member: &EpStructMember, + ) -> BackendResult { + match fake_member.binding { + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => { + write!(self.out, "WaveGetLaneCount()")? + } + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => { + write!(self.out, "WaveGetLaneIndex()")? + } + Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!( + self.out, + "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()", + ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] + )?, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => { + write!( + self.out, + "{}.__local_invocation_index / WaveGetLaneCount()", + ep_input.arg_name + )?; + } + _ => { + write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?; + } + } + Ok(()) + } + /// Write an entry point preface that initializes the arguments as specified in IR. fn write_ep_arguments_initialization( &mut self, @@ -584,6 +670,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { func: &crate::Function, ep_index: u16, ) -> BackendResult { + let ep = &module.entry_points[ep_index as usize]; let ep_input = match self.entry_point_io[ep_index as usize].input.take() { Some(ep_input) => ep_input, None => return Ok(()), @@ -597,8 +684,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { match module.types[arg.ty].inner { TypeInner::Array { base, size, .. } => { self.write_array_size(module, base, size)?; - let fake_member = fake_iter.next().unwrap(); - writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + write!(self.out, " = ")?; + self.write_ep_argument_initialization( + ep, + &ep_input, + fake_iter.next().unwrap(), + )?; + writeln!(self.out, ";")?; } TypeInner::Struct { ref members, .. } => { write!(self.out, " = {{ ")?; @@ -606,14 +698,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if index != 0 { write!(self.out, ", ")?; } - let fake_member = fake_iter.next().unwrap(); - write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?; + self.write_ep_argument_initialization( + ep, + &ep_input, + fake_iter.next().unwrap(), + )?; } writeln!(self.out, " }};")?; } _ => { - let fake_member = fake_iter.next().unwrap(); - writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?; + write!(self.out, " = ")?; + self.write_ep_argument_initialization( + ep, + &ep_input, + fake_iter.next().unwrap(), + )?; + writeln!(self.out, ";")?; } } } @@ -928,9 +1028,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } - if let Some(ref binding) = member.binding { - self.write_semantic(binding, shader_stage)?; - }; + self.write_semantic(&member.binding, shader_stage)?; writeln!(self.out, ";")?; } @@ -1143,7 +1241,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } back::FunctionType::EntryPoint(ep_index) => { if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input { - write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; + write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?; } else { let stage = module.entry_points[ep_index as usize].stage; for (index, arg) in func.arguments.iter().enumerate() { @@ -1160,17 +1258,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_array_size(module, base, size)?; } - if let Some(ref binding) = arg.binding { - self.write_semantic(binding, Some((stage, Io::Input)))?; - } + self.write_semantic(&arg.binding, Some((stage, Io::Input)))?; } - - if need_workgroup_variables_initialization { - if !func.arguments.is_empty() { - write!(self.out, ", ")?; - } - write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; + } + if need_workgroup_variables_initialization { + if self.entry_point_io[ep_index as usize].input.is_some() + || !func.arguments.is_empty() + { + write!(self.out, ", ")?; } + write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; } } } @@ -1180,11 +1277,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { // Write semantic if it present if let back::FunctionType::EntryPoint(index) = func_ctx.ty { let stage = module.entry_points[index as usize].stage; - if let Some(crate::FunctionResult { - binding: Some(ref binding), - .. - }) = func.result - { + if let Some(crate::FunctionResult { ref binding, .. }) = func.result { self.write_semantic(binding, Some((stage, Io::Output)))?; } } @@ -1984,6 +2077,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "const uint4 {name} = ")?; + self.named_expressions.insert(result, name); + + write!(self.out, "WaveActiveBallot(")?; + match predicate { + Some(predicate) => self.write_expr(module, predicate, func_ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -1997,7 +2213,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_possibly_const_expression( module, expr, - &module.const_expressions, + &module.global_expressions, |writer, expr| writer.write_const_expression(module, expr), ) } @@ -2039,7 +2255,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_const_expression(module, constant.init)?; } } - Expression::ZeroValue(ty) => self.write_default_init(module, ty)?, + Expression::ZeroValue(ty) => { + self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?; + write!(self.out, "()")?; + } Expression::Compose { ty, ref components } => { match module.types[ty].inner { TypeInner::Struct { .. } | TypeInner::Array { .. } => { @@ -2140,6 +2359,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => return Err(Error::Override), // All of the multiplication can be expressed as `mul`, // except vector * vector, which needs to use the "*" operator. Expression::Binary { @@ -2588,7 +2808,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { true } None => { - if inner.scalar_width() == Some(64) { + if inner.scalar_width() == Some(8) { false } else { write!(self.out, "{}(", kind.to_hlsl_cast(),)?; @@ -3129,7 +3349,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } - | Expression::RayQueryProceedResult => {} + | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} } if !closing_bracket.is_empty() { @@ -3179,7 +3401,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } /// Helper function that write default zero initialization - fn write_default_init(&mut self, module: &Module, ty: Handle) -> BackendResult { + pub(super) fn write_default_init( + &mut self, + module: &Module, + ty: Handle, + ) -> BackendResult { write!(self.out, "(")?; self.write_type(module, ty)?; if let TypeInner::Array { base, size, .. } = module.types[ty].inner { @@ -3196,6 +3422,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } + if barrier.contains(crate::Barrier::SUB_GROUP) { + // Does not exist in DirectX + } Ok(()) } } diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs index c8f091decb..0c9c5e4761 100644 --- a/third_party/rust/naga/src/back/mod.rs +++ b/third_party/rust/naga/src/back/mod.rs @@ -16,6 +16,14 @@ pub mod spv; #[cfg(feature = "wgsl-out")] pub mod wgsl; +#[cfg(any( + feature = "hlsl-out", + feature = "msl-out", + feature = "spv-out", + feature = "glsl-out" +))] +pub mod pipeline_constants; + /// Names of vector components. pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; /// Indent for backends. @@ -26,6 +34,15 @@ pub const BAKE_PREFIX: &str = "_e"; /// Expressions that need baking. pub type NeedBakeExpressions = crate::FastHashSet>; +/// Specifies the values of pipeline-overridable constants in the shader module. +/// +/// If an `@id` attribute was specified on the declaration, +/// the key must be the pipeline constant ID as a decimal ASCII number; if not, +/// the key must be the constant's identifier name. +/// +/// The value may represent any of WGSL's concrete scalar types. +pub type PipelineConstants = std::collections::HashMap; + /// Indentation level. #[derive(Clone, Copy)] pub struct Level(pub usize); diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs index 68e5b79906..8b03e20376 100644 --- a/third_party/rust/naga/src/back/msl/mod.rs +++ b/third_party/rust/naga/src/back/msl/mod.rs @@ -143,6 +143,8 @@ pub enum Error { UnsupportedArrayOfType(Handle), #[error("ray tracing is not supported prior to MSL 2.3")] UnsupportedRayTracing, + #[error("overrides should not be present at this stage")] + Override, } #[derive(Clone, Debug, PartialEq, thiserror::Error)] @@ -221,7 +223,7 @@ impl Default for Options { } /// A subset of options that are meant to be changed per pipeline. -#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { @@ -434,6 +436,11 @@ impl ResolvedBinding { Bi::WorkGroupId => "threadgroup_position_in_grid", Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", + // subgroup + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } @@ -536,3 +543,21 @@ fn test_error_size() { use std::mem::size_of; assert_eq!(size_of::(), 32); } + +impl crate::AtomicFunction { + fn to_msl(self) -> Result<&'static str, Error> { + Ok(match self { + Self::Add => "fetch_add", + Self::Subtract => "fetch_sub", + Self::And => "fetch_and", + Self::InclusiveOr => "fetch_or", + Self::ExclusiveOr => "fetch_xor", + Self::Min => "fetch_min", + Self::Max => "fetch_max", + Self::Exchange { compare: None } => "exchange", + Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented( + "atomic CompareExchange".to_string(), + ))?, + }) + } +} diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs index 5227d8e7db..e250d0b72c 100644 --- a/third_party/rust/naga/src/back/msl/writer.rs +++ b/third_party/rust/naga/src/back/msl/writer.rs @@ -1131,21 +1131,10 @@ impl Writer { Ok(()) } - fn put_atomic_fetch( - &mut self, - pointer: Handle, - key: &str, - value: Handle, - context: &ExpressionContext, - ) -> BackendResult { - self.put_atomic_operation(pointer, "fetch_", key, value, context) - } - fn put_atomic_operation( &mut self, pointer: Handle, - key1: &str, - key2: &str, + key: &str, value: Handle, context: &ExpressionContext, ) -> BackendResult { @@ -1163,7 +1152,7 @@ impl Writer { write!( self.out, - "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}" + "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}" )?; self.put_access_chain(pointer, policy, context)?; write!(self.out, ", ")?; @@ -1248,7 +1237,7 @@ impl Writer { ) -> BackendResult { self.put_possibly_const_expression( expr_handle, - &module.const_expressions, + &module.global_expressions, module, mod_info, &(module, mod_info), @@ -1431,6 +1420,7 @@ impl Writer { |writer, context, expr| writer.put_expression(expr, context, true), )?; } + crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { // This is an acceptable place to generate a `ReadZeroSkipWrite` check. @@ -1944,7 +1934,7 @@ impl Writer { // // extract_bits(e, min(offset, w), min(count, w - min(offset, w)))) - let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; write!(self.out, "{NAMESPACE}::extract_bits(")?; self.put_expression(arg, context, true)?; @@ -1960,7 +1950,7 @@ impl Writer { // // insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w)))) - let scalar_bits = context.resolve_type(arg).scalar_width().unwrap(); + let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8; write!(self.out, "{NAMESPACE}::insert_bits(")?; self.put_expression(arg, context, true)?; @@ -2041,6 +2031,8 @@ impl Writer { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -2994,43 +2986,8 @@ impl Writer { let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); self.start_baking_expression(result, &context.expression, &res_name)?; self.named_expressions.insert(result, res_name); - match *fun { - crate::AtomicFunction::Add => { - self.put_atomic_fetch(pointer, "add", value, &context.expression)?; - } - crate::AtomicFunction::Subtract => { - self.put_atomic_fetch(pointer, "sub", value, &context.expression)?; - } - crate::AtomicFunction::And => { - self.put_atomic_fetch(pointer, "and", value, &context.expression)?; - } - crate::AtomicFunction::InclusiveOr => { - self.put_atomic_fetch(pointer, "or", value, &context.expression)?; - } - crate::AtomicFunction::ExclusiveOr => { - self.put_atomic_fetch(pointer, "xor", value, &context.expression)?; - } - crate::AtomicFunction::Min => { - self.put_atomic_fetch(pointer, "min", value, &context.expression)?; - } - crate::AtomicFunction::Max => { - self.put_atomic_fetch(pointer, "max", value, &context.expression)?; - } - crate::AtomicFunction::Exchange { compare: None } => { - self.put_atomic_operation( - pointer, - "exchange", - "", - value, - &context.expression, - )?; - } - crate::AtomicFunction::Exchange { .. } => { - return Err(Error::FeatureNotImplemented( - "atomic CompareExchange".to_string(), - )); - } - } + let fun_str = fun.to_msl()?; + self.put_atomic_operation(pointer, fun_str, value, &context.expression)?; // done writeln!(self.out, ";")?; } @@ -3144,6 +3101,121 @@ impl Writer { } } } + crate::Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; + } + writeln!(self.out, "), 0, 0, 0);")?; + } + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; + } } } @@ -3220,6 +3292,10 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { + if !module.overrides.is_empty() { + return Err(Error::Override); + } + self.names.clear(); self.namer.reset( module, @@ -4487,6 +4563,12 @@ impl Writer { "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; + } Ok(()) } } @@ -4757,8 +4839,8 @@ fn test_stack_size() { } let stack_size = addresses_end - addresses_start; // check the size (in debug only) - // last observed macOS value: 19152 (CI) - if !(9000..=20000).contains(&stack_size) { + // last observed macOS value: 22256 (CI) + if !(15000..=25000).contains(&stack_size) { panic!("`put_block` stack size {stack_size} has changed!"); } } diff --git a/third_party/rust/naga/src/back/pipeline_constants.rs b/third_party/rust/naga/src/back/pipeline_constants.rs new file mode 100644 index 0000000000..0dbe9cf4e8 --- /dev/null +++ b/third_party/rust/naga/src/back/pipeline_constants.rs @@ -0,0 +1,957 @@ +use super::PipelineConstants; +use crate::{ + proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter}, + valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, + Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar, + Span, Statement, TypeInner, WithSpan, +}; +use std::{borrow::Cow, collections::HashSet, mem}; +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum PipelineConstantError { + #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")] + MissingValue(String), + #[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")] + SrcNeedsToBeFinite, + #[error("Source f64 value doesn't fit in destination")] + DstRangeTooSmall, + #[error(transparent)] + ConstantEvaluatorError(#[from] ConstantEvaluatorError), + #[error(transparent)] + ValidationError(#[from] WithSpan), +} + +/// Replace all overrides in `module` with constants. +/// +/// If no changes are needed, this just returns `Cow::Borrowed` +/// references to `module` and `module_info`. Otherwise, it clones +/// `module`, edits its [`global_expressions`] arena to contain only +/// fully-evaluated expressions, and returns `Cow::Owned` values +/// holding the simplified module and its validation results. +/// +/// In either case, the module returned has an empty `overrides` +/// arena, and the `global_expressions` arena contains only +/// fully-evaluated expressions. +/// +/// [`global_expressions`]: Module::global_expressions +pub fn process_overrides<'a>( + module: &'a Module, + module_info: &'a ModuleInfo, + pipeline_constants: &PipelineConstants, +) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> { + if module.overrides.is_empty() { + return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info))); + } + + let mut module = module.clone(); + + // A map from override handles to the handles of the constants + // we've replaced them with. + let mut override_map = Vec::with_capacity(module.overrides.len()); + + // A map from `module`'s original global expression handles to + // handles in the new, simplified global expression arena. + let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len()); + + // The set of constants whose initializer handles we've already + // updated to refer to the newly built global expression arena. + // + // All constants in `module` must have their `init` handles + // updated to point into the new, simplified global expression + // arena. Some of these we can most easily handle as a side effect + // during the simplification process, but we must handle the rest + // in a final fixup pass, guided by `adjusted_global_expressions`. We + // add their handles to this set, so that the final fixup step can + // leave them alone. + let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); + + let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + + // An iterator through the original overrides table, consumed in + // approximate tandem with the global expressions. + let mut override_iter = module.overrides.drain(); + + // Do two things in tandem: + // + // - Rebuild the global expression arena from scratch, fully + // evaluating all expressions, and replacing each `Override` + // expression in `module.global_expressions` with a `Constant` + // expression. + // + // - Build a new `Constant` in `module.constants` to take the + // place of each `Override`. + // + // Build a map from old global expression handles to their + // fully-evaluated counterparts in `adjusted_global_expressions` as we + // go. + // + // Why in tandem? Overrides refer to expressions, and expressions + // refer to overrides, so we can't disentangle the two into + // separate phases. However, we can take advantage of the fact + // that the overrides and expressions must form a DAG, and work + // our way from the leaves to the roots, replacing and evaluating + // as we go. + // + // Although the two loops are nested, this is really two + // alternating phases: we adjust and evaluate constant expressions + // until we hit an `Override` expression, at which point we switch + // to building `Constant`s for `Overrides` until we've handled the + // one used by the expression. Then we switch back to processing + // expressions. Because we know they form a DAG, we know the + // `Override` expressions we encounter can only have initializers + // referring to global expressions we've already simplified. + for (old_h, expr, span) in module.global_expressions.drain() { + let mut expr = match expr { + Expression::Override(h) => { + let c_h = if let Some(new_h) = override_map.get(h.index()) { + *new_h + } else { + let mut new_h = None; + for entry in override_iter.by_ref() { + let stop = entry.0 == h; + new_h = Some(process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_global_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?); + if stop { + break; + } + } + new_h.unwrap() + }; + Expression::Constant(c_h) + } + Expression::Constant(c_h) => { + if adjusted_constant_initializers.insert(c_h) { + let init = &mut module.constants[c_h].init; + *init = adjusted_global_expressions[init.index()]; + } + expr + } + expr => expr, + }; + let mut evaluator = ConstantEvaluator::for_wgsl_module( + &mut module, + &mut global_expression_kind_tracker, + false, + ); + adjust_expr(&adjusted_global_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_global_expressions.len()); + adjusted_global_expressions.push(h); + } + + // Finish processing any overrides we didn't visit in the loop above. + for entry in override_iter { + process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_global_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?; + } + + // Update the initialization expression handles of all `Constant`s + // and `GlobalVariable`s. Skip `Constant`s we'd already updated en + // passant. + for (_, c) in module + .constants + .iter_mut() + .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h)) + { + c.init = adjusted_global_expressions[c.init.index()]; + } + + for (_, v) in module.global_variables.iter_mut() { + if let Some(ref mut init) = v.init { + *init = adjusted_global_expressions[init.index()]; + } + } + + let mut functions = mem::take(&mut module.functions); + for (_, function) in functions.iter_mut() { + process_function(&mut module, &override_map, function)?; + } + module.functions = functions; + + let mut entry_points = mem::take(&mut module.entry_points); + for ep in entry_points.iter_mut() { + process_function(&mut module, &override_map, &mut ep.function)?; + } + module.entry_points = entry_points; + + // Now that we've rewritten all the expressions, we need to + // recompute their types and other metadata. For the time being, + // do a full re-validation. + let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); + let module_info = validator.validate_no_overrides(&module)?; + + Ok((Cow::Owned(module), Cow::Owned(module_info))) +} + +/// Add a [`Constant`] to `module` for the override `old_h`. +/// +/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. +fn process_override( + (old_h, override_, span): (Handle, Override, Span), + pipeline_constants: &PipelineConstants, + module: &mut Module, + override_map: &mut Vec>, + adjusted_global_expressions: &[Handle], + adjusted_constant_initializers: &mut HashSet>, + global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker, +) -> Result, PipelineConstantError> { + // Determine which key to use for `override_` in `pipeline_constants`. + let key = if let Some(id) = override_.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = override_.name { + Cow::Borrowed(name) + } else { + unreachable!(); + }; + + // Generate a global expression for `override_`'s value, either + // from the provided `pipeline_constants` table or its initializer + // in the module. + let init = if let Some(value) = pipeline_constants.get::(&key) { + let literal = match module.types[override_.ty].inner { + TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, + _ => unreachable!(), + }; + let expr = module + .global_expressions + .append(Expression::Literal(literal), Span::UNDEFINED); + global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); + expr + } else if let Some(init) = override_.init { + adjusted_global_expressions[init.index()] + } else { + return Err(PipelineConstantError::MissingValue(key.to_string())); + }; + + // Generate a new `Constant` to represent the override's value. + let constant = Constant { + name: override_.name, + ty: override_.ty, + init, + }; + let h = module.constants.append(constant, span); + debug_assert_eq!(old_h.index(), override_map.len()); + override_map.push(h); + adjusted_constant_initializers.insert(h); + Ok(h) +} + +/// Replace all override expressions in `function` with fully-evaluated constants. +/// +/// Replace all `Expression::Override`s in `function`'s expression arena with +/// the corresponding `Expression::Constant`s, as given in `override_map`. +/// Replace any expressions whose values are now known with their fully +/// evaluated form. +/// +/// If `h` is a `Handle`, then `override_map[h.index()]` is the +/// `Handle` for the override's final value. +fn process_function( + module: &mut Module, + override_map: &[Handle], + function: &mut Function, +) -> Result<(), ConstantEvaluatorError> { + // A map from original local expression handles to + // handles in the new, local expression arena. + let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len()); + + let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new(); + + let mut expressions = mem::take(&mut function.expressions); + + // Dummy `emitter` and `block` for the constant evaluator. + // We can ignore the concept of emitting expressions here since + // expressions have already been covered by a `Statement::Emit` + // in the frontend. + // The only thing we might have to do is remove some expressions + // that have been covered by a `Statement::Emit`. See the docs of + // `filter_emits_in_block` for the reasoning. + let mut emitter = Emitter::default(); + let mut block = Block::new(); + + let mut evaluator = ConstantEvaluator::for_wgsl_function( + module, + &mut function.expressions, + &mut local_expression_kind_tracker, + &mut emitter, + &mut block, + ); + + for (old_h, mut expr, span) in expressions.drain() { + if let Expression::Override(h) = expr { + expr = Expression::Constant(override_map[h.index()]); + } + adjust_expr(&adjusted_local_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_local_expressions.len()); + adjusted_local_expressions.push(h); + } + + adjust_block(&adjusted_local_expressions, &mut function.body); + + filter_emits_in_block(&mut function.body, &function.expressions); + + // Update local expression initializers. + for (_, local) in function.local_variables.iter_mut() { + if let &mut Some(ref mut init) = &mut local.init { + *init = adjusted_local_expressions[init.index()]; + } + } + + // We've changed the keys of `function.named_expression`, so we have to + // rebuild it from scratch. + let named_expressions = mem::take(&mut function.named_expressions); + for (expr_h, name) in named_expressions { + function + .named_expressions + .insert(adjusted_local_expressions[expr_h.index()], name); + } + + Ok(()) +} + +/// Replace every expression handle in `expr` with its counterpart +/// given by `new_pos`. +fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { + let adjust = |expr: &mut Handle| { + *expr = new_pos[expr.index()]; + }; + match *expr { + Expression::Compose { + ref mut components, + ty: _, + } => { + for c in components.iter_mut() { + adjust(c); + } + } + Expression::Access { + ref mut base, + ref mut index, + } => { + adjust(base); + adjust(index); + } + Expression::AccessIndex { + ref mut base, + index: _, + } => { + adjust(base); + } + Expression::Splat { + ref mut value, + size: _, + } => { + adjust(value); + } + Expression::Swizzle { + ref mut vector, + size: _, + pattern: _, + } => { + adjust(vector); + } + Expression::Load { ref mut pointer } => { + adjust(pointer); + } + Expression::ImageSample { + ref mut image, + ref mut sampler, + ref mut coordinate, + ref mut array_index, + ref mut offset, + ref mut level, + ref mut depth_ref, + gather: _, + } => { + adjust(image); + adjust(sampler); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = offset.as_mut() { + adjust(e); + } + match *level { + crate::SampleLevel::Exact(ref mut expr) + | crate::SampleLevel::Bias(ref mut expr) => { + adjust(expr); + } + crate::SampleLevel::Gradient { + ref mut x, + ref mut y, + } => { + adjust(x); + adjust(y); + } + _ => {} + } + if let Some(e) = depth_ref.as_mut() { + adjust(e); + } + } + Expression::ImageLoad { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut sample, + ref mut level, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = sample.as_mut() { + adjust(e); + } + if let Some(e) = level.as_mut() { + adjust(e); + } + } + Expression::ImageQuery { + ref mut image, + ref mut query, + } => { + adjust(image); + match *query { + crate::ImageQuery::Size { ref mut level } => { + if let Some(e) = level.as_mut() { + adjust(e); + } + } + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => {} + } + } + Expression::Unary { + ref mut expr, + op: _, + } => { + adjust(expr); + } + Expression::Binary { + ref mut left, + ref mut right, + op: _, + } => { + adjust(left); + adjust(right); + } + Expression::Select { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust(accept); + adjust(reject); + } + Expression::Derivative { + ref mut expr, + axis: _, + ctrl: _, + } => { + adjust(expr); + } + Expression::Relational { + ref mut argument, + fun: _, + } => { + adjust(argument); + } + Expression::Math { + ref mut arg, + ref mut arg1, + ref mut arg2, + ref mut arg3, + fun: _, + } => { + adjust(arg); + if let Some(e) = arg1.as_mut() { + adjust(e); + } + if let Some(e) = arg2.as_mut() { + adjust(e); + } + if let Some(e) = arg3.as_mut() { + adjust(e); + } + } + Expression::As { + ref mut expr, + kind: _, + convert: _, + } => { + adjust(expr); + } + Expression::ArrayLength(ref mut expr) => { + adjust(expr); + } + Expression::RayQueryGetIntersection { + ref mut query, + committed: _, + } => { + adjust(query); + } + Expression::Literal(_) + | Expression::FunctionArgument(_) + | Expression::GlobalVariable(_) + | Expression::LocalVariable(_) + | Expression::CallResult(_) + | Expression::RayQueryProceedResult + | Expression::Constant(_) + | Expression::Override(_) + | Expression::ZeroValue(_) + | Expression::AtomicResult { + ty: _, + comparison: _, + } + | Expression::WorkGroupUniformLoadResult { ty: _ } + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} + } +} + +/// Replace every expression handle in `block` with its counterpart +/// given by `new_pos`. +fn adjust_block(new_pos: &[Handle], block: &mut Block) { + for stmt in block.iter_mut() { + adjust_stmt(new_pos, stmt); + } +} + +/// Replace every expression handle in `stmt` with its counterpart +/// given by `new_pos`. +fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { + let adjust = |expr: &mut Handle| { + *expr = new_pos[expr.index()]; + }; + match *stmt { + Statement::Emit(ref mut range) => { + if let Some((mut first, mut last)) = range.first_and_last() { + adjust(&mut first); + adjust(&mut last); + *range = Range::new_from_bounds(first, last); + } + } + Statement::Block(ref mut block) => { + adjust_block(new_pos, block); + } + Statement::If { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust_block(new_pos, accept); + adjust_block(new_pos, reject); + } + Statement::Switch { + ref mut selector, + ref mut cases, + } => { + adjust(selector); + for case in cases.iter_mut() { + adjust_block(new_pos, &mut case.body); + } + } + Statement::Loop { + ref mut body, + ref mut continuing, + ref mut break_if, + } => { + adjust_block(new_pos, body); + adjust_block(new_pos, continuing); + if let Some(e) = break_if.as_mut() { + adjust(e); + } + } + Statement::Return { ref mut value } => { + if let Some(e) = value.as_mut() { + adjust(e); + } + } + Statement::Store { + ref mut pointer, + ref mut value, + } => { + adjust(pointer); + adjust(value); + } + Statement::ImageStore { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut value, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + adjust(value); + } + crate::Statement::Atomic { + ref mut pointer, + ref mut value, + ref mut result, + ref mut fun, + } => { + adjust(pointer); + adjust(value); + adjust(result); + match *fun { + crate::AtomicFunction::Exchange { + compare: Some(ref mut compare), + } => { + adjust(compare); + } + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max + | crate::AtomicFunction::Exchange { compare: None } => {} + } + } + Statement::WorkGroupUniformLoad { + ref mut pointer, + ref mut result, + } => { + adjust(pointer); + adjust(result); + } + Statement::SubgroupBallot { + ref mut result, + ref mut predicate, + } => { + if let Some(ref mut predicate) = *predicate { + adjust(predicate); + } + adjust(result); + } + Statement::SubgroupCollectiveOperation { + ref mut argument, + ref mut result, + .. + } => { + adjust(argument); + adjust(result); + } + Statement::SubgroupGather { + ref mut mode, + ref mut argument, + ref mut result, + } => { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) + | crate::GatherMode::Shuffle(ref mut index) + | crate::GatherMode::ShuffleDown(ref mut index) + | crate::GatherMode::ShuffleUp(ref mut index) + | crate::GatherMode::ShuffleXor(ref mut index) => { + adjust(index); + } + } + adjust(argument); + adjust(result) + } + Statement::Call { + ref mut arguments, + ref mut result, + function: _, + } => { + for argument in arguments.iter_mut() { + adjust(argument); + } + if let Some(e) = result.as_mut() { + adjust(e); + } + } + Statement::RayQuery { + ref mut query, + ref mut fun, + } => { + adjust(query); + match *fun { + crate::RayQueryFunction::Initialize { + ref mut acceleration_structure, + ref mut descriptor, + } => { + adjust(acceleration_structure); + adjust(descriptor); + } + crate::RayQueryFunction::Proceed { ref mut result } => { + adjust(result); + } + crate::RayQueryFunction::Terminate => {} + } + } + Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {} + } +} + +/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced. +/// +/// According to validation, [`Emit`] statements must not cover any expressions +/// for which [`Expression::needs_pre_emit`] returns true. All expressions built +/// by successful constant evaluation fall into that category, meaning that +/// `process_function` will usually rewrite [`Override`] expressions and those +/// that use their values into pre-emitted expressions, leaving any [`Emit`] +/// statements that cover them invalid. +/// +/// This function rewrites all [`Emit`] statements into zero or more new +/// [`Emit`] statements covering only those expressions in the original range +/// that are not pre-emitted. +/// +/// [`Emit`]: Statement::Emit +/// [`needs_pre_emit`]: Expression::needs_pre_emit +/// [`Override`]: Expression::Override +fn filter_emits_in_block(block: &mut Block, expressions: &Arena) { + let original = std::mem::replace(block, Block::with_capacity(block.len())); + for (stmt, span) in original.span_into_iter() { + match stmt { + Statement::Emit(range) => { + let mut current = None; + for expr_h in range { + if expressions[expr_h].needs_pre_emit() { + if let Some((first, last)) = current { + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); + } + + current = None; + } else if let Some((_, ref mut last)) = current { + *last = expr_h; + } else { + current = Some((expr_h, expr_h)); + } + } + if let Some((first, last)) = current { + block.push(Statement::Emit(Range::new_from_bounds(first, last)), span); + } + } + Statement::Block(mut child) => { + filter_emits_in_block(&mut child, expressions); + block.push(Statement::Block(child), span); + } + Statement::If { + condition, + mut accept, + mut reject, + } => { + filter_emits_in_block(&mut accept, expressions); + filter_emits_in_block(&mut reject, expressions); + block.push( + Statement::If { + condition, + accept, + reject, + }, + span, + ); + } + Statement::Switch { + selector, + mut cases, + } => { + for case in &mut cases { + filter_emits_in_block(&mut case.body, expressions); + } + block.push(Statement::Switch { selector, cases }, span); + } + Statement::Loop { + mut body, + mut continuing, + break_if, + } => { + filter_emits_in_block(&mut body, expressions); + filter_emits_in_block(&mut continuing, expressions); + block.push( + Statement::Loop { + body, + continuing, + break_if, + }, + span, + ); + } + stmt => block.push(stmt.clone(), span), + } + } +} + +fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { + // note that in rust 0.0 == -0.0 + match scalar { + Scalar::BOOL => { + // https://webidl.spec.whatwg.org/#js-boolean + let value = value != 0.0 && !value.is_nan(); + Ok(Literal::Bool(value)) + } + Scalar::I32 => { + // https://webidl.spec.whatwg.org/#js-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as i32; + Ok(Literal::I32(value)) + } + Scalar::U32 => { + // https://webidl.spec.whatwg.org/#js-unsigned-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as u32; + Ok(Literal::U32(value)) + } + Scalar::F32 => { + // https://webidl.spec.whatwg.org/#js-float + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value as f32; + if !value.is_finite() { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + Ok(Literal::F32(value)) + } + Scalar::F64 => { + // https://webidl.spec.whatwg.org/#js-double + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + Ok(Literal::F64(value)) + } + _ => unreachable!(), + } +} + +#[test] +fn test_map_value_to_literal() { + let bool_test_cases = [ + (0.0, false), + (-0.0, false), + (f64::NAN, false), + (1.0, true), + (f64::INFINITY, true), + (f64::NEG_INFINITY, true), + ]; + for (value, out) in bool_test_cases { + let res = Ok(Literal::Bool(out)); + assert_eq!(map_value_to_literal(value, Scalar::BOOL), res); + } + + for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] { + for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + let res = Err(PipelineConstantError::SrcNeedsToBeFinite); + assert_eq!(map_value_to_literal(value, scalar), res); + } + } + + // i32 + assert_eq!( + map_value_to_literal(f64::from(i32::MIN), Scalar::I32), + Ok(Literal::I32(i32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX), Scalar::I32), + Ok(Literal::I32(i32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // u32 + assert_eq!( + map_value_to_literal(f64::from(u32::MIN), Scalar::U32), + Ok(Literal::U32(u32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX), Scalar::U32), + Ok(Literal::U32(u32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f32 + assert_eq!( + map_value_to_literal(f64::from(f32::MIN), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(f32::MAX), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f64 + assert_eq!( + map_value_to_literal(f64::MIN, Scalar::F64), + Ok(Literal::F64(f64::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::MAX, Scalar::F64), + Ok(Literal::F64(f64::MAX)) + ); +} diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs index 81f2fc10e0..120d60fc40 100644 --- a/third_party/rust/naga/src/back/spv/block.rs +++ b/third_party/rust/naga/src/back/spv/block.rs @@ -239,6 +239,7 @@ impl<'w> BlockContext<'w> { let init = self.ir_module.constants[handle].init; self.writer.constant_ids[init.index()] } + crate::Expression::Override(_) => return Err(Error::Override), crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id), crate::Expression::Compose { ty, ref components } => { self.temp_list.clear(); @@ -1072,7 +1073,7 @@ impl<'w> BlockContext<'w> { // // bitfieldExtract(x, o, c) - let bit_width = arg_ty.scalar_width().unwrap(); + let bit_width = arg_ty.scalar_width().unwrap() * 8; let width_constant = self .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); @@ -1128,7 +1129,7 @@ impl<'w> BlockContext<'w> { Mf::InsertBits => { // The behavior of InsertBits has the same undefined behavior as ExtractBits. - let bit_width = arg_ty.scalar_width().unwrap(); + let bit_width = arg_ty.scalar_width().unwrap() * 8; let width_constant = self .writer .get_constant_scalar(crate::Literal::U32(bit_width as u32)); @@ -1184,7 +1185,7 @@ impl<'w> BlockContext<'w> { } Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb), Mf::FindMsb => { - if arg_ty.scalar_width() == Some(32) { + if arg_ty.scalar_width() == Some(4) { let thing = match arg_scalar_kind { Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, @@ -1278,7 +1279,9 @@ impl<'w> BlockContext<'w> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } - | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2489,6 +2492,27 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; + } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; + } + crate::Statement::SubgroupGather { + ref mode, + argument, + result, + } => { + self.write_subgroup_gather(mode, argument, result, &mut block)?; + } } } diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs index 5b6226db85..1fb447e384 100644 --- a/third_party/rust/naga/src/back/spv/helpers.rs +++ b/third_party/rust/naga/src/back/spv/helpers.rs @@ -10,8 +10,12 @@ pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec { pub(super) fn string_to_words(input: &str) -> Vec { let bytes = input.as_bytes(); - let mut words = bytes_to_words(bytes); + str_bytes_to_words(bytes) +} + +pub(super) fn str_bytes_to_words(bytes: &[u8]) -> Vec { + let mut words = bytes_to_words(bytes); if bytes.len() % 4 == 0 { // nul-termination words.push(0x0u32); @@ -20,6 +24,21 @@ pub(super) fn string_to_words(input: &str) -> Vec { words } +/// split a string into chunks and keep utf8 valid +#[allow(unstable_name_collisions)] +pub(super) fn string_to_byte_chunks(input: &str, limit: usize) -> Vec<&[u8]> { + let mut offset: usize = 0; + let mut start: usize = 0; + let mut words = vec![]; + while offset < input.len() { + offset = input.floor_char_boundary(offset + limit); + words.push(input[start..offset].as_bytes()); + start = offset; + } + + words +} + pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass { match space { crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant, @@ -107,3 +126,35 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab _ => true, } } + +///HACK: this is taken from std unstable, remove it when std's floor_char_boundary is stable +trait U8Internal { + fn is_utf8_char_boundary(&self) -> bool; +} + +impl U8Internal for u8 { + fn is_utf8_char_boundary(&self) -> bool { + // This is bit magic equivalent to: b < 128 || b >= 192 + (*self as i8) >= -0x40 + } +} + +trait StrUnstable { + fn floor_char_boundary(&self, index: usize) -> usize; +} + +impl StrUnstable for str { + fn floor_char_boundary(&self, index: usize) -> usize { + if index >= self.len() { + self.len() + } else { + let lower_bound = index.saturating_sub(3); + let new_index = self.as_bytes()[lower_bound..=index] + .iter() + .rposition(|b| b.is_utf8_char_boundary()); + + // SAFETY: we know that the character boundary will be within four bytes + unsafe { lower_bound + new_index.unwrap_unchecked() } + } + } +} diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs index b963793ad3..df2774ab9c 100644 --- a/third_party/rust/naga/src/back/spv/instructions.rs +++ b/third_party/rust/naga/src/back/spv/instructions.rs @@ -43,6 +43,42 @@ impl super::Instruction { instruction } + pub(super) fn source_continued(source: &[u8]) -> Self { + let mut instruction = Self::new(Op::SourceContinued); + instruction.add_operands(helpers::str_bytes_to_words(source)); + instruction + } + + pub(super) fn source_auto_continued( + source_language: spirv::SourceLanguage, + version: u32, + source: &Option, + ) -> Vec { + let mut instructions = vec![]; + + let with_continue = source.as_ref().and_then(|debug_info| { + (debug_info.source_code.len() > u16::MAX as usize).then_some(debug_info) + }); + if let Some(debug_info) = with_continue { + let mut instruction = Self::new(Op::Source); + instruction.add_operand(source_language as u32); + instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes())); + + let words = helpers::string_to_byte_chunks(debug_info.source_code, u16::MAX as usize); + instruction.add_operand(debug_info.source_file_id); + instruction.add_operands(helpers::str_bytes_to_words(words[0])); + instructions.push(instruction); + for word_bytes in words[1..].iter() { + let instruction_continue = Self::source_continued(word_bytes); + instructions.push(instruction_continue); + } + } else { + let instruction = Self::source(source_language, version, source); + instructions.push(instruction); + } + instructions + } + pub(super) fn name(target_id: Word, name: &str) -> Self { let mut instruction = Self::new(Op::Name); instruction.add_operand(target_id); @@ -1037,6 +1073,73 @@ impl super::Instruction { instruction.add_operand(semantics_id); instruction } + + // Group Instructions + + pub(super) fn group_non_uniform_ballot( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + predicate: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBallot); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(predicate); + + instruction + } + pub(super) fn group_non_uniform_broadcast_first( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + + instruction + } + pub(super) fn group_non_uniform_gather( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_arithmetic( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + group_op: Option, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + if let Some(group_op) = group_op { + instruction.add_operand(group_op as u32); + } + instruction.add_operand(value); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs index eb29e3cd8b..38a87049e6 100644 --- a/third_party/rust/naga/src/back/spv/mod.rs +++ b/third_party/rust/naga/src/back/spv/mod.rs @@ -13,6 +13,7 @@ mod layout; mod ray; mod recyclable; mod selection; +mod subgroup; mod writer; pub use spirv::Capability; @@ -70,6 +71,8 @@ pub enum Error { FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), + #[error("overrides should not be present at this stage")] + Override, } #[derive(Default)] @@ -245,7 +248,7 @@ impl LocalImageType { /// this, by converting everything possible to a `LocalType` before inspecting /// it. /// -/// ## `Localtype` equality and SPIR-V `OpType` uniqueness +/// ## `LocalType` equality and SPIR-V `OpType` uniqueness /// /// The definition of `Eq` on `LocalType` is carefully chosen to help us follow /// certain SPIR-V rules. SPIR-V ยง2.8 requires some classes of `OpType...` @@ -454,7 +457,7 @@ impl recyclable::Recyclable for CachedExpressions { #[derive(Eq, Hash, PartialEq)] enum CachedConstant { - Literal(crate::Literal), + Literal(crate::proc::HashableLiteral), Composite { ty: LookupType, constituent_ids: Vec, @@ -527,6 +530,42 @@ struct FunctionArgument { handle_id: Word, } +/// Tracks the expressions for which the backend emits the following instructions: +/// - OpConstantTrue +/// - OpConstantFalse +/// - OpConstant +/// - OpConstantComposite +/// - OpConstantNull +struct ExpressionConstnessTracker { + inner: bit_set::BitSet, +} + +impl ExpressionConstnessTracker { + fn from_arena(arena: &crate::Arena) -> Self { + let mut inner = bit_set::BitSet::new(); + for (handle, expr) in arena.iter() { + let insert = match *expr { + crate::Expression::Literal(_) + | crate::Expression::ZeroValue(_) + | crate::Expression::Constant(_) => true, + crate::Expression::Compose { ref components, .. } => { + components.iter().all(|h| inner.contains(h.index())) + } + crate::Expression::Splat { value, .. } => inner.contains(value.index()), + _ => false, + }; + if insert { + inner.insert(handle.index()); + } + } + Self { inner } + } + + fn is_const(&self, value: Handle) -> bool { + self.inner.contains(value.index()) + } +} + /// General information needed to emit SPIR-V for Naga statements. struct BlockContext<'w> { /// The writer handling the module to which this code belongs. @@ -552,7 +591,7 @@ struct BlockContext<'w> { temp_list: Vec, /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions` - expression_constness: crate::proc::ExpressionConstnessTracker, + expression_constness: ExpressionConstnessTracker, } impl BlockContext<'_> { @@ -725,7 +764,7 @@ impl<'a> Default for Options<'a> { } // A subset of options meant to be changed per pipeline. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct PipelineOptions { diff --git a/third_party/rust/naga/src/back/spv/subgroup.rs b/third_party/rust/naga/src/back/spv/subgroup.rs new file mode 100644 index 0000000000..c952cb11a7 --- /dev/null +++ b/third_party/rust/naga/src/back/spv/subgroup.rs @@ -0,0 +1,207 @@ +use super::{Block, BlockContext, Error, Instruction}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; + +impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option>, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + use crate::SubgroupOperation as sg; + match *op { + sg::All | sg::Any => { + self.writer.require_any( + "GroupNonUniformVote", + &[spirv::Capability::GroupNonUniformVote], + )?; + } + _ => { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[spirv::Capability::GroupNonUniformArithmetic], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + let result_ty_inner = result_ty.inner_with(&self.ir_module.types); + + let (is_scalar, scalar) = match *result_ty_inner { + TypeInner::Scalar(kind) => (true, kind), + TypeInner::Vector { scalar: kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + let spirv_op = match (scalar.kind, *op) { + (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, + (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, + (_, sg::All | sg::Any) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd, + (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd, + (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul, + (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul, + (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax, + (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax, + (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax, + (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin, + (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin, + (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin, + (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd, + (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr, + (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor, + (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd, + (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr, + (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor, + (_, sg::And | sg::Or | sg::Xor) => unimplemented!(), + }; + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + use crate::CollectiveOperation as c; + let group_op = match *op { + sg::All | sg::Any => None, + _ => Some(match *collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }), + }; + + let arg_id = self.cached[argument]; + block.body.push(Instruction::group_non_uniform_arithmetic( + spirv_op, + result_type_id, + id, + exec_scope_id, + group_op, + arg_id, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_gather( + &mut self, + mode: &crate::GatherMode, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + match *mode { + crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + } + crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => { + self.writer.require_any( + "GroupNonUniformShuffle", + &[spirv::Capability::GroupNonUniformShuffle], + )?; + } + crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { + self.writer.require_any( + "GroupNonUniformShuffleRelative", + &[spirv::Capability::GroupNonUniformShuffleRelative], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + match *mode { + crate::GatherMode::BroadcastFirst => { + block + .body + .push(Instruction::group_non_uniform_broadcast_first( + result_type_id, + id, + exec_scope_id, + arg_id, + )); + } + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + // Use shuffle to emit broadcast to allow the index to + // be dynamically uniform on Vulkan 1.1. The argument to + // OpGroupNonUniformBroadcast must be a constant pre SPIR-V + // 1.5 (vulkan 1.2) + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } + } + self.cached[result] = id; + Ok(()) + } +} diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs index a5065e0623..73a16c273e 100644 --- a/third_party/rust/naga/src/back/spv/writer.rs +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -615,7 +615,7 @@ impl Writer { // Steal the Writer's temp list for a bit. temp_list: std::mem::take(&mut self.temp_list), writer: self, - expression_constness: crate::proc::ExpressionConstnessTracker::from_arena( + expression_constness: super::ExpressionConstnessTracker::from_arena( &ir_function.expressions, ), }; @@ -970,6 +970,11 @@ impl Writer { handle: Handle, ) -> Result { let ty = &arena[handle]; + // If it's a type that needs SPIR-V capabilities, request them now. + // This needs to happen regardless of the LocalType lookup succeeding, + // because some types which map to the same LocalType have different + // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569 + self.request_type_capabilities(&ty.inner)?; let id = if let Some(local) = make_local(&ty.inner) { // This type can be represented as a `LocalType`, so check if we've // already written an instruction for it. If not, do so now, with @@ -985,10 +990,6 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's a type that needs SPIR-V capabilities, request them now, - // so write_type_declaration_local can stay infallible. - self.request_type_capabilities(&ty.inner)?; - id } } @@ -1150,7 +1151,7 @@ impl Writer { } pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word { - let scalar = CachedConstant::Literal(value); + let scalar = CachedConstant::Literal(value.into()); if let Some(&id) = self.cached_constants.get(&scalar) { return id; } @@ -1258,7 +1259,7 @@ impl Writer { ir_module: &crate::Module, mod_info: &ModuleInfo, ) -> Result { - let id = match ir_module.const_expressions[handle] { + let id = match ir_module.global_expressions[handle] { crate::Expression::Literal(literal) => self.get_constant_scalar(literal), crate::Expression::Constant(constant) => { let constant = &ir_module.constants[constant]; @@ -1272,7 +1273,7 @@ impl Writer { let component_ids: Vec<_> = crate::proc::flatten_compose( ty, components, - &ir_module.const_expressions, + &ir_module.global_expressions, &ir_module.types, ) .map(|component| self.constant_ids[component.index()]) @@ -1310,7 +1311,11 @@ impl Writer { spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); - let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { + self.get_index_constant(spirv::Scope::Subgroup as u32) + } else { + self.get_index_constant(spirv::Scope::Workgroup as u32) + }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); block.body.push(Instruction::control_barrier( @@ -1585,6 +1590,41 @@ impl Writer { Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + // Subgroup + Bi::NumSubgroups => { + self.require_any( + "`num_subgroups` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId + } + Bi::SubgroupSize => { + self.require_any( + "`subgroup_size` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupSize + } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); @@ -1899,7 +1939,7 @@ impl Writer { source_code: debug_info.source_code, source_file_id, }); - self.debugs.push(Instruction::source( + self.debugs.append(&mut Instruction::source_auto_continued( spirv::SourceLanguage::Unknown, 0, &debug_info_inner, @@ -1914,8 +1954,8 @@ impl Writer { // write all const-expressions as constants self.constant_ids - .resize(ir_module.const_expressions.len(), 0); - for (handle, _) in ir_module.const_expressions.iter() { + .resize(ir_module.global_expressions.len(), 0); + for (handle, _) in ir_module.global_expressions.iter() { self.write_constant_expr(handle, ir_module, mod_info)?; } debug_assert!(self.constant_ids.iter().all(|&id| id != 0)); @@ -2029,6 +2069,10 @@ impl Writer { debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { + if !ir_module.overrides.is_empty() { + return Err(Error::Override); + } + self.reset(); // Try to find the entry point and corresponding index diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs index 3039cbbbe4..789f6f62bf 100644 --- a/third_party/rust/naga/src/back/wgsl/writer.rs +++ b/third_party/rust/naga/src/back/wgsl/writer.rs @@ -106,6 +106,12 @@ impl Writer { } pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { + if !module.overrides.is_empty() { + return Err(Error::Unimplemented( + "Pipeline constants are not yet supported for this back-end".to_string(), + )); + } + self.reset(module); // Save all ep result types @@ -918,8 +924,124 @@ impl Writer { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + write!(self.out, "subgroupBallot(")?; + if let Some(predicate) = predicate { + self.write_expr(module, predicate, func_ctx)?; + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -1070,7 +1192,7 @@ impl Writer { self.write_possibly_const_expression( module, expr, - &module.const_expressions, + &module.global_expressions, |writer, expr| writer.write_const_expression(module, expr), ) } @@ -1199,6 +1321,7 @@ impl Writer { |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } + Expression::Override(_) => unreachable!(), Expression::FunctionArgument(pos) => { let name_key = func_ctx.argument_key(pos); let name = &self.names[&name_key]; @@ -1691,6 +1814,8 @@ impl Writer { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} } @@ -1792,6 +1917,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", + Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance -- cgit v1.2.3