summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/back')
-rw-r--r--third_party/rust/naga/src/back/dot/mod.rs91
-rw-r--r--third_party/rust/naga/src/back/glsl/features.rs23
-rw-r--r--third_party/rust/naga/src/back/glsl/mod.rs148
-rw-r--r--third_party/rust/naga/src/back/hlsl/conv.rs5
-rw-r--r--third_party/rust/naga/src/back/hlsl/help.rs94
-rw-r--r--third_party/rust/naga/src/back/hlsl/mod.rs17
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs315
-rw-r--r--third_party/rust/naga/src/back/mod.rs17
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs27
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs192
-rw-r--r--third_party/rust/naga/src/back/pipeline_constants.rs957
-rw-r--r--third_party/rust/naga/src/back/spv/block.rs32
-rw-r--r--third_party/rust/naga/src/back/spv/helpers.rs53
-rw-r--r--third_party/rust/naga/src/back/spv/instructions.rs103
-rw-r--r--third_party/rust/naga/src/back/spv/mod.rs47
-rw-r--r--third_party/rust/naga/src/back/spv/subgroup.rs207
-rw-r--r--third_party/rust/naga/src/back/spv/writer.rs68
-rw-r--r--third_party/rust/naga/src/back/wgsl/writer.rs131
18 files changed, 2400 insertions, 127 deletions
diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs
index 1556371df1..9a7702b3f6 100644
--- a/third_party/rust/naga/src/back/dot/mod.rs
+++ b/third_party/rust/naga/src/back/dot/mod.rs
@@ -279,6 +279,94 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
+ S::SubgroupBallot { result, predicate } => {
+ if let Some(predicate) = predicate {
+ self.dependencies.push((id, predicate, "predicate"));
+ }
+ self.emits.push((id, result));
+ "SubgroupBallot"
+ }
+ S::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ self.dependencies.push((id, argument, "arg"));
+ self.emits.push((id, result));
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ "SubgroupAll"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ "SubgroupAny"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ "SubgroupAdd"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ "SubgroupMul"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ "SubgroupMax"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ "SubgroupMin"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ "SubgroupAnd"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ "SubgroupOr"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ "SubgroupXor"
+ }
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => "SubgroupExclusiveAdd",
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => "SubgroupExclusiveMul",
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => "SubgroupInclusiveAdd",
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => "SubgroupInclusiveMul",
+ _ => unimplemented!(),
+ }
+ }
+ S::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ self.dependencies.push((id, index, "index"))
+ }
+ }
+ self.dependencies.push((id, argument, "arg"));
+ self.emits.push((id, result));
+ match mode {
+ crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
+ crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
+ crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
+ crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
+ crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
+ crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
+ }
+ }
};
// Set the last node to the merge node
last_node = merge_id;
@@ -404,6 +492,7 @@ fn write_function_expressions(
let (label, color_id) = match *expression {
E::Literal(_) => ("Literal".into(), 2),
E::Constant(_) => ("Constant".into(), 2),
+ E::Override(_) => ("Override".into(), 2),
E::ZeroValue(_) => ("ZeroValue".into(), 2),
E::Compose { ref components, .. } => {
payload = Some(Payload::Arguments(components));
@@ -586,6 +675,8 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
+ E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
+ E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
};
// give uniform expressions an outline
diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs
index 99c128c6d9..e5a43f3e02 100644
--- a/third_party/rust/naga/src/back/glsl/features.rs
+++ b/third_party/rust/naga/src/back/glsl/features.rs
@@ -50,6 +50,8 @@ bitflags::bitflags! {
const INSTANCE_INDEX = 1 << 22;
/// Sample specific LODs of cube / array shadow textures
const TEXTURE_SHADOW_LOD = 1 << 23;
+ /// Subgroup operations
+ const SUBGROUP_OPERATIONS = 1 << 24;
}
}
@@ -117,6 +119,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
+ check_feature!(SUBGROUP_OPERATIONS, 430, 310);
match version {
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
_ => check_feature!(MULTI_VIEW, 140, 310),
@@ -259,6 +262,22 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?;
}
+ if self.0.contains(Features::SUBGROUP_OPERATIONS) {
+ // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
+ writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
+ writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
+ writeln!(
+ out,
+ "#extension GL_KHR_shader_subgroup_arithmetic : require"
+ )?;
+ writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
+ writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
+ writeln!(
+ out,
+ "#extension GL_KHR_shader_subgroup_shuffle_relative : require"
+ )?;
+ }
+
Ok(())
}
}
@@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> {
}
}
}
+ Expression::SubgroupBallotResult |
+ Expression::SubgroupOperationResult { .. } => {
+ features.request(Features::SUBGROUP_OPERATIONS)
+ }
_ => {}
}
}
diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs
index 9bda594610..c8c7ea557d 100644
--- a/third_party/rust/naga/src/back/glsl/mod.rs
+++ b/third_party/rust/naga/src/back/glsl/mod.rs
@@ -282,7 +282,7 @@ impl Default for Options {
}
/// A subset of options meant to be changed per pipeline.
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
@@ -497,6 +497,8 @@ pub enum Error {
ImageMultipleSamplers,
#[error("{0}")]
Custom(String),
+ #[error("overrides should not be present at this stage")]
+ Override,
}
/// Binary operation with a different logic on the GLSL side.
@@ -565,6 +567,10 @@ impl<'a, W: Write> Writer<'a, W> {
pipeline_options: &'a PipelineOptions,
policies: proc::BoundsCheckPolicies,
) -> Result<Self, Error> {
+ if !module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
// Check if the requested version is supported
if !options.version.is_supported() {
log::error!("Version {}", options.version);
@@ -2384,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
+ Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ write!(self.out, "subgroupBallot(")?;
+ match predicate {
+ Some(predicate) => self.write_expr(predicate, ctx)?,
+ None => write!(self.out, "true")?,
+ }
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "subgroupAll(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "subgroupAny(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupAdd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupMul(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "subgroupMax(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "subgroupMin(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "subgroupAnd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "subgroupOr(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "subgroupXor(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupExclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupExclusiveMul(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupInclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupInclusiveMul(")?
+ }
+ _ => unimplemented!(),
+ }
+ self.write_expr(argument, ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ match mode {
+ crate::GatherMode::BroadcastFirst => {
+ write!(self.out, "subgroupBroadcastFirst(")?;
+ }
+ crate::GatherMode::Broadcast(_) => {
+ write!(self.out, "subgroupBroadcast(")?;
+ }
+ crate::GatherMode::Shuffle(_) => {
+ write!(self.out, "subgroupShuffle(")?;
+ }
+ crate::GatherMode::ShuffleDown(_) => {
+ write!(self.out, "subgroupShuffleDown(")?;
+ }
+ crate::GatherMode::ShuffleUp(_) => {
+ write!(self.out, "subgroupShuffleUp(")?;
+ }
+ crate::GatherMode::ShuffleXor(_) => {
+ write!(self.out, "subgroupShuffleXor(")?;
+ }
+ }
+ self.write_expr(argument, ctx)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, ", ")?;
+ self.write_expr(index, ctx)?;
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
Ok(())
@@ -2402,7 +2527,7 @@ impl<'a, W: Write> Writer<'a, W> {
fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult {
self.write_possibly_const_expr(
expr,
- &self.module.const_expressions,
+ &self.module.global_expressions,
|expr| &self.info[expr],
|writer, expr| writer.write_const_expr(expr),
)
@@ -2536,6 +2661,7 @@ impl<'a, W: Write> Writer<'a, W> {
|writer, expr| writer.write_expr(expr, ctx),
)?;
}
+ Expression::Override(_) => return Err(Error::Override),
// `Access` is applied to arrays, vectors and matrices and is written as indexing
Expression::Access { base, index } => {
self.write_expr(base, ctx)?;
@@ -3411,7 +3537,8 @@ impl<'a, W: Write> Writer<'a, W> {
let scalar_bits = ctx
.resolve_type(arg, &self.module.types)
.scalar_width()
- .unwrap();
+ .unwrap()
+ * 8;
write!(self.out, "bitfieldExtract(")?;
self.write_expr(arg, ctx)?;
@@ -3430,7 +3557,8 @@ impl<'a, W: Write> Writer<'a, W> {
let scalar_bits = ctx
.resolve_type(arg, &self.module.types)
.scalar_width()
- .unwrap();
+ .unwrap()
+ * 8;
write!(self.out, "bitfieldInsert(")?;
self.write_expr(arg, ctx)?;
@@ -3649,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
- | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
+ | Expression::WorkGroupUniformLoadResult { .. }
+ | Expression::SubgroupOperationResult { .. }
+ | Expression::SubgroupBallotResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
@@ -4218,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> {
if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}memoryBarrierShared();")?;
}
+ if flags.contains(crate::Barrier::SUB_GROUP) {
+ writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
+ }
writeln!(self.out, "{level}barrier();")?;
Ok(())
}
@@ -4487,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
+ // subgroup
+ Bi::NumSubgroups => "gl_NumSubgroups",
+ Bi::SubgroupId => "gl_SubgroupID",
+ Bi::SubgroupSize => "gl_SubgroupSize",
+ Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
}
}
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs
index 2a6db35db8..7d15f43f6c 100644
--- a/third_party/rust/naga/src/back/hlsl/conv.rs
+++ b/third_party/rust/naga/src/back/hlsl/conv.rs
@@ -179,6 +179,11 @@ impl crate::BuiltIn {
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
+ // These builtins map to functions
+ Self::SubgroupSize
+ | Self::SubgroupInvocationId
+ | Self::NumSubgroups
+ | Self::SubgroupId => unreachable!(),
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {self:?}")))
}
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs
index 4dd9ea5987..d3bb1ce7f5 100644
--- a/third_party/rust/naga/src/back/hlsl/help.rs
+++ b/third_party/rust/naga/src/back/hlsl/help.rs
@@ -70,6 +70,11 @@ pub(super) struct WrappedMath {
pub(super) components: Option<u32>,
}
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedZeroValue {
+ pub(super) ty: Handle<crate::Type>,
+}
+
/// HLSL backend requires its own `ImageQuery` enum.
///
/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
@@ -359,7 +364,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
}
/// Helper function that write wrapped function for `Expression::Compose` for structures.
- pub(super) fn write_wrapped_constructor_function(
+ fn write_wrapped_constructor_function(
&mut self,
module: &crate::Module,
constructor: WrappedConstructor,
@@ -862,6 +867,25 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}
+ // TODO: we could merge this with iteration in write_wrapped_compose_functions...
+ //
+ /// Helper function that writes zero value wrapped functions
+ pub(super) fn write_wrapped_zero_value_functions(
+ &mut self,
+ module: &crate::Module,
+ expressions: &crate::Arena<crate::Expression>,
+ ) -> BackendResult {
+ for (handle, _) in expressions.iter() {
+ if let crate::Expression::ZeroValue(ty) = expressions[handle] {
+ let zero_value = WrappedZeroValue { ty };
+ if self.wrapped.zero_values.insert(zero_value) {
+ self.write_wrapped_zero_value_function(module, zero_value)?;
+ }
+ }
+ }
+ Ok(())
+ }
+
pub(super) fn write_wrapped_math_functions(
&mut self,
module: &crate::Module,
@@ -1006,6 +1030,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
) -> BackendResult {
self.write_wrapped_math_functions(module, func_ctx)?;
self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
+ self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
for (handle, _) in func_ctx.expressions.iter() {
match func_ctx.expressions[handle] {
@@ -1283,4 +1308,71 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}
+
+ pub(super) fn write_wrapped_zero_value_function_name(
+ &mut self,
+ module: &crate::Module,
+ zero_value: WrappedZeroValue,
+ ) -> BackendResult {
+ let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?;
+ write!(self.out, "ZeroValue{name}")?;
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ZeroValue`
+ ///
+ /// This is necessary since we might have a member access after the zero value expression, e.g.
+ /// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc).
+ ///
+ /// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly.
+ ///
+ /// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle
+ /// cases like:
+ ///
+ /// ```ignore
+ /// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet
+ /// t_1.am = (__mat4x2[2])((float4x2[2])0);
+ /// ^
+ /// ```
+ fn write_wrapped_zero_value_function(
+ &mut self,
+ module: &crate::Module,
+ zero_value: WrappedZeroValue,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner {
+ write!(self.out, "typedef ")?;
+ self.write_type(module, zero_value.ty)?;
+ write!(self.out, " ret_")?;
+ self.write_wrapped_zero_value_function_name(module, zero_value)?;
+ self.write_array_size(module, base, size)?;
+ writeln!(self.out, ";")?;
+
+ write!(self.out, "ret_")?;
+ self.write_wrapped_zero_value_function_name(module, zero_value)?;
+ } else {
+ self.write_type(module, zero_value.ty)?;
+ }
+ write!(self.out, " ")?;
+ self.write_wrapped_zero_value_function_name(module, zero_value)?;
+
+ // Write function parameters (none) and start function body
+ writeln!(self.out, "() {{")?;
+
+ // Write `ZeroValue` function.
+ write!(self.out, "{INDENT}return ")?;
+ self.write_default_init(module, zero_value.ty)?;
+ writeln!(self.out, ";")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
}
diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs
index f37a223f47..28edbf70e1 100644
--- a/third_party/rust/naga/src/back/hlsl/mod.rs
+++ b/third_party/rust/naga/src/back/hlsl/mod.rs
@@ -131,6 +131,13 @@ pub enum ShaderModel {
V5_0,
V5_1,
V6_0,
+ V6_1,
+ V6_2,
+ V6_3,
+ V6_4,
+ V6_5,
+ V6_6,
+ V6_7,
}
impl ShaderModel {
@@ -139,6 +146,13 @@ impl ShaderModel {
Self::V5_0 => "5_0",
Self::V5_1 => "5_1",
Self::V6_0 => "6_0",
+ Self::V6_1 => "6_1",
+ Self::V6_2 => "6_2",
+ Self::V6_3 => "6_3",
+ Self::V6_4 => "6_4",
+ Self::V6_5 => "6_5",
+ Self::V6_6 => "6_6",
+ Self::V6_7 => "6_7",
}
}
}
@@ -247,10 +261,13 @@ pub enum Error {
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Custom(String),
+ #[error("overrides should not be present at this stage")]
+ Override,
}
#[derive(Default)]
struct Wrapped {
+ zero_values: crate::FastHashSet<help::WrappedZeroValue>,
array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
image_queries: crate::FastHashSet<help::WrappedImageQuery>,
constructors: crate::FastHashSet<help::WrappedConstructor>,
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
index 4ba856946b..86d8f89035 100644
--- a/third_party/rust/naga/src/back/hlsl/writer.rs
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -1,5 +1,8 @@
use super::{
- help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
+ help::{
+ WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
+ WrappedZeroValue,
+ },
storage::StoreValue,
BackendResult, Error, Options,
};
@@ -77,6 +80,19 @@ enum Io {
Output,
}
+const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
+ let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
+ return false;
+ };
+ matches!(
+ builtin,
+ crate::BuiltIn::SubgroupSize
+ | crate::BuiltIn::SubgroupInvocationId
+ | crate::BuiltIn::NumSubgroups
+ | crate::BuiltIn::SubgroupId
+ )
+}
+
impl<'a, W: fmt::Write> super::Writer<'a, W> {
pub fn new(out: W, options: &'a Options) -> Self {
Self {
@@ -161,6 +177,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
}
+ for statement in func.body.iter() {
+ match *statement {
+ crate::Statement::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: crate::CollectiveOperation::InclusiveScan,
+ argument,
+ result: _,
+ } => {
+ self.need_bake_expressions.insert(argument);
+ }
+ _ => {}
+ }
+ }
}
pub fn write(
@@ -168,6 +197,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
module: &Module,
module_info: &valid::ModuleInfo,
) -> Result<super::ReflectionInfo, Error> {
+ if !module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
self.reset(module);
// Write special constants, if needed
@@ -233,7 +266,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_special_functions(module)?;
- self.write_wrapped_compose_functions(module, &module.const_expressions)?;
+ self.write_wrapped_compose_functions(module, &module.global_expressions)?;
+ self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
// Write all named constants
let mut constants = module
@@ -397,31 +431,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// if they are struct, so that the `stage` argument here could be omitted.
fn write_semantic(
&mut self,
- binding: &crate::Binding,
+ binding: &Option<crate::Binding>,
stage: Option<(ShaderStage, Io)>,
) -> BackendResult {
match *binding {
- crate::Binding::BuiltIn(builtin) => {
+ Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
let builtin_str = builtin.to_hlsl_str()?;
write!(self.out, " : {builtin_str}")?;
}
- crate::Binding::Location {
+ Some(crate::Binding::Location {
second_blend_source: true,
..
- } => {
+ }) => {
write!(self.out, " : SV_Target1")?;
}
- crate::Binding::Location {
+ Some(crate::Binding::Location {
location,
second_blend_source: false,
..
- } => {
+ }) => {
if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
write!(self.out, " : SV_Target{location}")?;
} else {
write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
}
}
+ _ => {}
}
Ok(())
@@ -442,17 +477,30 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, "struct {struct_name}")?;
writeln!(self.out, " {{")?;
for m in members.iter() {
+ if is_subgroup_builtin_binding(&m.binding) {
+ continue;
+ }
write!(self.out, "{}", back::INDENT)?;
if let Some(ref binding) = m.binding {
self.write_modifier(binding)?;
}
self.write_type(module, m.ty)?;
write!(self.out, " {}", &m.name)?;
- if let Some(ref binding) = m.binding {
- self.write_semantic(binding, Some(shader_stage))?;
- }
+ self.write_semantic(&m.binding, Some(shader_stage))?;
writeln!(self.out, ";")?;
}
+ if members.iter().any(|arg| {
+ matches!(
+ arg.binding,
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
+ )
+ }) {
+ writeln!(
+ self.out,
+ "{}uint __local_invocation_index : SV_GroupIndex;",
+ back::INDENT
+ )?;
+ }
writeln!(self.out, "}};")?;
writeln!(self.out)?;
@@ -553,8 +601,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
/// Writes special interface structures for an entry point. The special structures have
- /// all the fields flattened into them and sorted by binding. They are only needed for
- /// VS outputs and FS inputs, so that these interfaces match.
+ /// all the fields flattened into them and sorted by binding. They are needed to emulate
+ /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
fn write_ep_interface(
&mut self,
module: &Module,
@@ -563,7 +611,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ep_name: &str,
) -> Result<EntryPointInterface, Error> {
Ok(EntryPointInterface {
- input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
+ input: if !func.arguments.is_empty()
+ && (stage == ShaderStage::Fragment
+ || func
+ .arguments
+ .iter()
+ .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
+ {
Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
} else {
None
@@ -577,6 +631,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
})
}
+ fn write_ep_argument_initialization(
+ &mut self,
+ ep: &crate::EntryPoint,
+ ep_input: &EntryPointBinding,
+ fake_member: &EpStructMember,
+ ) -> BackendResult {
+ match fake_member.binding {
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
+ write!(self.out, "WaveGetLaneCount()")?
+ }
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
+ write!(self.out, "WaveGetLaneIndex()")?
+ }
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
+ self.out,
+ "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
+ ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
+ )?,
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
+ write!(
+ self.out,
+ "{}.__local_invocation_index / WaveGetLaneCount()",
+ ep_input.arg_name
+ )?;
+ }
+ _ => {
+ write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ }
+ }
+ Ok(())
+ }
+
/// Write an entry point preface that initializes the arguments as specified in IR.
fn write_ep_arguments_initialization(
&mut self,
@@ -584,6 +670,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
func: &crate::Function,
ep_index: u16,
) -> BackendResult {
+ let ep = &module.entry_points[ep_index as usize];
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
Some(ep_input) => ep_input,
None => return Ok(()),
@@ -597,8 +684,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
match module.types[arg.ty].inner {
TypeInner::Array { base, size, .. } => {
self.write_array_size(module, base, size)?;
- let fake_member = fake_iter.next().unwrap();
- writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ write!(self.out, " = ")?;
+ self.write_ep_argument_initialization(
+ ep,
+ &ep_input,
+ fake_iter.next().unwrap(),
+ )?;
+ writeln!(self.out, ";")?;
}
TypeInner::Struct { ref members, .. } => {
write!(self.out, " = {{ ")?;
@@ -606,14 +698,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if index != 0 {
write!(self.out, ", ")?;
}
- let fake_member = fake_iter.next().unwrap();
- write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ self.write_ep_argument_initialization(
+ ep,
+ &ep_input,
+ fake_iter.next().unwrap(),
+ )?;
}
writeln!(self.out, " }};")?;
}
_ => {
- let fake_member = fake_iter.next().unwrap();
- writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ write!(self.out, " = ")?;
+ self.write_ep_argument_initialization(
+ ep,
+ &ep_input,
+ fake_iter.next().unwrap(),
+ )?;
+ writeln!(self.out, ";")?;
}
}
}
@@ -928,9 +1028,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
- if let Some(ref binding) = member.binding {
- self.write_semantic(binding, shader_stage)?;
- };
+ self.write_semantic(&member.binding, shader_stage)?;
writeln!(self.out, ";")?;
}
@@ -1143,7 +1241,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
back::FunctionType::EntryPoint(ep_index) => {
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
- write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
+ write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
} else {
let stage = module.entry_points[ep_index as usize].stage;
for (index, arg) in func.arguments.iter().enumerate() {
@@ -1160,17 +1258,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_array_size(module, base, size)?;
}
- if let Some(ref binding) = arg.binding {
- self.write_semantic(binding, Some((stage, Io::Input)))?;
- }
+ self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
}
-
- if need_workgroup_variables_initialization {
- if !func.arguments.is_empty() {
- write!(self.out, ", ")?;
- }
- write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
+ }
+ if need_workgroup_variables_initialization {
+ if self.entry_point_io[ep_index as usize].input.is_some()
+ || !func.arguments.is_empty()
+ {
+ write!(self.out, ", ")?;
}
+ write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
}
}
}
@@ -1180,11 +1277,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Write semantic if it present
if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
let stage = module.entry_points[index as usize].stage;
- if let Some(crate::FunctionResult {
- binding: Some(ref binding),
- ..
- }) = func.result
- {
+ if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
self.write_semantic(binding, Some((stage, Io::Output)))?;
}
}
@@ -1984,6 +2077,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "{level}}}")?
}
Statement::RayQuery { .. } => unreachable!(),
+ Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ write!(self.out, "const uint4 {name} = ")?;
+ self.named_expressions.insert(result, name);
+
+ write!(self.out, "WaveActiveBallot(")?;
+ match predicate {
+ Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
+ None => write!(self.out, "true")?,
+ }
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(result, name);
+
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "WaveActiveAllTrue(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "WaveActiveAnyTrue(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "WaveActiveSum(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "WaveActiveProduct(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "WaveActiveMax(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "WaveActiveMin(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "WaveActiveBitAnd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "WaveActiveBitOr(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "WaveActiveBitXor(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "WavePrefixSum(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "WavePrefixProduct(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, " + WavePrefixSum(")?;
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, " * WavePrefixProduct(")?;
+ }
+ _ => unimplemented!(),
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(result, name);
+
+ if matches!(mode, crate::GatherMode::BroadcastFirst) {
+ write!(self.out, "WaveReadLaneFirst(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ } else {
+ write!(self.out, "WaveReadLaneAt(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, ", ")?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => unreachable!(),
+ crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ crate::GatherMode::ShuffleDown(index) => {
+ write!(self.out, "WaveGetLaneIndex() + ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ crate::GatherMode::ShuffleUp(index) => {
+ write!(self.out, "WaveGetLaneIndex() - ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, "WaveGetLaneIndex() ^ ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
Ok(())
@@ -1997,7 +2213,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_possibly_const_expression(
module,
expr,
- &module.const_expressions,
+ &module.global_expressions,
|writer, expr| writer.write_const_expression(module, expr),
)
}
@@ -2039,7 +2255,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_const_expression(module, constant.init)?;
}
}
- Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
+ Expression::ZeroValue(ty) => {
+ self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
+ write!(self.out, "()")?;
+ }
Expression::Compose { ty, ref components } => {
match module.types[ty].inner {
TypeInner::Struct { .. } | TypeInner::Array { .. } => {
@@ -2140,6 +2359,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
+ Expression::Override(_) => return Err(Error::Override),
// All of the multiplication can be expressed as `mul`,
// except vector * vector, which needs to use the "*" operator.
Expression::Binary {
@@ -2588,7 +2808,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
true
}
None => {
- if inner.scalar_width() == Some(64) {
+ if inner.scalar_width() == Some(8) {
false
} else {
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
@@ -3129,7 +3349,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::WorkGroupUniformLoadResult { .. }
- | Expression::RayQueryProceedResult => {}
+ | Expression::RayQueryProceedResult
+ | Expression::SubgroupBallotResult
+ | Expression::SubgroupOperationResult { .. } => {}
}
if !closing_bracket.is_empty() {
@@ -3179,7 +3401,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
/// Helper function that write default zero initialization
- fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ pub(super) fn write_default_init(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ ) -> BackendResult {
write!(self.out, "(")?;
self.write_type(module, ty)?;
if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
@@ -3196,6 +3422,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
}
+ if barrier.contains(crate::Barrier::SUB_GROUP) {
+ // Does not exist in DirectX
+ }
Ok(())
}
}
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs
index c8f091decb..0c9c5e4761 100644
--- a/third_party/rust/naga/src/back/mod.rs
+++ b/third_party/rust/naga/src/back/mod.rs
@@ -16,6 +16,14 @@ pub mod spv;
#[cfg(feature = "wgsl-out")]
pub mod wgsl;
+#[cfg(any(
+ feature = "hlsl-out",
+ feature = "msl-out",
+ feature = "spv-out",
+ feature = "glsl-out"
+))]
+pub mod pipeline_constants;
+
/// Names of vector components.
pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
/// Indent for backends.
@@ -26,6 +34,15 @@ pub const BAKE_PREFIX: &str = "_e";
/// Expressions that need baking.
pub type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>;
+/// Specifies the values of pipeline-overridable constants in the shader module.
+///
+/// If an `@id` attribute was specified on the declaration,
+/// the key must be the pipeline constant ID as a decimal ASCII number; if not,
+/// the key must be the constant's identifier name.
+///
+/// The value may represent any of WGSL's concrete scalar types.
+pub type PipelineConstants = std::collections::HashMap<String, f64>;
+
/// Indentation level.
#[derive(Clone, Copy)]
pub struct Level(pub usize);
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
index 68e5b79906..8b03e20376 100644
--- a/third_party/rust/naga/src/back/msl/mod.rs
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -143,6 +143,8 @@ pub enum Error {
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
+ #[error("overrides should not be present at this stage")]
+ Override,
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
@@ -221,7 +223,7 @@ impl Default for Options {
}
/// A subset of options that are meant to be changed per pipeline.
-#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
@@ -434,6 +436,11 @@ impl ResolvedBinding {
Bi::WorkGroupId => "threadgroup_position_in_grid",
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
Bi::NumWorkGroups => "threadgroups_per_grid",
+ // subgroup
+ Bi::NumSubgroups => "simdgroups_per_threadgroup",
+ Bi::SubgroupId => "simdgroup_index_in_threadgroup",
+ Bi::SubgroupSize => "threads_per_simdgroup",
+ Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
Bi::CullDistance | Bi::ViewIndex => {
return Err(Error::UnsupportedBuiltIn(built_in))
}
@@ -536,3 +543,21 @@ fn test_error_size() {
use std::mem::size_of;
assert_eq!(size_of::<Error>(), 32);
}
+
+impl crate::AtomicFunction {
+ fn to_msl(self) -> Result<&'static str, Error> {
+ Ok(match self {
+ Self::Add => "fetch_add",
+ Self::Subtract => "fetch_sub",
+ Self::And => "fetch_and",
+ Self::InclusiveOr => "fetch_or",
+ Self::ExclusiveOr => "fetch_xor",
+ Self::Min => "fetch_min",
+ Self::Max => "fetch_max",
+ Self::Exchange { compare: None } => "exchange",
+ Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
+ "atomic CompareExchange".to_string(),
+ ))?,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
index 5227d8e7db..e250d0b72c 100644
--- a/third_party/rust/naga/src/back/msl/writer.rs
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -1131,21 +1131,10 @@ impl<W: Write> Writer<W> {
Ok(())
}
- fn put_atomic_fetch(
- &mut self,
- pointer: Handle<crate::Expression>,
- key: &str,
- value: Handle<crate::Expression>,
- context: &ExpressionContext,
- ) -> BackendResult {
- self.put_atomic_operation(pointer, "fetch_", key, value, context)
- }
-
fn put_atomic_operation(
&mut self,
pointer: Handle<crate::Expression>,
- key1: &str,
- key2: &str,
+ key: &str,
value: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
@@ -1163,7 +1152,7 @@ impl<W: Write> Writer<W> {
write!(
self.out,
- "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}"
+ "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
@@ -1248,7 +1237,7 @@ impl<W: Write> Writer<W> {
) -> BackendResult {
self.put_possibly_const_expression(
expr_handle,
- &module.const_expressions,
+ &module.global_expressions,
module,
mod_info,
&(module, mod_info),
@@ -1431,6 +1420,7 @@ impl<W: Write> Writer<W> {
|writer, context, expr| writer.put_expression(expr, context, true),
)?;
}
+ crate::Expression::Override(_) => return Err(Error::Override),
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.
@@ -1944,7 +1934,7 @@ impl<W: Write> Writer<W> {
//
// extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
- let scalar_bits = context.resolve_type(arg).scalar_width().unwrap();
+ let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::extract_bits(")?;
self.put_expression(arg, context, true)?;
@@ -1960,7 +1950,7 @@ impl<W: Write> Writer<W> {
//
// insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
- let scalar_bits = context.resolve_type(arg).scalar_width().unwrap();
+ let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::insert_bits(")?;
self.put_expression(arg, context, true)?;
@@ -2041,6 +2031,8 @@ impl<W: Write> Writer<W> {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
+ | crate::Expression::SubgroupBallotResult
+ | crate::Expression::SubgroupOperationResult { .. }
| crate::Expression::RayQueryProceedResult => {
unreachable!()
}
@@ -2994,43 +2986,8 @@ impl<W: Write> Writer<W> {
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_baking_expression(result, &context.expression, &res_name)?;
self.named_expressions.insert(result, res_name);
- match *fun {
- crate::AtomicFunction::Add => {
- self.put_atomic_fetch(pointer, "add", value, &context.expression)?;
- }
- crate::AtomicFunction::Subtract => {
- self.put_atomic_fetch(pointer, "sub", value, &context.expression)?;
- }
- crate::AtomicFunction::And => {
- self.put_atomic_fetch(pointer, "and", value, &context.expression)?;
- }
- crate::AtomicFunction::InclusiveOr => {
- self.put_atomic_fetch(pointer, "or", value, &context.expression)?;
- }
- crate::AtomicFunction::ExclusiveOr => {
- self.put_atomic_fetch(pointer, "xor", value, &context.expression)?;
- }
- crate::AtomicFunction::Min => {
- self.put_atomic_fetch(pointer, "min", value, &context.expression)?;
- }
- crate::AtomicFunction::Max => {
- self.put_atomic_fetch(pointer, "max", value, &context.expression)?;
- }
- crate::AtomicFunction::Exchange { compare: None } => {
- self.put_atomic_operation(
- pointer,
- "exchange",
- "",
- value,
- &context.expression,
- )?;
- }
- crate::AtomicFunction::Exchange { .. } => {
- return Err(Error::FeatureNotImplemented(
- "atomic CompareExchange".to_string(),
- ));
- }
- }
+ let fun_str = fun.to_msl()?;
+ self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
writeln!(self.out, ";")?;
}
@@ -3144,6 +3101,121 @@ impl<W: Write> Writer<W> {
}
}
}
+ crate::Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
+ if let Some(predicate) = predicate {
+ self.put_expression(predicate, &context.expression, true)?;
+ } else {
+ write!(self.out, "true")?;
+ }
+ writeln!(self.out, "), 0, 0, 0);")?;
+ }
+ crate::Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "{NAMESPACE}::simd_all(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "{NAMESPACE}::simd_any(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "{NAMESPACE}::simd_sum(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "{NAMESPACE}::simd_product(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "{NAMESPACE}::simd_max(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "{NAMESPACE}::simd_min(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "{NAMESPACE}::simd_and(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "{NAMESPACE}::simd_or(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "{NAMESPACE}::simd_xor(")?
+ }
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
+ _ => unimplemented!(),
+ }
+ self.put_expression(argument, &context.expression, true)?;
+ writeln!(self.out, ");")?;
+ }
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ match mode {
+ crate::GatherMode::BroadcastFirst => {
+ write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
+ }
+ crate::GatherMode::Broadcast(_) => {
+ write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
+ }
+ crate::GatherMode::Shuffle(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
+ }
+ crate::GatherMode::ShuffleDown(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
+ }
+ crate::GatherMode::ShuffleUp(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
+ }
+ crate::GatherMode::ShuffleXor(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
+ }
+ }
+ self.put_expression(argument, &context.expression, true)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, ", ")?;
+ self.put_expression(index, &context.expression, true)?;
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
}
@@ -3220,6 +3292,10 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
+ if !module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
self.names.clear();
self.namer.reset(
module,
@@ -4487,6 +4563,12 @@ impl<W: Write> Writer<W> {
"{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
)?;
}
+ if flags.contains(crate::Barrier::SUB_GROUP) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
+ )?;
+ }
Ok(())
}
}
@@ -4757,8 +4839,8 @@ fn test_stack_size() {
}
let stack_size = addresses_end - addresses_start;
// check the size (in debug only)
- // last observed macOS value: 19152 (CI)
- if !(9000..=20000).contains(&stack_size) {
+ // last observed macOS value: 22256 (CI)
+ if !(15000..=25000).contains(&stack_size) {
panic!("`put_block` stack size {stack_size} has changed!");
}
}
diff --git a/third_party/rust/naga/src/back/pipeline_constants.rs b/third_party/rust/naga/src/back/pipeline_constants.rs
new file mode 100644
index 0000000000..0dbe9cf4e8
--- /dev/null
+++ b/third_party/rust/naga/src/back/pipeline_constants.rs
@@ -0,0 +1,957 @@
+use super::PipelineConstants;
+use crate::{
+ proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
+ valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
+ Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
+ Span, Statement, TypeInner, WithSpan,
+};
+use std::{borrow::Cow, collections::HashSet, mem};
+use thiserror::Error;
+
+#[derive(Error, Debug, Clone)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum PipelineConstantError {
+ #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
+ MissingValue(String),
+ #[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")]
+ SrcNeedsToBeFinite,
+ #[error("Source f64 value doesn't fit in destination")]
+ DstRangeTooSmall,
+ #[error(transparent)]
+ ConstantEvaluatorError(#[from] ConstantEvaluatorError),
+ #[error(transparent)]
+ ValidationError(#[from] WithSpan<ValidationError>),
+}
+
+/// Replace all overrides in `module` with constants.
+///
+/// If no changes are needed, this just returns `Cow::Borrowed`
+/// references to `module` and `module_info`. Otherwise, it clones
+/// `module`, edits its [`global_expressions`] arena to contain only
+/// fully-evaluated expressions, and returns `Cow::Owned` values
+/// holding the simplified module and its validation results.
+///
+/// In either case, the module returned has an empty `overrides`
+/// arena, and the `global_expressions` arena contains only
+/// fully-evaluated expressions.
+///
+/// [`global_expressions`]: Module::global_expressions
+pub fn process_overrides<'a>(
+ module: &'a Module,
+ module_info: &'a ModuleInfo,
+ pipeline_constants: &PipelineConstants,
+) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
+ if module.overrides.is_empty() {
+ return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
+ }
+
+ let mut module = module.clone();
+
+ // A map from override handles to the handles of the constants
+ // we've replaced them with.
+ let mut override_map = Vec::with_capacity(module.overrides.len());
+
+ // A map from `module`'s original global expression handles to
+ // handles in the new, simplified global expression arena.
+ let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len());
+
+ // The set of constants whose initializer handles we've already
+ // updated to refer to the newly built global expression arena.
+ //
+ // All constants in `module` must have their `init` handles
+ // updated to point into the new, simplified global expression
+ // arena. Some of these we can most easily handle as a side effect
+ // during the simplification process, but we must handle the rest
+ // in a final fixup pass, guided by `adjusted_global_expressions`. We
+ // add their handles to this set, so that the final fixup step can
+ // leave them alone.
+ let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
+
+ let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
+
+ // An iterator through the original overrides table, consumed in
+ // approximate tandem with the global expressions.
+ let mut override_iter = module.overrides.drain();
+
+ // Do two things in tandem:
+ //
+ // - Rebuild the global expression arena from scratch, fully
+ // evaluating all expressions, and replacing each `Override`
+ // expression in `module.global_expressions` with a `Constant`
+ // expression.
+ //
+ // - Build a new `Constant` in `module.constants` to take the
+ // place of each `Override`.
+ //
+ // Build a map from old global expression handles to their
+ // fully-evaluated counterparts in `adjusted_global_expressions` as we
+ // go.
+ //
+ // Why in tandem? Overrides refer to expressions, and expressions
+ // refer to overrides, so we can't disentangle the two into
+ // separate phases. However, we can take advantage of the fact
+ // that the overrides and expressions must form a DAG, and work
+ // our way from the leaves to the roots, replacing and evaluating
+ // as we go.
+ //
+ // Although the two loops are nested, this is really two
+ // alternating phases: we adjust and evaluate constant expressions
+ // until we hit an `Override` expression, at which point we switch
+ // to building `Constant`s for `Overrides` until we've handled the
+ // one used by the expression. Then we switch back to processing
+ // expressions. Because we know they form a DAG, we know the
+ // `Override` expressions we encounter can only have initializers
+ // referring to global expressions we've already simplified.
+ for (old_h, expr, span) in module.global_expressions.drain() {
+ let mut expr = match expr {
+ Expression::Override(h) => {
+ let c_h = if let Some(new_h) = override_map.get(h.index()) {
+ *new_h
+ } else {
+ let mut new_h = None;
+ for entry in override_iter.by_ref() {
+ let stop = entry.0 == h;
+ new_h = Some(process_override(
+ entry,
+ pipeline_constants,
+ &mut module,
+ &mut override_map,
+ &adjusted_global_expressions,
+ &mut adjusted_constant_initializers,
+ &mut global_expression_kind_tracker,
+ )?);
+ if stop {
+ break;
+ }
+ }
+ new_h.unwrap()
+ };
+ Expression::Constant(c_h)
+ }
+ Expression::Constant(c_h) => {
+ if adjusted_constant_initializers.insert(c_h) {
+ let init = &mut module.constants[c_h].init;
+ *init = adjusted_global_expressions[init.index()];
+ }
+ expr
+ }
+ expr => expr,
+ };
+ let mut evaluator = ConstantEvaluator::for_wgsl_module(
+ &mut module,
+ &mut global_expression_kind_tracker,
+ false,
+ );
+ adjust_expr(&adjusted_global_expressions, &mut expr);
+ let h = evaluator.try_eval_and_append(expr, span)?;
+ debug_assert_eq!(old_h.index(), adjusted_global_expressions.len());
+ adjusted_global_expressions.push(h);
+ }
+
+ // Finish processing any overrides we didn't visit in the loop above.
+ for entry in override_iter {
+ process_override(
+ entry,
+ pipeline_constants,
+ &mut module,
+ &mut override_map,
+ &adjusted_global_expressions,
+ &mut adjusted_constant_initializers,
+ &mut global_expression_kind_tracker,
+ )?;
+ }
+
+ // Update the initialization expression handles of all `Constant`s
+ // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
+ // passant.
+ for (_, c) in module
+ .constants
+ .iter_mut()
+ .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
+ {
+ c.init = adjusted_global_expressions[c.init.index()];
+ }
+
+ for (_, v) in module.global_variables.iter_mut() {
+ if let Some(ref mut init) = v.init {
+ *init = adjusted_global_expressions[init.index()];
+ }
+ }
+
+ let mut functions = mem::take(&mut module.functions);
+ for (_, function) in functions.iter_mut() {
+ process_function(&mut module, &override_map, function)?;
+ }
+ module.functions = functions;
+
+ let mut entry_points = mem::take(&mut module.entry_points);
+ for ep in entry_points.iter_mut() {
+ process_function(&mut module, &override_map, &mut ep.function)?;
+ }
+ module.entry_points = entry_points;
+
+ // Now that we've rewritten all the expressions, we need to
+ // recompute their types and other metadata. For the time being,
+ // do a full re-validation.
+ let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
+ let module_info = validator.validate_no_overrides(&module)?;
+
+ Ok((Cow::Owned(module), Cow::Owned(module_info)))
+}
+
+/// Add a [`Constant`] to `module` for the override `old_h`.
+///
+/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
+fn process_override(
+ (old_h, override_, span): (Handle<Override>, Override, Span),
+ pipeline_constants: &PipelineConstants,
+ module: &mut Module,
+ override_map: &mut Vec<Handle<Constant>>,
+ adjusted_global_expressions: &[Handle<Expression>],
+ adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
+) -> Result<Handle<Constant>, PipelineConstantError> {
+ // Determine which key to use for `override_` in `pipeline_constants`.
+ let key = if let Some(id) = override_.id {
+ Cow::Owned(id.to_string())
+ } else if let Some(ref name) = override_.name {
+ Cow::Borrowed(name)
+ } else {
+ unreachable!();
+ };
+
+ // Generate a global expression for `override_`'s value, either
+ // from the provided `pipeline_constants` table or its initializer
+ // in the module.
+ let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
+ let literal = match module.types[override_.ty].inner {
+ TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
+ _ => unreachable!(),
+ };
+ let expr = module
+ .global_expressions
+ .append(Expression::Literal(literal), Span::UNDEFINED);
+ global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
+ expr
+ } else if let Some(init) = override_.init {
+ adjusted_global_expressions[init.index()]
+ } else {
+ return Err(PipelineConstantError::MissingValue(key.to_string()));
+ };
+
+ // Generate a new `Constant` to represent the override's value.
+ let constant = Constant {
+ name: override_.name,
+ ty: override_.ty,
+ init,
+ };
+ let h = module.constants.append(constant, span);
+ debug_assert_eq!(old_h.index(), override_map.len());
+ override_map.push(h);
+ adjusted_constant_initializers.insert(h);
+ Ok(h)
+}
+
+/// Replace all override expressions in `function` with fully-evaluated constants.
+///
+/// Replace all `Expression::Override`s in `function`'s expression arena with
+/// the corresponding `Expression::Constant`s, as given in `override_map`.
+/// Replace any expressions whose values are now known with their fully
+/// evaluated form.
+///
+/// If `h` is a `Handle<Override>`, then `override_map[h.index()]` is the
+/// `Handle<Constant>` for the override's final value.
+fn process_function(
+ module: &mut Module,
+ override_map: &[Handle<Constant>],
+ function: &mut Function,
+) -> Result<(), ConstantEvaluatorError> {
+ // A map from original local expression handles to
+ // handles in the new, local expression arena.
+ let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len());
+
+ let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
+
+ let mut expressions = mem::take(&mut function.expressions);
+
+ // Dummy `emitter` and `block` for the constant evaluator.
+ // We can ignore the concept of emitting expressions here since
+ // expressions have already been covered by a `Statement::Emit`
+ // in the frontend.
+ // The only thing we might have to do is remove some expressions
+ // that have been covered by a `Statement::Emit`. See the docs of
+ // `filter_emits_in_block` for the reasoning.
+ let mut emitter = Emitter::default();
+ let mut block = Block::new();
+
+ let mut evaluator = ConstantEvaluator::for_wgsl_function(
+ module,
+ &mut function.expressions,
+ &mut local_expression_kind_tracker,
+ &mut emitter,
+ &mut block,
+ );
+
+ for (old_h, mut expr, span) in expressions.drain() {
+ if let Expression::Override(h) = expr {
+ expr = Expression::Constant(override_map[h.index()]);
+ }
+ adjust_expr(&adjusted_local_expressions, &mut expr);
+ let h = evaluator.try_eval_and_append(expr, span)?;
+ debug_assert_eq!(old_h.index(), adjusted_local_expressions.len());
+ adjusted_local_expressions.push(h);
+ }
+
+ adjust_block(&adjusted_local_expressions, &mut function.body);
+
+ filter_emits_in_block(&mut function.body, &function.expressions);
+
+ // Update local expression initializers.
+ for (_, local) in function.local_variables.iter_mut() {
+ if let &mut Some(ref mut init) = &mut local.init {
+ *init = adjusted_local_expressions[init.index()];
+ }
+ }
+
+ // We've changed the keys of `function.named_expression`, so we have to
+ // rebuild it from scratch.
+ let named_expressions = mem::take(&mut function.named_expressions);
+ for (expr_h, name) in named_expressions {
+ function
+ .named_expressions
+ .insert(adjusted_local_expressions[expr_h.index()], name);
+ }
+
+ Ok(())
+}
+
+/// Replace every expression handle in `expr` with its counterpart
+/// given by `new_pos`.
+fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
+ let adjust = |expr: &mut Handle<Expression>| {
+ *expr = new_pos[expr.index()];
+ };
+ match *expr {
+ Expression::Compose {
+ ref mut components,
+ ty: _,
+ } => {
+ for c in components.iter_mut() {
+ adjust(c);
+ }
+ }
+ Expression::Access {
+ ref mut base,
+ ref mut index,
+ } => {
+ adjust(base);
+ adjust(index);
+ }
+ Expression::AccessIndex {
+ ref mut base,
+ index: _,
+ } => {
+ adjust(base);
+ }
+ Expression::Splat {
+ ref mut value,
+ size: _,
+ } => {
+ adjust(value);
+ }
+ Expression::Swizzle {
+ ref mut vector,
+ size: _,
+ pattern: _,
+ } => {
+ adjust(vector);
+ }
+ Expression::Load { ref mut pointer } => {
+ adjust(pointer);
+ }
+ Expression::ImageSample {
+ ref mut image,
+ ref mut sampler,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut offset,
+ ref mut level,
+ ref mut depth_ref,
+ gather: _,
+ } => {
+ adjust(image);
+ adjust(sampler);
+ adjust(coordinate);
+ if let Some(e) = array_index.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = offset.as_mut() {
+ adjust(e);
+ }
+ match *level {
+ crate::SampleLevel::Exact(ref mut expr)
+ | crate::SampleLevel::Bias(ref mut expr) => {
+ adjust(expr);
+ }
+ crate::SampleLevel::Gradient {
+ ref mut x,
+ ref mut y,
+ } => {
+ adjust(x);
+ adjust(y);
+ }
+ _ => {}
+ }
+ if let Some(e) = depth_ref.as_mut() {
+ adjust(e);
+ }
+ }
+ Expression::ImageLoad {
+ ref mut image,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut sample,
+ ref mut level,
+ } => {
+ adjust(image);
+ adjust(coordinate);
+ if let Some(e) = array_index.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = sample.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = level.as_mut() {
+ adjust(e);
+ }
+ }
+ Expression::ImageQuery {
+ ref mut image,
+ ref mut query,
+ } => {
+ adjust(image);
+ match *query {
+ crate::ImageQuery::Size { ref mut level } => {
+ if let Some(e) = level.as_mut() {
+ adjust(e);
+ }
+ }
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => {}
+ }
+ }
+ Expression::Unary {
+ ref mut expr,
+ op: _,
+ } => {
+ adjust(expr);
+ }
+ Expression::Binary {
+ ref mut left,
+ ref mut right,
+ op: _,
+ } => {
+ adjust(left);
+ adjust(right);
+ }
+ Expression::Select {
+ ref mut condition,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ adjust(condition);
+ adjust(accept);
+ adjust(reject);
+ }
+ Expression::Derivative {
+ ref mut expr,
+ axis: _,
+ ctrl: _,
+ } => {
+ adjust(expr);
+ }
+ Expression::Relational {
+ ref mut argument,
+ fun: _,
+ } => {
+ adjust(argument);
+ }
+ Expression::Math {
+ ref mut arg,
+ ref mut arg1,
+ ref mut arg2,
+ ref mut arg3,
+ fun: _,
+ } => {
+ adjust(arg);
+ if let Some(e) = arg1.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = arg2.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = arg3.as_mut() {
+ adjust(e);
+ }
+ }
+ Expression::As {
+ ref mut expr,
+ kind: _,
+ convert: _,
+ } => {
+ adjust(expr);
+ }
+ Expression::ArrayLength(ref mut expr) => {
+ adjust(expr);
+ }
+ Expression::RayQueryGetIntersection {
+ ref mut query,
+ committed: _,
+ } => {
+ adjust(query);
+ }
+ Expression::Literal(_)
+ | Expression::FunctionArgument(_)
+ | Expression::GlobalVariable(_)
+ | Expression::LocalVariable(_)
+ | Expression::CallResult(_)
+ | Expression::RayQueryProceedResult
+ | Expression::Constant(_)
+ | Expression::Override(_)
+ | Expression::ZeroValue(_)
+ | Expression::AtomicResult {
+ ty: _,
+ comparison: _,
+ }
+ | Expression::WorkGroupUniformLoadResult { ty: _ }
+ | Expression::SubgroupBallotResult
+ | Expression::SubgroupOperationResult { .. } => {}
+ }
+}
+
+/// Replace every expression handle in `block` with its counterpart
+/// given by `new_pos`.
+fn adjust_block(new_pos: &[Handle<Expression>], block: &mut Block) {
+ for stmt in block.iter_mut() {
+ adjust_stmt(new_pos, stmt);
+ }
+}
+
+/// Replace every expression handle in `stmt` with its counterpart
+/// given by `new_pos`.
+fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
+ let adjust = |expr: &mut Handle<Expression>| {
+ *expr = new_pos[expr.index()];
+ };
+ match *stmt {
+ Statement::Emit(ref mut range) => {
+ if let Some((mut first, mut last)) = range.first_and_last() {
+ adjust(&mut first);
+ adjust(&mut last);
+ *range = Range::new_from_bounds(first, last);
+ }
+ }
+ Statement::Block(ref mut block) => {
+ adjust_block(new_pos, block);
+ }
+ Statement::If {
+ ref mut condition,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ adjust(condition);
+ adjust_block(new_pos, accept);
+ adjust_block(new_pos, reject);
+ }
+ Statement::Switch {
+ ref mut selector,
+ ref mut cases,
+ } => {
+ adjust(selector);
+ for case in cases.iter_mut() {
+ adjust_block(new_pos, &mut case.body);
+ }
+ }
+ Statement::Loop {
+ ref mut body,
+ ref mut continuing,
+ ref mut break_if,
+ } => {
+ adjust_block(new_pos, body);
+ adjust_block(new_pos, continuing);
+ if let Some(e) = break_if.as_mut() {
+ adjust(e);
+ }
+ }
+ Statement::Return { ref mut value } => {
+ if let Some(e) = value.as_mut() {
+ adjust(e);
+ }
+ }
+ Statement::Store {
+ ref mut pointer,
+ ref mut value,
+ } => {
+ adjust(pointer);
+ adjust(value);
+ }
+ Statement::ImageStore {
+ ref mut image,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut value,
+ } => {
+ adjust(image);
+ adjust(coordinate);
+ if let Some(e) = array_index.as_mut() {
+ adjust(e);
+ }
+ adjust(value);
+ }
+ crate::Statement::Atomic {
+ ref mut pointer,
+ ref mut value,
+ ref mut result,
+ ref mut fun,
+ } => {
+ adjust(pointer);
+ adjust(value);
+ adjust(result);
+ match *fun {
+ crate::AtomicFunction::Exchange {
+ compare: Some(ref mut compare),
+ } => {
+ adjust(compare);
+ }
+ crate::AtomicFunction::Add
+ | crate::AtomicFunction::Subtract
+ | crate::AtomicFunction::And
+ | crate::AtomicFunction::ExclusiveOr
+ | crate::AtomicFunction::InclusiveOr
+ | crate::AtomicFunction::Min
+ | crate::AtomicFunction::Max
+ | crate::AtomicFunction::Exchange { compare: None } => {}
+ }
+ }
+ Statement::WorkGroupUniformLoad {
+ ref mut pointer,
+ ref mut result,
+ } => {
+ adjust(pointer);
+ adjust(result);
+ }
+ Statement::SubgroupBallot {
+ ref mut result,
+ ref mut predicate,
+ } => {
+ if let Some(ref mut predicate) = *predicate {
+ adjust(predicate);
+ }
+ adjust(result);
+ }
+ Statement::SubgroupCollectiveOperation {
+ ref mut argument,
+ ref mut result,
+ ..
+ } => {
+ adjust(argument);
+ adjust(result);
+ }
+ Statement::SubgroupGather {
+ ref mut mode,
+ ref mut argument,
+ ref mut result,
+ } => {
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(ref mut index)
+ | crate::GatherMode::Shuffle(ref mut index)
+ | crate::GatherMode::ShuffleDown(ref mut index)
+ | crate::GatherMode::ShuffleUp(ref mut index)
+ | crate::GatherMode::ShuffleXor(ref mut index) => {
+ adjust(index);
+ }
+ }
+ adjust(argument);
+ adjust(result)
+ }
+ Statement::Call {
+ ref mut arguments,
+ ref mut result,
+ function: _,
+ } => {
+ for argument in arguments.iter_mut() {
+ adjust(argument);
+ }
+ if let Some(e) = result.as_mut() {
+ adjust(e);
+ }
+ }
+ Statement::RayQuery {
+ ref mut query,
+ ref mut fun,
+ } => {
+ adjust(query);
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ ref mut acceleration_structure,
+ ref mut descriptor,
+ } => {
+ adjust(acceleration_structure);
+ adjust(descriptor);
+ }
+ crate::RayQueryFunction::Proceed { ref mut result } => {
+ adjust(result);
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+ Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
+ }
+}
+
+/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
+///
+/// According to validation, [`Emit`] statements must not cover any expressions
+/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
+/// by successful constant evaluation fall into that category, meaning that
+/// `process_function` will usually rewrite [`Override`] expressions and those
+/// that use their values into pre-emitted expressions, leaving any [`Emit`]
+/// statements that cover them invalid.
+///
+/// This function rewrites all [`Emit`] statements into zero or more new
+/// [`Emit`] statements covering only those expressions in the original range
+/// that are not pre-emitted.
+///
+/// [`Emit`]: Statement::Emit
+/// [`needs_pre_emit`]: Expression::needs_pre_emit
+/// [`Override`]: Expression::Override
+fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
+ let original = std::mem::replace(block, Block::with_capacity(block.len()));
+ for (stmt, span) in original.span_into_iter() {
+ match stmt {
+ Statement::Emit(range) => {
+ let mut current = None;
+ for expr_h in range {
+ if expressions[expr_h].needs_pre_emit() {
+ if let Some((first, last)) = current {
+ block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
+ }
+
+ current = None;
+ } else if let Some((_, ref mut last)) = current {
+ *last = expr_h;
+ } else {
+ current = Some((expr_h, expr_h));
+ }
+ }
+ if let Some((first, last)) = current {
+ block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
+ }
+ }
+ Statement::Block(mut child) => {
+ filter_emits_in_block(&mut child, expressions);
+ block.push(Statement::Block(child), span);
+ }
+ Statement::If {
+ condition,
+ mut accept,
+ mut reject,
+ } => {
+ filter_emits_in_block(&mut accept, expressions);
+ filter_emits_in_block(&mut reject, expressions);
+ block.push(
+ Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ span,
+ );
+ }
+ Statement::Switch {
+ selector,
+ mut cases,
+ } => {
+ for case in &mut cases {
+ filter_emits_in_block(&mut case.body, expressions);
+ }
+ block.push(Statement::Switch { selector, cases }, span);
+ }
+ Statement::Loop {
+ mut body,
+ mut continuing,
+ break_if,
+ } => {
+ filter_emits_in_block(&mut body, expressions);
+ filter_emits_in_block(&mut continuing, expressions);
+ block.push(
+ Statement::Loop {
+ body,
+ continuing,
+ break_if,
+ },
+ span,
+ );
+ }
+ stmt => block.push(stmt.clone(), span),
+ }
+ }
+}
+
+fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
+ // note that in rust 0.0 == -0.0
+ match scalar {
+ Scalar::BOOL => {
+ // https://webidl.spec.whatwg.org/#js-boolean
+ let value = value != 0.0 && !value.is_nan();
+ Ok(Literal::Bool(value))
+ }
+ Scalar::I32 => {
+ // https://webidl.spec.whatwg.org/#js-long
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ let value = value.trunc();
+ if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
+ return Err(PipelineConstantError::DstRangeTooSmall);
+ }
+
+ let value = value as i32;
+ Ok(Literal::I32(value))
+ }
+ Scalar::U32 => {
+ // https://webidl.spec.whatwg.org/#js-unsigned-long
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ let value = value.trunc();
+ if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
+ return Err(PipelineConstantError::DstRangeTooSmall);
+ }
+
+ let value = value as u32;
+ Ok(Literal::U32(value))
+ }
+ Scalar::F32 => {
+ // https://webidl.spec.whatwg.org/#js-float
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ let value = value as f32;
+ if !value.is_finite() {
+ return Err(PipelineConstantError::DstRangeTooSmall);
+ }
+
+ Ok(Literal::F32(value))
+ }
+ Scalar::F64 => {
+ // https://webidl.spec.whatwg.org/#js-double
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ Ok(Literal::F64(value))
+ }
+ _ => unreachable!(),
+ }
+}
+
+#[test]
+fn test_map_value_to_literal() {
+ let bool_test_cases = [
+ (0.0, false),
+ (-0.0, false),
+ (f64::NAN, false),
+ (1.0, true),
+ (f64::INFINITY, true),
+ (f64::NEG_INFINITY, true),
+ ];
+ for (value, out) in bool_test_cases {
+ let res = Ok(Literal::Bool(out));
+ assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
+ }
+
+ for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
+ for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
+ let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
+ assert_eq!(map_value_to_literal(value, scalar), res);
+ }
+ }
+
+ // i32
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
+ Ok(Literal::I32(i32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
+ Ok(Literal::I32(i32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+
+ // u32
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
+ Ok(Literal::U32(u32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
+ Ok(Literal::U32(u32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+
+ // f32
+ assert_eq!(
+ map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
+ Ok(Literal::F32(f32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
+ Ok(Literal::F32(f32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
+ Ok(Literal::F32(f32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
+ Ok(Literal::F32(f32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+
+ // f64
+ assert_eq!(
+ map_value_to_literal(f64::MIN, Scalar::F64),
+ Ok(Literal::F64(f64::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::MAX, Scalar::F64),
+ Ok(Literal::F64(f64::MAX))
+ );
+}
diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs
index 81f2fc10e0..120d60fc40 100644
--- a/third_party/rust/naga/src/back/spv/block.rs
+++ b/third_party/rust/naga/src/back/spv/block.rs
@@ -239,6 +239,7 @@ impl<'w> BlockContext<'w> {
let init = self.ir_module.constants[handle].init;
self.writer.constant_ids[init.index()]
}
+ crate::Expression::Override(_) => return Err(Error::Override),
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();
@@ -1072,7 +1073,7 @@ impl<'w> BlockContext<'w> {
//
// bitfieldExtract(x, o, c)
- let bit_width = arg_ty.scalar_width().unwrap();
+ let bit_width = arg_ty.scalar_width().unwrap() * 8;
let width_constant = self
.writer
.get_constant_scalar(crate::Literal::U32(bit_width as u32));
@@ -1128,7 +1129,7 @@ impl<'w> BlockContext<'w> {
Mf::InsertBits => {
// The behavior of InsertBits has the same undefined behavior as ExtractBits.
- let bit_width = arg_ty.scalar_width().unwrap();
+ let bit_width = arg_ty.scalar_width().unwrap() * 8;
let width_constant = self
.writer
.get_constant_scalar(crate::Literal::U32(bit_width as u32));
@@ -1184,7 +1185,7 @@ impl<'w> BlockContext<'w> {
}
Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
Mf::FindMsb => {
- if arg_ty.scalar_width() == Some(32) {
+ if arg_ty.scalar_width() == Some(4) {
let thing = match arg_scalar_kind {
Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
@@ -1278,7 +1279,9 @@ impl<'w> BlockContext<'w> {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
- | crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
+ | crate::Expression::RayQueryProceedResult
+ | crate::Expression::SubgroupBallotResult
+ | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
crate::Expression::As {
expr,
kind,
@@ -2489,6 +2492,27 @@ impl<'w> BlockContext<'w> {
crate::Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);
}
+ crate::Statement::SubgroupBallot {
+ result,
+ ref predicate,
+ } => {
+ self.write_subgroup_ballot(predicate, result, &mut block)?;
+ }
+ crate::Statement::SubgroupCollectiveOperation {
+ ref op,
+ ref collective_op,
+ argument,
+ result,
+ } => {
+ self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
+ }
+ crate::Statement::SubgroupGather {
+ ref mode,
+ argument,
+ result,
+ } => {
+ self.write_subgroup_gather(mode, argument, result, &mut block)?;
+ }
}
}
diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs
index 5b6226db85..1fb447e384 100644
--- a/third_party/rust/naga/src/back/spv/helpers.rs
+++ b/third_party/rust/naga/src/back/spv/helpers.rs
@@ -10,8 +10,12 @@ pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> {
pub(super) fn string_to_words(input: &str) -> Vec<Word> {
let bytes = input.as_bytes();
- let mut words = bytes_to_words(bytes);
+ str_bytes_to_words(bytes)
+}
+
+pub(super) fn str_bytes_to_words(bytes: &[u8]) -> Vec<Word> {
+ let mut words = bytes_to_words(bytes);
if bytes.len() % 4 == 0 {
// nul-termination
words.push(0x0u32);
@@ -20,6 +24,21 @@ pub(super) fn string_to_words(input: &str) -> Vec<Word> {
words
}
+/// split a string into chunks and keep utf8 valid
+#[allow(unstable_name_collisions)]
+pub(super) fn string_to_byte_chunks(input: &str, limit: usize) -> Vec<&[u8]> {
+ let mut offset: usize = 0;
+ let mut start: usize = 0;
+ let mut words = vec![];
+ while offset < input.len() {
+ offset = input.floor_char_boundary(offset + limit);
+ words.push(input[start..offset].as_bytes());
+ start = offset;
+ }
+
+ words
+}
+
pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass {
match space {
crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant,
@@ -107,3 +126,35 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab
_ => true,
}
}
+
+///HACK: this is taken from std unstable, remove it when std's floor_char_boundary is stable
+trait U8Internal {
+ fn is_utf8_char_boundary(&self) -> bool;
+}
+
+impl U8Internal for u8 {
+ fn is_utf8_char_boundary(&self) -> bool {
+ // This is bit magic equivalent to: b < 128 || b >= 192
+ (*self as i8) >= -0x40
+ }
+}
+
+trait StrUnstable {
+ fn floor_char_boundary(&self, index: usize) -> usize;
+}
+
+impl StrUnstable for str {
+ fn floor_char_boundary(&self, index: usize) -> usize {
+ if index >= self.len() {
+ self.len()
+ } else {
+ let lower_bound = index.saturating_sub(3);
+ let new_index = self.as_bytes()[lower_bound..=index]
+ .iter()
+ .rposition(|b| b.is_utf8_char_boundary());
+
+ // SAFETY: we know that the character boundary will be within four bytes
+ unsafe { lower_bound + new_index.unwrap_unchecked() }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs
index b963793ad3..df2774ab9c 100644
--- a/third_party/rust/naga/src/back/spv/instructions.rs
+++ b/third_party/rust/naga/src/back/spv/instructions.rs
@@ -43,6 +43,42 @@ impl super::Instruction {
instruction
}
+ pub(super) fn source_continued(source: &[u8]) -> Self {
+ let mut instruction = Self::new(Op::SourceContinued);
+ instruction.add_operands(helpers::str_bytes_to_words(source));
+ instruction
+ }
+
+ pub(super) fn source_auto_continued(
+ source_language: spirv::SourceLanguage,
+ version: u32,
+ source: &Option<DebugInfoInner>,
+ ) -> Vec<Self> {
+ let mut instructions = vec![];
+
+ let with_continue = source.as_ref().and_then(|debug_info| {
+ (debug_info.source_code.len() > u16::MAX as usize).then_some(debug_info)
+ });
+ if let Some(debug_info) = with_continue {
+ let mut instruction = Self::new(Op::Source);
+ instruction.add_operand(source_language as u32);
+ instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes()));
+
+ let words = helpers::string_to_byte_chunks(debug_info.source_code, u16::MAX as usize);
+ instruction.add_operand(debug_info.source_file_id);
+ instruction.add_operands(helpers::str_bytes_to_words(words[0]));
+ instructions.push(instruction);
+ for word_bytes in words[1..].iter() {
+ let instruction_continue = Self::source_continued(word_bytes);
+ instructions.push(instruction_continue);
+ }
+ } else {
+ let instruction = Self::source(source_language, version, source);
+ instructions.push(instruction);
+ }
+ instructions
+ }
+
pub(super) fn name(target_id: Word, name: &str) -> Self {
let mut instruction = Self::new(Op::Name);
instruction.add_operand(target_id);
@@ -1037,6 +1073,73 @@ impl super::Instruction {
instruction.add_operand(semantics_id);
instruction
}
+
+ // Group Instructions
+
+ pub(super) fn group_non_uniform_ballot(
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ predicate: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::GroupNonUniformBallot);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(predicate);
+
+ instruction
+ }
+ pub(super) fn group_non_uniform_broadcast_first(
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(value);
+
+ instruction
+ }
+ pub(super) fn group_non_uniform_gather(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ value: Word,
+ index: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(value);
+ instruction.add_operand(index);
+
+ instruction
+ }
+ pub(super) fn group_non_uniform_arithmetic(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ group_op: Option<spirv::GroupOperation>,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ if let Some(group_op) = group_op {
+ instruction.add_operand(group_op as u32);
+ }
+ instruction.add_operand(value);
+
+ instruction
+ }
}
impl From<crate::StorageFormat> for spirv::ImageFormat {
diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs
index eb29e3cd8b..38a87049e6 100644
--- a/third_party/rust/naga/src/back/spv/mod.rs
+++ b/third_party/rust/naga/src/back/spv/mod.rs
@@ -13,6 +13,7 @@ mod layout;
mod ray;
mod recyclable;
mod selection;
+mod subgroup;
mod writer;
pub use spirv::Capability;
@@ -70,6 +71,8 @@ pub enum Error {
FeatureNotImplemented(&'static str),
#[error("module is not validated properly: {0}")]
Validation(&'static str),
+ #[error("overrides should not be present at this stage")]
+ Override,
}
#[derive(Default)]
@@ -245,7 +248,7 @@ impl LocalImageType {
/// this, by converting everything possible to a `LocalType` before inspecting
/// it.
///
-/// ## `Localtype` equality and SPIR-V `OpType` uniqueness
+/// ## `LocalType` equality and SPIR-V `OpType` uniqueness
///
/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow
/// certain SPIR-V rules. SPIR-V ยง2.8 requires some classes of `OpType...`
@@ -454,7 +457,7 @@ impl recyclable::Recyclable for CachedExpressions {
#[derive(Eq, Hash, PartialEq)]
enum CachedConstant {
- Literal(crate::Literal),
+ Literal(crate::proc::HashableLiteral),
Composite {
ty: LookupType,
constituent_ids: Vec<Word>,
@@ -527,6 +530,42 @@ struct FunctionArgument {
handle_id: Word,
}
+/// Tracks the expressions for which the backend emits the following instructions:
+/// - OpConstantTrue
+/// - OpConstantFalse
+/// - OpConstant
+/// - OpConstantComposite
+/// - OpConstantNull
+struct ExpressionConstnessTracker {
+ inner: bit_set::BitSet,
+}
+
+impl ExpressionConstnessTracker {
+ fn from_arena(arena: &crate::Arena<crate::Expression>) -> Self {
+ let mut inner = bit_set::BitSet::new();
+ for (handle, expr) in arena.iter() {
+ let insert = match *expr {
+ crate::Expression::Literal(_)
+ | crate::Expression::ZeroValue(_)
+ | crate::Expression::Constant(_) => true,
+ crate::Expression::Compose { ref components, .. } => {
+ components.iter().all(|h| inner.contains(h.index()))
+ }
+ crate::Expression::Splat { value, .. } => inner.contains(value.index()),
+ _ => false,
+ };
+ if insert {
+ inner.insert(handle.index());
+ }
+ }
+ Self { inner }
+ }
+
+ fn is_const(&self, value: Handle<crate::Expression>) -> bool {
+ self.inner.contains(value.index())
+ }
+}
+
/// General information needed to emit SPIR-V for Naga statements.
struct BlockContext<'w> {
/// The writer handling the module to which this code belongs.
@@ -552,7 +591,7 @@ struct BlockContext<'w> {
temp_list: Vec<Word>,
/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
- expression_constness: crate::proc::ExpressionConstnessTracker,
+ expression_constness: ExpressionConstnessTracker,
}
impl BlockContext<'_> {
@@ -725,7 +764,7 @@ impl<'a> Default for Options<'a> {
}
// A subset of options meant to be changed per pipeline.
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
diff --git a/third_party/rust/naga/src/back/spv/subgroup.rs b/third_party/rust/naga/src/back/spv/subgroup.rs
new file mode 100644
index 0000000000..c952cb11a7
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/subgroup.rs
@@ -0,0 +1,207 @@
+use super::{Block, BlockContext, Error, Instruction};
+use crate::{
+ arena::Handle,
+ back::spv::{LocalType, LookupType},
+ TypeInner,
+};
+
+impl<'w> BlockContext<'w> {
+ pub(super) fn write_subgroup_ballot(
+ &mut self,
+ predicate: &Option<Handle<crate::Expression>>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+ let predicate = if let Some(predicate) = *predicate {
+ self.cached[predicate]
+ } else {
+ self.writer.get_constant_scalar(crate::Literal::Bool(true))
+ };
+ let id = self.gen_id();
+ block.body.push(Instruction::group_non_uniform_ballot(
+ vec4_u32_type_id,
+ id,
+ exec_scope_id,
+ predicate,
+ ));
+ self.cached[result] = id;
+ Ok(())
+ }
+ pub(super) fn write_subgroup_operation(
+ &mut self,
+ op: &crate::SubgroupOperation,
+ collective_op: &crate::CollectiveOperation,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ use crate::SubgroupOperation as sg;
+ match *op {
+ sg::All | sg::Any => {
+ self.writer.require_any(
+ "GroupNonUniformVote",
+ &[spirv::Capability::GroupNonUniformVote],
+ )?;
+ }
+ _ => {
+ self.writer.require_any(
+ "GroupNonUniformArithmetic",
+ &[spirv::Capability::GroupNonUniformArithmetic],
+ )?;
+ }
+ }
+
+ let id = self.gen_id();
+ let result_ty = &self.fun_info[result].ty;
+ let result_type_id = self.get_expression_type_id(result_ty);
+ let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
+
+ let (is_scalar, scalar) = match *result_ty_inner {
+ TypeInner::Scalar(kind) => (true, kind),
+ TypeInner::Vector { scalar: kind, .. } => (false, kind),
+ _ => unimplemented!(),
+ };
+
+ use crate::ScalarKind as sk;
+ let spirv_op = match (scalar.kind, *op) {
+ (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
+ (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
+ (_, sg::All | sg::Any) => unimplemented!(),
+
+ (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
+ (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
+ (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
+ (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
+ (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
+ (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
+ (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
+ (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
+ (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
+ (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
+ (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
+
+ (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
+ (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
+ (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
+ (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
+ (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
+ (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
+ (_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
+ };
+
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+
+ use crate::CollectiveOperation as c;
+ let group_op = match *op {
+ sg::All | sg::Any => None,
+ _ => Some(match *collective_op {
+ c::Reduce => spirv::GroupOperation::Reduce,
+ c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
+ c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
+ }),
+ };
+
+ let arg_id = self.cached[argument];
+ block.body.push(Instruction::group_non_uniform_arithmetic(
+ spirv_op,
+ result_type_id,
+ id,
+ exec_scope_id,
+ group_op,
+ arg_id,
+ ));
+ self.cached[result] = id;
+ Ok(())
+ }
+ pub(super) fn write_subgroup_gather(
+ &mut self,
+ mode: &crate::GatherMode,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ match *mode {
+ crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ }
+ crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => {
+ self.writer.require_any(
+ "GroupNonUniformShuffle",
+ &[spirv::Capability::GroupNonUniformShuffle],
+ )?;
+ }
+ crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
+ self.writer.require_any(
+ "GroupNonUniformShuffleRelative",
+ &[spirv::Capability::GroupNonUniformShuffleRelative],
+ )?;
+ }
+ }
+
+ let id = self.gen_id();
+ let result_ty = &self.fun_info[result].ty;
+ let result_type_id = self.get_expression_type_id(result_ty);
+
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+
+ let arg_id = self.cached[argument];
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {
+ block
+ .body
+ .push(Instruction::group_non_uniform_broadcast_first(
+ result_type_id,
+ id,
+ exec_scope_id,
+ arg_id,
+ ));
+ }
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let index_id = self.cached[index];
+ let op = match *mode {
+ crate::GatherMode::BroadcastFirst => unreachable!(),
+ // Use shuffle to emit broadcast to allow the index to
+ // be dynamically uniform on Vulkan 1.1. The argument to
+ // OpGroupNonUniformBroadcast must be a constant pre SPIR-V
+ // 1.5 (vulkan 1.2)
+ crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
+ crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
+ crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
+ crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
+ crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
+ };
+ block.body.push(Instruction::group_non_uniform_gather(
+ op,
+ result_type_id,
+ id,
+ exec_scope_id,
+ arg_id,
+ index_id,
+ ));
+ }
+ }
+ self.cached[result] = id;
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs
index a5065e0623..73a16c273e 100644
--- a/third_party/rust/naga/src/back/spv/writer.rs
+++ b/third_party/rust/naga/src/back/spv/writer.rs
@@ -615,7 +615,7 @@ impl Writer {
// Steal the Writer's temp list for a bit.
temp_list: std::mem::take(&mut self.temp_list),
writer: self,
- expression_constness: crate::proc::ExpressionConstnessTracker::from_arena(
+ expression_constness: super::ExpressionConstnessTracker::from_arena(
&ir_function.expressions,
),
};
@@ -970,6 +970,11 @@ impl Writer {
handle: Handle<crate::Type>,
) -> Result<Word, Error> {
let ty = &arena[handle];
+ // If it's a type that needs SPIR-V capabilities, request them now.
+ // This needs to happen regardless of the LocalType lookup succeeding,
+ // because some types which map to the same LocalType have different
+ // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
+ self.request_type_capabilities(&ty.inner)?;
let id = if let Some(local) = make_local(&ty.inner) {
// This type can be represented as a `LocalType`, so check if we've
// already written an instruction for it. If not, do so now, with
@@ -985,10 +990,6 @@ impl Writer {
self.write_type_declaration_local(id, local);
- // If it's a type that needs SPIR-V capabilities, request them now,
- // so write_type_declaration_local can stay infallible.
- self.request_type_capabilities(&ty.inner)?;
-
id
}
}
@@ -1150,7 +1151,7 @@ impl Writer {
}
pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
- let scalar = CachedConstant::Literal(value);
+ let scalar = CachedConstant::Literal(value.into());
if let Some(&id) = self.cached_constants.get(&scalar) {
return id;
}
@@ -1258,7 +1259,7 @@ impl Writer {
ir_module: &crate::Module,
mod_info: &ModuleInfo,
) -> Result<Word, Error> {
- let id = match ir_module.const_expressions[handle] {
+ let id = match ir_module.global_expressions[handle] {
crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
crate::Expression::Constant(constant) => {
let constant = &ir_module.constants[constant];
@@ -1272,7 +1273,7 @@ impl Writer {
let component_ids: Vec<_> = crate::proc::flatten_compose(
ty,
components,
- &ir_module.const_expressions,
+ &ir_module.global_expressions,
&ir_module.types,
)
.map(|component| self.constant_ids[component.index()])
@@ -1310,7 +1311,11 @@ impl Writer {
spirv::MemorySemantics::WORKGROUP_MEMORY,
flags.contains(crate::Barrier::WORK_GROUP),
);
- let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
+ let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
+ self.get_index_constant(spirv::Scope::Subgroup as u32)
+ } else {
+ self.get_index_constant(spirv::Scope::Workgroup as u32)
+ };
let mem_scope_id = self.get_index_constant(memory_scope as u32);
let semantics_id = self.get_index_constant(semantics.bits());
block.body.push(Instruction::control_barrier(
@@ -1585,6 +1590,41 @@ impl Writer {
Bi::WorkGroupId => BuiltIn::WorkgroupId,
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
+ // Subgroup
+ Bi::NumSubgroups => {
+ self.require_any(
+ "`num_subgroups` built-in",
+ &[spirv::Capability::GroupNonUniform],
+ )?;
+ BuiltIn::NumSubgroups
+ }
+ Bi::SubgroupId => {
+ self.require_any(
+ "`subgroup_id` built-in",
+ &[spirv::Capability::GroupNonUniform],
+ )?;
+ BuiltIn::SubgroupId
+ }
+ Bi::SubgroupSize => {
+ self.require_any(
+ "`subgroup_size` built-in",
+ &[
+ spirv::Capability::GroupNonUniform,
+ spirv::Capability::SubgroupBallotKHR,
+ ],
+ )?;
+ BuiltIn::SubgroupSize
+ }
+ Bi::SubgroupInvocationId => {
+ self.require_any(
+ "`subgroup_invocation_id` built-in",
+ &[
+ spirv::Capability::GroupNonUniform,
+ spirv::Capability::SubgroupBallotKHR,
+ ],
+ )?;
+ BuiltIn::SubgroupLocalInvocationId
+ }
};
self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
@@ -1899,7 +1939,7 @@ impl Writer {
source_code: debug_info.source_code,
source_file_id,
});
- self.debugs.push(Instruction::source(
+ self.debugs.append(&mut Instruction::source_auto_continued(
spirv::SourceLanguage::Unknown,
0,
&debug_info_inner,
@@ -1914,8 +1954,8 @@ impl Writer {
// write all const-expressions as constants
self.constant_ids
- .resize(ir_module.const_expressions.len(), 0);
- for (handle, _) in ir_module.const_expressions.iter() {
+ .resize(ir_module.global_expressions.len(), 0);
+ for (handle, _) in ir_module.global_expressions.iter() {
self.write_constant_expr(handle, ir_module, mod_info)?;
}
debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
@@ -2029,6 +2069,10 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
+ if !ir_module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
self.reset();
// Try to find the entry point and corresponding index
diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs
index 3039cbbbe4..789f6f62bf 100644
--- a/third_party/rust/naga/src/back/wgsl/writer.rs
+++ b/third_party/rust/naga/src/back/wgsl/writer.rs
@@ -106,6 +106,12 @@ impl<W: Write> Writer<W> {
}
pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
+ if !module.overrides.is_empty() {
+ return Err(Error::Unimplemented(
+ "Pipeline constants are not yet supported for this back-end".to_string(),
+ ));
+ }
+
self.reset(module);
// Save all ep result types
@@ -918,8 +924,124 @@ impl<W: Write> Writer<W> {
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}workgroupBarrier();")?;
}
+
+ if barrier.contains(crate::Barrier::SUB_GROUP) {
+ writeln!(self.out, "{level}subgroupBarrier();")?;
+ }
}
Statement::RayQuery { .. } => unreachable!(),
+ Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ write!(self.out, "subgroupBallot(")?;
+ if let Some(predicate) = predicate {
+ self.write_expr(module, predicate, func_ctx)?;
+ }
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "subgroupAll(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "subgroupAny(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupAdd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupMul(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "subgroupMax(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "subgroupMin(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "subgroupAnd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "subgroupOr(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "subgroupXor(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupExclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupExclusiveMul(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupInclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupInclusiveMul(")?
+ }
+ _ => unimplemented!(),
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ match mode {
+ crate::GatherMode::BroadcastFirst => {
+ write!(self.out, "subgroupBroadcastFirst(")?;
+ }
+ crate::GatherMode::Broadcast(_) => {
+ write!(self.out, "subgroupBroadcast(")?;
+ }
+ crate::GatherMode::Shuffle(_) => {
+ write!(self.out, "subgroupShuffle(")?;
+ }
+ crate::GatherMode::ShuffleDown(_) => {
+ write!(self.out, "subgroupShuffleDown(")?;
+ }
+ crate::GatherMode::ShuffleUp(_) => {
+ write!(self.out, "subgroupShuffleUp(")?;
+ }
+ crate::GatherMode::ShuffleXor(_) => {
+ write!(self.out, "subgroupShuffleXor(")?;
+ }
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
Ok(())
@@ -1070,7 +1192,7 @@ impl<W: Write> Writer<W> {
self.write_possibly_const_expression(
module,
expr,
- &module.const_expressions,
+ &module.global_expressions,
|writer, expr| writer.write_const_expression(module, expr),
)
}
@@ -1199,6 +1321,7 @@ impl<W: Write> Writer<W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
+ Expression::Override(_) => unreachable!(),
Expression::FunctionArgument(pos) => {
let name_key = func_ctx.argument_key(pos);
let name = &self.names[&name_key];
@@ -1691,6 +1814,8 @@ impl<W: Write> Writer<W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
+ | Expression::SubgroupBallotResult
+ | Expression::SubgroupOperationResult { .. }
| Expression::WorkGroupUniformLoadResult { .. } => {}
}
@@ -1792,6 +1917,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
Bi::SampleMask => "sample_mask",
Bi::PrimitiveIndex => "primitive_index",
Bi::ViewIndex => "view_index",
+ Bi::NumSubgroups => "num_subgroups",
+ Bi::SubgroupId => "subgroup_id",
+ Bi::SubgroupSize => "subgroup_size",
+ Bi::SubgroupInvocationId => "subgroup_invocation_id",
Bi::BaseInstance
| Bi::BaseVertex
| Bi::ClipDistance