summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src')
-rw-r--r--third_party/rust/naga/src/arena.rs17
-rw-r--r--third_party/rust/naga/src/back/dot/mod.rs91
-rw-r--r--third_party/rust/naga/src/back/glsl/features.rs23
-rw-r--r--third_party/rust/naga/src/back/glsl/mod.rs148
-rw-r--r--third_party/rust/naga/src/back/hlsl/conv.rs5
-rw-r--r--third_party/rust/naga/src/back/hlsl/help.rs94
-rw-r--r--third_party/rust/naga/src/back/hlsl/mod.rs17
-rw-r--r--third_party/rust/naga/src/back/hlsl/writer.rs315
-rw-r--r--third_party/rust/naga/src/back/mod.rs17
-rw-r--r--third_party/rust/naga/src/back/msl/mod.rs27
-rw-r--r--third_party/rust/naga/src/back/msl/writer.rs192
-rw-r--r--third_party/rust/naga/src/back/pipeline_constants.rs957
-rw-r--r--third_party/rust/naga/src/back/spv/block.rs32
-rw-r--r--third_party/rust/naga/src/back/spv/helpers.rs53
-rw-r--r--third_party/rust/naga/src/back/spv/instructions.rs103
-rw-r--r--third_party/rust/naga/src/back/spv/mod.rs47
-rw-r--r--third_party/rust/naga/src/back/spv/subgroup.rs207
-rw-r--r--third_party/rust/naga/src/back/spv/writer.rs68
-rw-r--r--third_party/rust/naga/src/back/wgsl/writer.rs131
-rw-r--r--third_party/rust/naga/src/block.rs6
-rw-r--r--third_party/rust/naga/src/compact/expressions.rs27
-rw-r--r--third_party/rust/naga/src/compact/functions.rs6
-rw-r--r--third_party/rust/naga/src/compact/mod.rs49
-rw-r--r--third_party/rust/naga/src/compact/statements.rs67
-rw-r--r--third_party/rust/naga/src/error.rs74
-rw-r--r--third_party/rust/naga/src/front/glsl/context.rs48
-rw-r--r--third_party/rust/naga/src/front/glsl/error.rs18
-rw-r--r--third_party/rust/naga/src/front/glsl/functions.rs10
-rw-r--r--third_party/rust/naga/src/front/glsl/mod.rs4
-rw-r--r--third_party/rust/naga/src/front/glsl/parser.rs17
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/declarations.rs9
-rw-r--r--third_party/rust/naga/src/front/glsl/parser/functions.rs9
-rw-r--r--third_party/rust/naga/src/front/glsl/parser_tests.rs22
-rw-r--r--third_party/rust/naga/src/front/glsl/types.rs17
-rw-r--r--third_party/rust/naga/src/front/glsl/variables.rs1
-rw-r--r--third_party/rust/naga/src/front/spv/convert.rs5
-rw-r--r--third_party/rust/naga/src/front/spv/error.rs10
-rw-r--r--third_party/rust/naga/src/front/spv/function.rs13
-rw-r--r--third_party/rust/naga/src/front/spv/image.rs13
-rw-r--r--third_party/rust/naga/src/front/spv/mod.rs455
-rw-r--r--third_party/rust/naga/src/front/spv/null.rs8
-rw-r--r--third_party/rust/naga/src/front/wgsl/error.rs29
-rw-r--r--third_party/rust/naga/src/front/wgsl/index.rs1
-rw-r--r--third_party/rust/naga/src/front/wgsl/lower/mod.rs414
-rw-r--r--third_party/rust/naga/src/front/wgsl/mod.rs11
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/ast.rs9
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/conv.rs28
-rw-r--r--third_party/rust/naga/src/front/wgsl/parse/mod.rs106
-rw-r--r--third_party/rust/naga/src/front/wgsl/to_wgsl.rs3
-rw-r--r--third_party/rust/naga/src/lib.rs174
-rw-r--r--third_party/rust/naga/src/proc/constant_evaluator.rs589
-rw-r--r--third_party/rust/naga/src/proc/index.rs4
-rw-r--r--third_party/rust/naga/src/proc/mod.rs105
-rw-r--r--third_party/rust/naga/src/proc/terminator.rs3
-rw-r--r--third_party/rust/naga/src/proc/typifier.rs8
-rw-r--r--third_party/rust/naga/src/span.rs12
-rw-r--r--third_party/rust/naga/src/valid/analyzer.rs53
-rw-r--r--third_party/rust/naga/src/valid/expression.rs39
-rw-r--r--third_party/rust/naga/src/valid/function.rs268
-rw-r--r--third_party/rust/naga/src/valid/handles.rs91
-rw-r--r--third_party/rust/naga/src/valid/interface.rs48
-rw-r--r--third_party/rust/naga/src/valid/mod.rs228
-rw-r--r--third_party/rust/naga/src/valid/type.rs3
63 files changed, 4815 insertions, 843 deletions
diff --git a/third_party/rust/naga/src/arena.rs b/third_party/rust/naga/src/arena.rs
index c37538667f..740df85b86 100644
--- a/third_party/rust/naga/src/arena.rs
+++ b/third_party/rust/naga/src/arena.rs
@@ -122,6 +122,7 @@ impl<T> Handle<T> {
serde(transparent)
)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(test, derive(PartialEq))]
pub struct Range<T> {
inner: ops::Range<u32>,
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
@@ -140,6 +141,7 @@ impl<T> Range<T> {
// NOTE: Keep this diagnostic in sync with that of [`BadHandle`].
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")]
pub struct BadRangeError {
// This error is used for many `Handle` types, but there's no point in making this generic, so
@@ -239,7 +241,7 @@ impl<T> Range<T> {
/// Adding new items to the arena produces a strongly-typed [`Handle`].
/// The arena can be indexed using the given handle to obtain
/// a reference to the stored item.
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "serialize", serde(transparent))]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
@@ -297,6 +299,17 @@ impl<T> Arena<T> {
.map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
}
+ /// Drains the arena, returning an iterator over the items stored.
+ pub fn drain(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, T, Span)> {
+ let arena = std::mem::take(self);
+ arena
+ .data
+ .into_iter()
+ .zip(arena.span_info)
+ .enumerate()
+ .map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) })
+ }
+
/// Returns a iterator over the items stored in this arena,
/// returning both the item's handle and a mutable reference to it.
pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {
@@ -531,7 +544,7 @@ mod tests {
///
/// `UniqueArena` is similar to [`Arena`]: If `Arena` is vector-like,
/// `UniqueArena` is `HashSet`-like.
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Clone)]
pub struct UniqueArena<T> {
set: FastIndexSet<T>,
diff --git a/third_party/rust/naga/src/back/dot/mod.rs b/third_party/rust/naga/src/back/dot/mod.rs
index 1556371df1..9a7702b3f6 100644
--- a/third_party/rust/naga/src/back/dot/mod.rs
+++ b/third_party/rust/naga/src/back/dot/mod.rs
@@ -279,6 +279,94 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
+ S::SubgroupBallot { result, predicate } => {
+ if let Some(predicate) = predicate {
+ self.dependencies.push((id, predicate, "predicate"));
+ }
+ self.emits.push((id, result));
+ "SubgroupBallot"
+ }
+ S::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ self.dependencies.push((id, argument, "arg"));
+ self.emits.push((id, result));
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ "SubgroupAll"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ "SubgroupAny"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ "SubgroupAdd"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ "SubgroupMul"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ "SubgroupMax"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ "SubgroupMin"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ "SubgroupAnd"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ "SubgroupOr"
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ "SubgroupXor"
+ }
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => "SubgroupExclusiveAdd",
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => "SubgroupExclusiveMul",
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => "SubgroupInclusiveAdd",
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => "SubgroupInclusiveMul",
+ _ => unimplemented!(),
+ }
+ }
+ S::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ self.dependencies.push((id, index, "index"))
+ }
+ }
+ self.dependencies.push((id, argument, "arg"));
+ self.emits.push((id, result));
+ match mode {
+ crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
+ crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
+ crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
+ crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
+ crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
+ crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
+ }
+ }
};
// Set the last node to the merge node
last_node = merge_id;
@@ -404,6 +492,7 @@ fn write_function_expressions(
let (label, color_id) = match *expression {
E::Literal(_) => ("Literal".into(), 2),
E::Constant(_) => ("Constant".into(), 2),
+ E::Override(_) => ("Override".into(), 2),
E::ZeroValue(_) => ("ZeroValue".into(), 2),
E::Compose { ref components, .. } => {
payload = Some(Payload::Arguments(components));
@@ -586,6 +675,8 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
+ E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
+ E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
};
// give uniform expressions an outline
diff --git a/third_party/rust/naga/src/back/glsl/features.rs b/third_party/rust/naga/src/back/glsl/features.rs
index 99c128c6d9..e5a43f3e02 100644
--- a/third_party/rust/naga/src/back/glsl/features.rs
+++ b/third_party/rust/naga/src/back/glsl/features.rs
@@ -50,6 +50,8 @@ bitflags::bitflags! {
const INSTANCE_INDEX = 1 << 22;
/// Sample specific LODs of cube / array shadow textures
const TEXTURE_SHADOW_LOD = 1 << 23;
+ /// Subgroup operations
+ const SUBGROUP_OPERATIONS = 1 << 24;
}
}
@@ -117,6 +119,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
+ check_feature!(SUBGROUP_OPERATIONS, 430, 310);
match version {
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
_ => check_feature!(MULTI_VIEW, 140, 310),
@@ -259,6 +262,22 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?;
}
+ if self.0.contains(Features::SUBGROUP_OPERATIONS) {
+ // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
+ writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
+ writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
+ writeln!(
+ out,
+ "#extension GL_KHR_shader_subgroup_arithmetic : require"
+ )?;
+ writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
+ writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
+ writeln!(
+ out,
+ "#extension GL_KHR_shader_subgroup_shuffle_relative : require"
+ )?;
+ }
+
Ok(())
}
}
@@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> {
}
}
}
+ Expression::SubgroupBallotResult |
+ Expression::SubgroupOperationResult { .. } => {
+ features.request(Features::SUBGROUP_OPERATIONS)
+ }
_ => {}
}
}
diff --git a/third_party/rust/naga/src/back/glsl/mod.rs b/third_party/rust/naga/src/back/glsl/mod.rs
index 9bda594610..c8c7ea557d 100644
--- a/third_party/rust/naga/src/back/glsl/mod.rs
+++ b/third_party/rust/naga/src/back/glsl/mod.rs
@@ -282,7 +282,7 @@ impl Default for Options {
}
/// A subset of options meant to be changed per pipeline.
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
@@ -497,6 +497,8 @@ pub enum Error {
ImageMultipleSamplers,
#[error("{0}")]
Custom(String),
+ #[error("overrides should not be present at this stage")]
+ Override,
}
/// Binary operation with a different logic on the GLSL side.
@@ -565,6 +567,10 @@ impl<'a, W: Write> Writer<'a, W> {
pipeline_options: &'a PipelineOptions,
policies: proc::BoundsCheckPolicies,
) -> Result<Self, Error> {
+ if !module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
// Check if the requested version is supported
if !options.version.is_supported() {
log::error!("Version {}", options.version);
@@ -2384,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
+ Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ write!(self.out, "subgroupBallot(")?;
+ match predicate {
+ Some(predicate) => self.write_expr(predicate, ctx)?,
+ None => write!(self.out, "true")?,
+ }
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "subgroupAll(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "subgroupAny(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupAdd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupMul(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "subgroupMax(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "subgroupMin(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "subgroupAnd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "subgroupOr(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "subgroupXor(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupExclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupExclusiveMul(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupInclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupInclusiveMul(")?
+ }
+ _ => unimplemented!(),
+ }
+ self.write_expr(argument, ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
+ self.write_value_type(res_ty)?;
+ write!(self.out, " {res_name} = ")?;
+ self.named_expressions.insert(result, res_name);
+
+ match mode {
+ crate::GatherMode::BroadcastFirst => {
+ write!(self.out, "subgroupBroadcastFirst(")?;
+ }
+ crate::GatherMode::Broadcast(_) => {
+ write!(self.out, "subgroupBroadcast(")?;
+ }
+ crate::GatherMode::Shuffle(_) => {
+ write!(self.out, "subgroupShuffle(")?;
+ }
+ crate::GatherMode::ShuffleDown(_) => {
+ write!(self.out, "subgroupShuffleDown(")?;
+ }
+ crate::GatherMode::ShuffleUp(_) => {
+ write!(self.out, "subgroupShuffleUp(")?;
+ }
+ crate::GatherMode::ShuffleXor(_) => {
+ write!(self.out, "subgroupShuffleXor(")?;
+ }
+ }
+ self.write_expr(argument, ctx)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, ", ")?;
+ self.write_expr(index, ctx)?;
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
Ok(())
@@ -2402,7 +2527,7 @@ impl<'a, W: Write> Writer<'a, W> {
fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult {
self.write_possibly_const_expr(
expr,
- &self.module.const_expressions,
+ &self.module.global_expressions,
|expr| &self.info[expr],
|writer, expr| writer.write_const_expr(expr),
)
@@ -2536,6 +2661,7 @@ impl<'a, W: Write> Writer<'a, W> {
|writer, expr| writer.write_expr(expr, ctx),
)?;
}
+ Expression::Override(_) => return Err(Error::Override),
// `Access` is applied to arrays, vectors and matrices and is written as indexing
Expression::Access { base, index } => {
self.write_expr(base, ctx)?;
@@ -3411,7 +3537,8 @@ impl<'a, W: Write> Writer<'a, W> {
let scalar_bits = ctx
.resolve_type(arg, &self.module.types)
.scalar_width()
- .unwrap();
+ .unwrap()
+ * 8;
write!(self.out, "bitfieldExtract(")?;
self.write_expr(arg, ctx)?;
@@ -3430,7 +3557,8 @@ impl<'a, W: Write> Writer<'a, W> {
let scalar_bits = ctx
.resolve_type(arg, &self.module.types)
.scalar_width()
- .unwrap();
+ .unwrap()
+ * 8;
write!(self.out, "bitfieldInsert(")?;
self.write_expr(arg, ctx)?;
@@ -3649,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
- | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
+ | Expression::WorkGroupUniformLoadResult { .. }
+ | Expression::SubgroupOperationResult { .. }
+ | Expression::SubgroupBallotResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
@@ -4218,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> {
if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}memoryBarrierShared();")?;
}
+ if flags.contains(crate::Barrier::SUB_GROUP) {
+ writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
+ }
writeln!(self.out, "{level}barrier();")?;
Ok(())
}
@@ -4487,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
+ // subgroup
+ Bi::NumSubgroups => "gl_NumSubgroups",
+ Bi::SubgroupId => "gl_SubgroupID",
+ Bi::SubgroupSize => "gl_SubgroupSize",
+ Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
}
}
diff --git a/third_party/rust/naga/src/back/hlsl/conv.rs b/third_party/rust/naga/src/back/hlsl/conv.rs
index 2a6db35db8..7d15f43f6c 100644
--- a/third_party/rust/naga/src/back/hlsl/conv.rs
+++ b/third_party/rust/naga/src/back/hlsl/conv.rs
@@ -179,6 +179,11 @@ impl crate::BuiltIn {
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
+ // These builtins map to functions
+ Self::SubgroupSize
+ | Self::SubgroupInvocationId
+ | Self::NumSubgroups
+ | Self::SubgroupId => unreachable!(),
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {self:?}")))
}
diff --git a/third_party/rust/naga/src/back/hlsl/help.rs b/third_party/rust/naga/src/back/hlsl/help.rs
index 4dd9ea5987..d3bb1ce7f5 100644
--- a/third_party/rust/naga/src/back/hlsl/help.rs
+++ b/third_party/rust/naga/src/back/hlsl/help.rs
@@ -70,6 +70,11 @@ pub(super) struct WrappedMath {
pub(super) components: Option<u32>,
}
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+pub(super) struct WrappedZeroValue {
+ pub(super) ty: Handle<crate::Type>,
+}
+
/// HLSL backend requires its own `ImageQuery` enum.
///
/// It is used inside `WrappedImageQuery` and should be unique per ImageQuery function.
@@ -359,7 +364,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
}
/// Helper function that write wrapped function for `Expression::Compose` for structures.
- pub(super) fn write_wrapped_constructor_function(
+ fn write_wrapped_constructor_function(
&mut self,
module: &crate::Module,
constructor: WrappedConstructor,
@@ -862,6 +867,25 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}
+ // TODO: we could merge this with iteration in write_wrapped_compose_functions...
+ //
+ /// Helper function that writes zero value wrapped functions
+ pub(super) fn write_wrapped_zero_value_functions(
+ &mut self,
+ module: &crate::Module,
+ expressions: &crate::Arena<crate::Expression>,
+ ) -> BackendResult {
+ for (handle, _) in expressions.iter() {
+ if let crate::Expression::ZeroValue(ty) = expressions[handle] {
+ let zero_value = WrappedZeroValue { ty };
+ if self.wrapped.zero_values.insert(zero_value) {
+ self.write_wrapped_zero_value_function(module, zero_value)?;
+ }
+ }
+ }
+ Ok(())
+ }
+
pub(super) fn write_wrapped_math_functions(
&mut self,
module: &crate::Module,
@@ -1006,6 +1030,7 @@ impl<'a, W: Write> super::Writer<'a, W> {
) -> BackendResult {
self.write_wrapped_math_functions(module, func_ctx)?;
self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
+ self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;
for (handle, _) in func_ctx.expressions.iter() {
match func_ctx.expressions[handle] {
@@ -1283,4 +1308,71 @@ impl<'a, W: Write> super::Writer<'a, W> {
Ok(())
}
+
+ pub(super) fn write_wrapped_zero_value_function_name(
+ &mut self,
+ module: &crate::Module,
+ zero_value: WrappedZeroValue,
+ ) -> BackendResult {
+ let name = crate::TypeInner::hlsl_type_id(zero_value.ty, module.to_ctx(), &self.names)?;
+ write!(self.out, "ZeroValue{name}")?;
+ Ok(())
+ }
+
+ /// Helper function that write wrapped function for `Expression::ZeroValue`
+ ///
+ /// This is necessary since we might have a member access after the zero value expression, e.g.
+ /// `.y` (in practice this can come up when consuming SPIRV that's been produced by glslc).
+ ///
+ /// So we can't just write `(float4)0` since `(float4)0.y` won't parse correctly.
+ ///
+ /// Parenthesizing the expression like `((float4)0).y` would work... except DXC can't handle
+ /// cases like:
+ ///
+ /// ```ignore
+ /// tests\out\hlsl\access.hlsl:183:41: error: cannot compile this l-value expression yet
+ /// t_1.am = (__mat4x2[2])((float4x2[2])0);
+ /// ^
+ /// ```
+ fn write_wrapped_zero_value_function(
+ &mut self,
+ module: &crate::Module,
+ zero_value: WrappedZeroValue,
+ ) -> BackendResult {
+ use crate::back::INDENT;
+
+ const RETURN_VARIABLE_NAME: &str = "ret";
+
+ // Write function return type and name
+ if let crate::TypeInner::Array { base, size, .. } = module.types[zero_value.ty].inner {
+ write!(self.out, "typedef ")?;
+ self.write_type(module, zero_value.ty)?;
+ write!(self.out, " ret_")?;
+ self.write_wrapped_zero_value_function_name(module, zero_value)?;
+ self.write_array_size(module, base, size)?;
+ writeln!(self.out, ";")?;
+
+ write!(self.out, "ret_")?;
+ self.write_wrapped_zero_value_function_name(module, zero_value)?;
+ } else {
+ self.write_type(module, zero_value.ty)?;
+ }
+ write!(self.out, " ")?;
+ self.write_wrapped_zero_value_function_name(module, zero_value)?;
+
+ // Write function parameters (none) and start function body
+ writeln!(self.out, "() {{")?;
+
+ // Write `ZeroValue` function.
+ write!(self.out, "{INDENT}return ")?;
+ self.write_default_init(module, zero_value.ty)?;
+ writeln!(self.out, ";")?;
+
+ // End of function body
+ writeln!(self.out, "}}")?;
+ // Write extra new line
+ writeln!(self.out)?;
+
+ Ok(())
+ }
}
diff --git a/third_party/rust/naga/src/back/hlsl/mod.rs b/third_party/rust/naga/src/back/hlsl/mod.rs
index f37a223f47..28edbf70e1 100644
--- a/third_party/rust/naga/src/back/hlsl/mod.rs
+++ b/third_party/rust/naga/src/back/hlsl/mod.rs
@@ -131,6 +131,13 @@ pub enum ShaderModel {
V5_0,
V5_1,
V6_0,
+ V6_1,
+ V6_2,
+ V6_3,
+ V6_4,
+ V6_5,
+ V6_6,
+ V6_7,
}
impl ShaderModel {
@@ -139,6 +146,13 @@ impl ShaderModel {
Self::V5_0 => "5_0",
Self::V5_1 => "5_1",
Self::V6_0 => "6_0",
+ Self::V6_1 => "6_1",
+ Self::V6_2 => "6_2",
+ Self::V6_3 => "6_3",
+ Self::V6_4 => "6_4",
+ Self::V6_5 => "6_5",
+ Self::V6_6 => "6_6",
+ Self::V6_7 => "6_7",
}
}
}
@@ -247,10 +261,13 @@ pub enum Error {
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Custom(String),
+ #[error("overrides should not be present at this stage")]
+ Override,
}
#[derive(Default)]
struct Wrapped {
+ zero_values: crate::FastHashSet<help::WrappedZeroValue>,
array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
image_queries: crate::FastHashSet<help::WrappedImageQuery>,
constructors: crate::FastHashSet<help::WrappedConstructor>,
diff --git a/third_party/rust/naga/src/back/hlsl/writer.rs b/third_party/rust/naga/src/back/hlsl/writer.rs
index 4ba856946b..86d8f89035 100644
--- a/third_party/rust/naga/src/back/hlsl/writer.rs
+++ b/third_party/rust/naga/src/back/hlsl/writer.rs
@@ -1,5 +1,8 @@
use super::{
- help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
+ help::{
+ WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
+ WrappedZeroValue,
+ },
storage::StoreValue,
BackendResult, Error, Options,
};
@@ -77,6 +80,19 @@ enum Io {
Output,
}
+const fn is_subgroup_builtin_binding(binding: &Option<crate::Binding>) -> bool {
+ let &Some(crate::Binding::BuiltIn(builtin)) = binding else {
+ return false;
+ };
+ matches!(
+ builtin,
+ crate::BuiltIn::SubgroupSize
+ | crate::BuiltIn::SubgroupInvocationId
+ | crate::BuiltIn::NumSubgroups
+ | crate::BuiltIn::SubgroupId
+ )
+}
+
impl<'a, W: fmt::Write> super::Writer<'a, W> {
pub fn new(out: W, options: &'a Options) -> Self {
Self {
@@ -161,6 +177,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
}
+ for statement in func.body.iter() {
+ match *statement {
+ crate::Statement::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: crate::CollectiveOperation::InclusiveScan,
+ argument,
+ result: _,
+ } => {
+ self.need_bake_expressions.insert(argument);
+ }
+ _ => {}
+ }
+ }
}
pub fn write(
@@ -168,6 +197,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
module: &Module,
module_info: &valid::ModuleInfo,
) -> Result<super::ReflectionInfo, Error> {
+ if !module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
self.reset(module);
// Write special constants, if needed
@@ -233,7 +266,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_special_functions(module)?;
- self.write_wrapped_compose_functions(module, &module.const_expressions)?;
+ self.write_wrapped_compose_functions(module, &module.global_expressions)?;
+ self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;
// Write all named constants
let mut constants = module
@@ -397,31 +431,32 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// if they are struct, so that the `stage` argument here could be omitted.
fn write_semantic(
&mut self,
- binding: &crate::Binding,
+ binding: &Option<crate::Binding>,
stage: Option<(ShaderStage, Io)>,
) -> BackendResult {
match *binding {
- crate::Binding::BuiltIn(builtin) => {
+ Some(crate::Binding::BuiltIn(builtin)) if !is_subgroup_builtin_binding(binding) => {
let builtin_str = builtin.to_hlsl_str()?;
write!(self.out, " : {builtin_str}")?;
}
- crate::Binding::Location {
+ Some(crate::Binding::Location {
second_blend_source: true,
..
- } => {
+ }) => {
write!(self.out, " : SV_Target1")?;
}
- crate::Binding::Location {
+ Some(crate::Binding::Location {
location,
second_blend_source: false,
..
- } => {
+ }) => {
if stage == Some((crate::ShaderStage::Fragment, Io::Output)) {
write!(self.out, " : SV_Target{location}")?;
} else {
write!(self.out, " : {LOCATION_SEMANTIC}{location}")?;
}
}
+ _ => {}
}
Ok(())
@@ -442,17 +477,30 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, "struct {struct_name}")?;
writeln!(self.out, " {{")?;
for m in members.iter() {
+ if is_subgroup_builtin_binding(&m.binding) {
+ continue;
+ }
write!(self.out, "{}", back::INDENT)?;
if let Some(ref binding) = m.binding {
self.write_modifier(binding)?;
}
self.write_type(module, m.ty)?;
write!(self.out, " {}", &m.name)?;
- if let Some(ref binding) = m.binding {
- self.write_semantic(binding, Some(shader_stage))?;
- }
+ self.write_semantic(&m.binding, Some(shader_stage))?;
writeln!(self.out, ";")?;
}
+ if members.iter().any(|arg| {
+ matches!(
+ arg.binding,
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId))
+ )
+ }) {
+ writeln!(
+ self.out,
+ "{}uint __local_invocation_index : SV_GroupIndex;",
+ back::INDENT
+ )?;
+ }
writeln!(self.out, "}};")?;
writeln!(self.out)?;
@@ -553,8 +601,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
/// Writes special interface structures for an entry point. The special structures have
- /// all the fields flattened into them and sorted by binding. They are only needed for
- /// VS outputs and FS inputs, so that these interfaces match.
+ /// all the fields flattened into them and sorted by binding. They are needed to emulate
+ /// subgroup built-ins and to make the interfaces between VS outputs and FS inputs match.
fn write_ep_interface(
&mut self,
module: &Module,
@@ -563,7 +611,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
ep_name: &str,
) -> Result<EntryPointInterface, Error> {
Ok(EntryPointInterface {
- input: if !func.arguments.is_empty() && stage == ShaderStage::Fragment {
+ input: if !func.arguments.is_empty()
+ && (stage == ShaderStage::Fragment
+ || func
+ .arguments
+ .iter()
+ .any(|arg| is_subgroup_builtin_binding(&arg.binding)))
+ {
Some(self.write_ep_input_struct(module, func, stage, ep_name)?)
} else {
None
@@ -577,6 +631,38 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
})
}
+ fn write_ep_argument_initialization(
+ &mut self,
+ ep: &crate::EntryPoint,
+ ep_input: &EntryPointBinding,
+ fake_member: &EpStructMember,
+ ) -> BackendResult {
+ match fake_member.binding {
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) => {
+ write!(self.out, "WaveGetLaneCount()")?
+ }
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupInvocationId)) => {
+ write!(self.out, "WaveGetLaneIndex()")?
+ }
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) => write!(
+ self.out,
+ "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount()",
+ ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2]
+ )?,
+ Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) => {
+ write!(
+ self.out,
+ "{}.__local_invocation_index / WaveGetLaneCount()",
+ ep_input.arg_name
+ )?;
+ }
+ _ => {
+ write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ }
+ }
+ Ok(())
+ }
+
/// Write an entry point preface that initializes the arguments as specified in IR.
fn write_ep_arguments_initialization(
&mut self,
@@ -584,6 +670,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
func: &crate::Function,
ep_index: u16,
) -> BackendResult {
+ let ep = &module.entry_points[ep_index as usize];
let ep_input = match self.entry_point_io[ep_index as usize].input.take() {
Some(ep_input) => ep_input,
None => return Ok(()),
@@ -597,8 +684,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
match module.types[arg.ty].inner {
TypeInner::Array { base, size, .. } => {
self.write_array_size(module, base, size)?;
- let fake_member = fake_iter.next().unwrap();
- writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ write!(self.out, " = ")?;
+ self.write_ep_argument_initialization(
+ ep,
+ &ep_input,
+ fake_iter.next().unwrap(),
+ )?;
+ writeln!(self.out, ";")?;
}
TypeInner::Struct { ref members, .. } => {
write!(self.out, " = {{ ")?;
@@ -606,14 +698,22 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if index != 0 {
write!(self.out, ", ")?;
}
- let fake_member = fake_iter.next().unwrap();
- write!(self.out, "{}.{}", ep_input.arg_name, fake_member.name)?;
+ self.write_ep_argument_initialization(
+ ep,
+ &ep_input,
+ fake_iter.next().unwrap(),
+ )?;
}
writeln!(self.out, " }};")?;
}
_ => {
- let fake_member = fake_iter.next().unwrap();
- writeln!(self.out, " = {}.{};", ep_input.arg_name, fake_member.name)?;
+ write!(self.out, " = ")?;
+ self.write_ep_argument_initialization(
+ ep,
+ &ep_input,
+ fake_iter.next().unwrap(),
+ )?;
+ writeln!(self.out, ";")?;
}
}
}
@@ -928,9 +1028,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}
- if let Some(ref binding) = member.binding {
- self.write_semantic(binding, shader_stage)?;
- };
+ self.write_semantic(&member.binding, shader_stage)?;
writeln!(self.out, ";")?;
}
@@ -1143,7 +1241,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
back::FunctionType::EntryPoint(ep_index) => {
if let Some(ref ep_input) = self.entry_point_io[ep_index as usize].input {
- write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?;
+ write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name)?;
} else {
let stage = module.entry_points[ep_index as usize].stage;
for (index, arg) in func.arguments.iter().enumerate() {
@@ -1160,17 +1258,16 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_array_size(module, base, size)?;
}
- if let Some(ref binding) = arg.binding {
- self.write_semantic(binding, Some((stage, Io::Input)))?;
- }
+ self.write_semantic(&arg.binding, Some((stage, Io::Input)))?;
}
-
- if need_workgroup_variables_initialization {
- if !func.arguments.is_empty() {
- write!(self.out, ", ")?;
- }
- write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
+ }
+ if need_workgroup_variables_initialization {
+ if self.entry_point_io[ep_index as usize].input.is_some()
+ || !func.arguments.is_empty()
+ {
+ write!(self.out, ", ")?;
}
+ write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
}
}
}
@@ -1180,11 +1277,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Write semantic if it present
if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
let stage = module.entry_points[index as usize].stage;
- if let Some(crate::FunctionResult {
- binding: Some(ref binding),
- ..
- }) = func.result
- {
+ if let Some(crate::FunctionResult { ref binding, .. }) = func.result {
self.write_semantic(binding, Some((stage, Io::Output)))?;
}
}
@@ -1984,6 +2077,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "{level}}}")?
}
Statement::RayQuery { .. } => unreachable!(),
+ Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ write!(self.out, "const uint4 {name} = ")?;
+ self.named_expressions.insert(result, name);
+
+ write!(self.out, "WaveActiveBallot(")?;
+ match predicate {
+ Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
+ None => write!(self.out, "true")?,
+ }
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(result, name);
+
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "WaveActiveAllTrue(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "WaveActiveAnyTrue(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "WaveActiveSum(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "WaveActiveProduct(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "WaveActiveMax(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "WaveActiveMin(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "WaveActiveBitAnd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "WaveActiveBitOr(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "WaveActiveBitXor(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "WavePrefixSum(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "WavePrefixProduct(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, " + WavePrefixSum(")?;
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, " * WavePrefixProduct(")?;
+ }
+ _ => unimplemented!(),
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ write!(self.out, "const ")?;
+ let name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ match func_ctx.info[result].ty {
+ proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?,
+ proc::TypeResolution::Value(ref value) => {
+ self.write_value_type(module, value)?
+ }
+ };
+ write!(self.out, " {name} = ")?;
+ self.named_expressions.insert(result, name);
+
+ if matches!(mode, crate::GatherMode::BroadcastFirst) {
+ write!(self.out, "WaveReadLaneFirst(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ } else {
+ write!(self.out, "WaveReadLaneAt(")?;
+ self.write_expr(module, argument, func_ctx)?;
+ write!(self.out, ", ")?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => unreachable!(),
+ crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
+ self.write_expr(module, index, func_ctx)?;
+ }
+ crate::GatherMode::ShuffleDown(index) => {
+ write!(self.out, "WaveGetLaneIndex() + ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ crate::GatherMode::ShuffleUp(index) => {
+ write!(self.out, "WaveGetLaneIndex() - ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, "WaveGetLaneIndex() ^ ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
Ok(())
@@ -1997,7 +2213,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_possibly_const_expression(
module,
expr,
- &module.const_expressions,
+ &module.global_expressions,
|writer, expr| writer.write_const_expression(module, expr),
)
}
@@ -2039,7 +2255,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_const_expression(module, constant.init)?;
}
}
- Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
+ Expression::ZeroValue(ty) => {
+ self.write_wrapped_zero_value_function_name(module, WrappedZeroValue { ty })?;
+ write!(self.out, "()")?;
+ }
Expression::Compose { ty, ref components } => {
match module.types[ty].inner {
TypeInner::Struct { .. } | TypeInner::Array { .. } => {
@@ -2140,6 +2359,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
+ Expression::Override(_) => return Err(Error::Override),
// All of the multiplication can be expressed as `mul`,
// except vector * vector, which needs to use the "*" operator.
Expression::Binary {
@@ -2588,7 +2808,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
true
}
None => {
- if inner.scalar_width() == Some(64) {
+ if inner.scalar_width() == Some(8) {
false
} else {
write!(self.out, "{}(", kind.to_hlsl_cast(),)?;
@@ -3129,7 +3349,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::WorkGroupUniformLoadResult { .. }
- | Expression::RayQueryProceedResult => {}
+ | Expression::RayQueryProceedResult
+ | Expression::SubgroupBallotResult
+ | Expression::SubgroupOperationResult { .. } => {}
}
if !closing_bracket.is_empty() {
@@ -3179,7 +3401,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
/// Helper function that write default zero initialization
- fn write_default_init(&mut self, module: &Module, ty: Handle<crate::Type>) -> BackendResult {
+ pub(super) fn write_default_init(
+ &mut self,
+ module: &Module,
+ ty: Handle<crate::Type>,
+ ) -> BackendResult {
write!(self.out, "(")?;
self.write_type(module, ty)?;
if let TypeInner::Array { base, size, .. } = module.types[ty].inner {
@@ -3196,6 +3422,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?;
}
+ if barrier.contains(crate::Barrier::SUB_GROUP) {
+ // Does not exist in DirectX
+ }
Ok(())
}
}
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs
index c8f091decb..0c9c5e4761 100644
--- a/third_party/rust/naga/src/back/mod.rs
+++ b/third_party/rust/naga/src/back/mod.rs
@@ -16,6 +16,14 @@ pub mod spv;
#[cfg(feature = "wgsl-out")]
pub mod wgsl;
+#[cfg(any(
+ feature = "hlsl-out",
+ feature = "msl-out",
+ feature = "spv-out",
+ feature = "glsl-out"
+))]
+pub mod pipeline_constants;
+
/// Names of vector components.
pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
/// Indent for backends.
@@ -26,6 +34,15 @@ pub const BAKE_PREFIX: &str = "_e";
/// Expressions that need baking.
pub type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>;
+/// Specifies the values of pipeline-overridable constants in the shader module.
+///
+/// If an `@id` attribute was specified on the declaration,
+/// the key must be the pipeline constant ID as a decimal ASCII number; if not,
+/// the key must be the constant's identifier name.
+///
+/// The value may represent any of WGSL's concrete scalar types.
+pub type PipelineConstants = std::collections::HashMap<String, f64>;
+
/// Indentation level.
#[derive(Clone, Copy)]
pub struct Level(pub usize);
diff --git a/third_party/rust/naga/src/back/msl/mod.rs b/third_party/rust/naga/src/back/msl/mod.rs
index 68e5b79906..8b03e20376 100644
--- a/third_party/rust/naga/src/back/msl/mod.rs
+++ b/third_party/rust/naga/src/back/msl/mod.rs
@@ -143,6 +143,8 @@ pub enum Error {
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
+ #[error("overrides should not be present at this stage")]
+ Override,
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
@@ -221,7 +223,7 @@ impl Default for Options {
}
/// A subset of options that are meant to be changed per pipeline.
-#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
@@ -434,6 +436,11 @@ impl ResolvedBinding {
Bi::WorkGroupId => "threadgroup_position_in_grid",
Bi::WorkGroupSize => "dispatch_threads_per_threadgroup",
Bi::NumWorkGroups => "threadgroups_per_grid",
+ // subgroup
+ Bi::NumSubgroups => "simdgroups_per_threadgroup",
+ Bi::SubgroupId => "simdgroup_index_in_threadgroup",
+ Bi::SubgroupSize => "threads_per_simdgroup",
+ Bi::SubgroupInvocationId => "thread_index_in_simdgroup",
Bi::CullDistance | Bi::ViewIndex => {
return Err(Error::UnsupportedBuiltIn(built_in))
}
@@ -536,3 +543,21 @@ fn test_error_size() {
use std::mem::size_of;
assert_eq!(size_of::<Error>(), 32);
}
+
+impl crate::AtomicFunction {
+ fn to_msl(self) -> Result<&'static str, Error> {
+ Ok(match self {
+ Self::Add => "fetch_add",
+ Self::Subtract => "fetch_sub",
+ Self::And => "fetch_and",
+ Self::InclusiveOr => "fetch_or",
+ Self::ExclusiveOr => "fetch_xor",
+ Self::Min => "fetch_min",
+ Self::Max => "fetch_max",
+ Self::Exchange { compare: None } => "exchange",
+ Self::Exchange { compare: Some(_) } => Err(Error::FeatureNotImplemented(
+ "atomic CompareExchange".to_string(),
+ ))?,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/back/msl/writer.rs b/third_party/rust/naga/src/back/msl/writer.rs
index 5227d8e7db..e250d0b72c 100644
--- a/third_party/rust/naga/src/back/msl/writer.rs
+++ b/third_party/rust/naga/src/back/msl/writer.rs
@@ -1131,21 +1131,10 @@ impl<W: Write> Writer<W> {
Ok(())
}
- fn put_atomic_fetch(
- &mut self,
- pointer: Handle<crate::Expression>,
- key: &str,
- value: Handle<crate::Expression>,
- context: &ExpressionContext,
- ) -> BackendResult {
- self.put_atomic_operation(pointer, "fetch_", key, value, context)
- }
-
fn put_atomic_operation(
&mut self,
pointer: Handle<crate::Expression>,
- key1: &str,
- key2: &str,
+ key: &str,
value: Handle<crate::Expression>,
context: &ExpressionContext,
) -> BackendResult {
@@ -1163,7 +1152,7 @@ impl<W: Write> Writer<W> {
write!(
self.out,
- "{NAMESPACE}::atomic_{key1}{key2}_explicit({ATOMIC_REFERENCE}"
+ "{NAMESPACE}::atomic_{key}_explicit({ATOMIC_REFERENCE}"
)?;
self.put_access_chain(pointer, policy, context)?;
write!(self.out, ", ")?;
@@ -1248,7 +1237,7 @@ impl<W: Write> Writer<W> {
) -> BackendResult {
self.put_possibly_const_expression(
expr_handle,
- &module.const_expressions,
+ &module.global_expressions,
module,
mod_info,
&(module, mod_info),
@@ -1431,6 +1420,7 @@ impl<W: Write> Writer<W> {
|writer, context, expr| writer.put_expression(expr, context, true),
)?;
}
+ crate::Expression::Override(_) => return Err(Error::Override),
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.
@@ -1944,7 +1934,7 @@ impl<W: Write> Writer<W> {
//
// extract_bits(e, min(offset, w), min(count, w - min(offset, w))))
- let scalar_bits = context.resolve_type(arg).scalar_width().unwrap();
+ let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::extract_bits(")?;
self.put_expression(arg, context, true)?;
@@ -1960,7 +1950,7 @@ impl<W: Write> Writer<W> {
//
// insertBits(e, newBits, min(offset, w), min(count, w - min(offset, w))))
- let scalar_bits = context.resolve_type(arg).scalar_width().unwrap();
+ let scalar_bits = context.resolve_type(arg).scalar_width().unwrap() * 8;
write!(self.out, "{NAMESPACE}::insert_bits(")?;
self.put_expression(arg, context, true)?;
@@ -2041,6 +2031,8 @@ impl<W: Write> Writer<W> {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
+ | crate::Expression::SubgroupBallotResult
+ | crate::Expression::SubgroupOperationResult { .. }
| crate::Expression::RayQueryProceedResult => {
unreachable!()
}
@@ -2994,43 +2986,8 @@ impl<W: Write> Writer<W> {
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_baking_expression(result, &context.expression, &res_name)?;
self.named_expressions.insert(result, res_name);
- match *fun {
- crate::AtomicFunction::Add => {
- self.put_atomic_fetch(pointer, "add", value, &context.expression)?;
- }
- crate::AtomicFunction::Subtract => {
- self.put_atomic_fetch(pointer, "sub", value, &context.expression)?;
- }
- crate::AtomicFunction::And => {
- self.put_atomic_fetch(pointer, "and", value, &context.expression)?;
- }
- crate::AtomicFunction::InclusiveOr => {
- self.put_atomic_fetch(pointer, "or", value, &context.expression)?;
- }
- crate::AtomicFunction::ExclusiveOr => {
- self.put_atomic_fetch(pointer, "xor", value, &context.expression)?;
- }
- crate::AtomicFunction::Min => {
- self.put_atomic_fetch(pointer, "min", value, &context.expression)?;
- }
- crate::AtomicFunction::Max => {
- self.put_atomic_fetch(pointer, "max", value, &context.expression)?;
- }
- crate::AtomicFunction::Exchange { compare: None } => {
- self.put_atomic_operation(
- pointer,
- "exchange",
- "",
- value,
- &context.expression,
- )?;
- }
- crate::AtomicFunction::Exchange { .. } => {
- return Err(Error::FeatureNotImplemented(
- "atomic CompareExchange".to_string(),
- ));
- }
- }
+ let fun_str = fun.to_msl()?;
+ self.put_atomic_operation(pointer, fun_str, value, &context.expression)?;
// done
writeln!(self.out, ";")?;
}
@@ -3144,6 +3101,121 @@ impl<W: Write> Writer<W> {
}
}
}
+ crate::Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?;
+ if let Some(predicate) = predicate {
+ self.put_expression(predicate, &context.expression, true)?;
+ } else {
+ write!(self.out, "true")?;
+ }
+ writeln!(self.out, "), 0, 0, 0);")?;
+ }
+ crate::Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "{NAMESPACE}::simd_all(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "{NAMESPACE}::simd_any(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "{NAMESPACE}::simd_sum(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "{NAMESPACE}::simd_product(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "{NAMESPACE}::simd_max(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "{NAMESPACE}::simd_min(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "{NAMESPACE}::simd_and(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "{NAMESPACE}::simd_or(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "{NAMESPACE}::simd_xor(")?
+ }
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?,
+ (
+ crate::CollectiveOperation::ExclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?,
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Add,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?,
+ (
+ crate::CollectiveOperation::InclusiveScan,
+ crate::SubgroupOperation::Mul,
+ ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?,
+ _ => unimplemented!(),
+ }
+ self.put_expression(argument, &context.expression, true)?;
+ writeln!(self.out, ");")?;
+ }
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let name = self.namer.call("");
+ self.start_baking_expression(result, &context.expression, &name)?;
+ self.named_expressions.insert(result, name);
+ match mode {
+ crate::GatherMode::BroadcastFirst => {
+ write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?;
+ }
+ crate::GatherMode::Broadcast(_) => {
+ write!(self.out, "{NAMESPACE}::simd_broadcast(")?;
+ }
+ crate::GatherMode::Shuffle(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle(")?;
+ }
+ crate::GatherMode::ShuffleDown(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?;
+ }
+ crate::GatherMode::ShuffleUp(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?;
+ }
+ crate::GatherMode::ShuffleXor(_) => {
+ write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
+ }
+ }
+ self.put_expression(argument, &context.expression, true)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, ", ")?;
+ self.put_expression(index, &context.expression, true)?;
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
}
@@ -3220,6 +3292,10 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
+ if !module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
self.names.clear();
self.namer.reset(
module,
@@ -4487,6 +4563,12 @@ impl<W: Write> Writer<W> {
"{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
)?;
}
+ if flags.contains(crate::Barrier::SUB_GROUP) {
+ writeln!(
+ self.out,
+ "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);",
+ )?;
+ }
Ok(())
}
}
@@ -4757,8 +4839,8 @@ fn test_stack_size() {
}
let stack_size = addresses_end - addresses_start;
// check the size (in debug only)
- // last observed macOS value: 19152 (CI)
- if !(9000..=20000).contains(&stack_size) {
+ // last observed macOS value: 22256 (CI)
+ if !(15000..=25000).contains(&stack_size) {
panic!("`put_block` stack size {stack_size} has changed!");
}
}
diff --git a/third_party/rust/naga/src/back/pipeline_constants.rs b/third_party/rust/naga/src/back/pipeline_constants.rs
new file mode 100644
index 0000000000..0dbe9cf4e8
--- /dev/null
+++ b/third_party/rust/naga/src/back/pipeline_constants.rs
@@ -0,0 +1,957 @@
+use super::PipelineConstants;
+use crate::{
+ proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
+ valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
+ Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
+ Span, Statement, TypeInner, WithSpan,
+};
+use std::{borrow::Cow, collections::HashSet, mem};
+use thiserror::Error;
+
+#[derive(Error, Debug, Clone)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum PipelineConstantError {
+ #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
+ MissingValue(String),
+ #[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")]
+ SrcNeedsToBeFinite,
+ #[error("Source f64 value doesn't fit in destination")]
+ DstRangeTooSmall,
+ #[error(transparent)]
+ ConstantEvaluatorError(#[from] ConstantEvaluatorError),
+ #[error(transparent)]
+ ValidationError(#[from] WithSpan<ValidationError>),
+}
+
+/// Replace all overrides in `module` with constants.
+///
+/// If no changes are needed, this just returns `Cow::Borrowed`
+/// references to `module` and `module_info`. Otherwise, it clones
+/// `module`, edits its [`global_expressions`] arena to contain only
+/// fully-evaluated expressions, and returns `Cow::Owned` values
+/// holding the simplified module and its validation results.
+///
+/// In either case, the module returned has an empty `overrides`
+/// arena, and the `global_expressions` arena contains only
+/// fully-evaluated expressions.
+///
+/// [`global_expressions`]: Module::global_expressions
+pub fn process_overrides<'a>(
+ module: &'a Module,
+ module_info: &'a ModuleInfo,
+ pipeline_constants: &PipelineConstants,
+) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
+ if module.overrides.is_empty() {
+ return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
+ }
+
+ let mut module = module.clone();
+
+ // A map from override handles to the handles of the constants
+ // we've replaced them with.
+ let mut override_map = Vec::with_capacity(module.overrides.len());
+
+ // A map from `module`'s original global expression handles to
+ // handles in the new, simplified global expression arena.
+ let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len());
+
+ // The set of constants whose initializer handles we've already
+ // updated to refer to the newly built global expression arena.
+ //
+ // All constants in `module` must have their `init` handles
+ // updated to point into the new, simplified global expression
+ // arena. Some of these we can most easily handle as a side effect
+ // during the simplification process, but we must handle the rest
+ // in a final fixup pass, guided by `adjusted_global_expressions`. We
+ // add their handles to this set, so that the final fixup step can
+ // leave them alone.
+ let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
+
+ let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
+
+ // An iterator through the original overrides table, consumed in
+ // approximate tandem with the global expressions.
+ let mut override_iter = module.overrides.drain();
+
+ // Do two things in tandem:
+ //
+ // - Rebuild the global expression arena from scratch, fully
+ // evaluating all expressions, and replacing each `Override`
+ // expression in `module.global_expressions` with a `Constant`
+ // expression.
+ //
+ // - Build a new `Constant` in `module.constants` to take the
+ // place of each `Override`.
+ //
+ // Build a map from old global expression handles to their
+ // fully-evaluated counterparts in `adjusted_global_expressions` as we
+ // go.
+ //
+ // Why in tandem? Overrides refer to expressions, and expressions
+ // refer to overrides, so we can't disentangle the two into
+ // separate phases. However, we can take advantage of the fact
+ // that the overrides and expressions must form a DAG, and work
+ // our way from the leaves to the roots, replacing and evaluating
+ // as we go.
+ //
+ // Although the two loops are nested, this is really two
+ // alternating phases: we adjust and evaluate constant expressions
+ // until we hit an `Override` expression, at which point we switch
+ // to building `Constant`s for `Overrides` until we've handled the
+ // one used by the expression. Then we switch back to processing
+ // expressions. Because we know they form a DAG, we know the
+ // `Override` expressions we encounter can only have initializers
+ // referring to global expressions we've already simplified.
+ for (old_h, expr, span) in module.global_expressions.drain() {
+ let mut expr = match expr {
+ Expression::Override(h) => {
+ let c_h = if let Some(new_h) = override_map.get(h.index()) {
+ *new_h
+ } else {
+ let mut new_h = None;
+ for entry in override_iter.by_ref() {
+ let stop = entry.0 == h;
+ new_h = Some(process_override(
+ entry,
+ pipeline_constants,
+ &mut module,
+ &mut override_map,
+ &adjusted_global_expressions,
+ &mut adjusted_constant_initializers,
+ &mut global_expression_kind_tracker,
+ )?);
+ if stop {
+ break;
+ }
+ }
+ new_h.unwrap()
+ };
+ Expression::Constant(c_h)
+ }
+ Expression::Constant(c_h) => {
+ if adjusted_constant_initializers.insert(c_h) {
+ let init = &mut module.constants[c_h].init;
+ *init = adjusted_global_expressions[init.index()];
+ }
+ expr
+ }
+ expr => expr,
+ };
+ let mut evaluator = ConstantEvaluator::for_wgsl_module(
+ &mut module,
+ &mut global_expression_kind_tracker,
+ false,
+ );
+ adjust_expr(&adjusted_global_expressions, &mut expr);
+ let h = evaluator.try_eval_and_append(expr, span)?;
+ debug_assert_eq!(old_h.index(), adjusted_global_expressions.len());
+ adjusted_global_expressions.push(h);
+ }
+
+ // Finish processing any overrides we didn't visit in the loop above.
+ for entry in override_iter {
+ process_override(
+ entry,
+ pipeline_constants,
+ &mut module,
+ &mut override_map,
+ &adjusted_global_expressions,
+ &mut adjusted_constant_initializers,
+ &mut global_expression_kind_tracker,
+ )?;
+ }
+
+ // Update the initialization expression handles of all `Constant`s
+ // and `GlobalVariable`s. Skip `Constant`s we'd already updated en
+ // passant.
+ for (_, c) in module
+ .constants
+ .iter_mut()
+ .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
+ {
+ c.init = adjusted_global_expressions[c.init.index()];
+ }
+
+ for (_, v) in module.global_variables.iter_mut() {
+ if let Some(ref mut init) = v.init {
+ *init = adjusted_global_expressions[init.index()];
+ }
+ }
+
+ let mut functions = mem::take(&mut module.functions);
+ for (_, function) in functions.iter_mut() {
+ process_function(&mut module, &override_map, function)?;
+ }
+ module.functions = functions;
+
+ let mut entry_points = mem::take(&mut module.entry_points);
+ for ep in entry_points.iter_mut() {
+ process_function(&mut module, &override_map, &mut ep.function)?;
+ }
+ module.entry_points = entry_points;
+
+ // Now that we've rewritten all the expressions, we need to
+ // recompute their types and other metadata. For the time being,
+ // do a full re-validation.
+ let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
+ let module_info = validator.validate_no_overrides(&module)?;
+
+ Ok((Cow::Owned(module), Cow::Owned(module_info)))
+}
+
+/// Add a [`Constant`] to `module` for the override `old_h`.
+///
+/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
+fn process_override(
+ (old_h, override_, span): (Handle<Override>, Override, Span),
+ pipeline_constants: &PipelineConstants,
+ module: &mut Module,
+ override_map: &mut Vec<Handle<Constant>>,
+ adjusted_global_expressions: &[Handle<Expression>],
+ adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
+) -> Result<Handle<Constant>, PipelineConstantError> {
+ // Determine which key to use for `override_` in `pipeline_constants`.
+ let key = if let Some(id) = override_.id {
+ Cow::Owned(id.to_string())
+ } else if let Some(ref name) = override_.name {
+ Cow::Borrowed(name)
+ } else {
+ unreachable!();
+ };
+
+ // Generate a global expression for `override_`'s value, either
+ // from the provided `pipeline_constants` table or its initializer
+ // in the module.
+ let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
+ let literal = match module.types[override_.ty].inner {
+ TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
+ _ => unreachable!(),
+ };
+ let expr = module
+ .global_expressions
+ .append(Expression::Literal(literal), Span::UNDEFINED);
+ global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
+ expr
+ } else if let Some(init) = override_.init {
+ adjusted_global_expressions[init.index()]
+ } else {
+ return Err(PipelineConstantError::MissingValue(key.to_string()));
+ };
+
+ // Generate a new `Constant` to represent the override's value.
+ let constant = Constant {
+ name: override_.name,
+ ty: override_.ty,
+ init,
+ };
+ let h = module.constants.append(constant, span);
+ debug_assert_eq!(old_h.index(), override_map.len());
+ override_map.push(h);
+ adjusted_constant_initializers.insert(h);
+ Ok(h)
+}
+
+/// Replace all override expressions in `function` with fully-evaluated constants.
+///
+/// Replace all `Expression::Override`s in `function`'s expression arena with
+/// the corresponding `Expression::Constant`s, as given in `override_map`.
+/// Replace any expressions whose values are now known with their fully
+/// evaluated form.
+///
+/// If `h` is a `Handle<Override>`, then `override_map[h.index()]` is the
+/// `Handle<Constant>` for the override's final value.
+fn process_function(
+ module: &mut Module,
+ override_map: &[Handle<Constant>],
+ function: &mut Function,
+) -> Result<(), ConstantEvaluatorError> {
+ // A map from original local expression handles to
+ // handles in the new, local expression arena.
+ let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len());
+
+ let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
+
+ let mut expressions = mem::take(&mut function.expressions);
+
+ // Dummy `emitter` and `block` for the constant evaluator.
+ // We can ignore the concept of emitting expressions here since
+ // expressions have already been covered by a `Statement::Emit`
+ // in the frontend.
+ // The only thing we might have to do is remove some expressions
+ // that have been covered by a `Statement::Emit`. See the docs of
+ // `filter_emits_in_block` for the reasoning.
+ let mut emitter = Emitter::default();
+ let mut block = Block::new();
+
+ let mut evaluator = ConstantEvaluator::for_wgsl_function(
+ module,
+ &mut function.expressions,
+ &mut local_expression_kind_tracker,
+ &mut emitter,
+ &mut block,
+ );
+
+ for (old_h, mut expr, span) in expressions.drain() {
+ if let Expression::Override(h) = expr {
+ expr = Expression::Constant(override_map[h.index()]);
+ }
+ adjust_expr(&adjusted_local_expressions, &mut expr);
+ let h = evaluator.try_eval_and_append(expr, span)?;
+ debug_assert_eq!(old_h.index(), adjusted_local_expressions.len());
+ adjusted_local_expressions.push(h);
+ }
+
+ adjust_block(&adjusted_local_expressions, &mut function.body);
+
+ filter_emits_in_block(&mut function.body, &function.expressions);
+
+ // Update local expression initializers.
+ for (_, local) in function.local_variables.iter_mut() {
+ if let &mut Some(ref mut init) = &mut local.init {
+ *init = adjusted_local_expressions[init.index()];
+ }
+ }
+
+ // We've changed the keys of `function.named_expression`, so we have to
+ // rebuild it from scratch.
+ let named_expressions = mem::take(&mut function.named_expressions);
+ for (expr_h, name) in named_expressions {
+ function
+ .named_expressions
+ .insert(adjusted_local_expressions[expr_h.index()], name);
+ }
+
+ Ok(())
+}
+
+/// Replace every expression handle in `expr` with its counterpart
+/// given by `new_pos`.
+fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
+ let adjust = |expr: &mut Handle<Expression>| {
+ *expr = new_pos[expr.index()];
+ };
+ match *expr {
+ Expression::Compose {
+ ref mut components,
+ ty: _,
+ } => {
+ for c in components.iter_mut() {
+ adjust(c);
+ }
+ }
+ Expression::Access {
+ ref mut base,
+ ref mut index,
+ } => {
+ adjust(base);
+ adjust(index);
+ }
+ Expression::AccessIndex {
+ ref mut base,
+ index: _,
+ } => {
+ adjust(base);
+ }
+ Expression::Splat {
+ ref mut value,
+ size: _,
+ } => {
+ adjust(value);
+ }
+ Expression::Swizzle {
+ ref mut vector,
+ size: _,
+ pattern: _,
+ } => {
+ adjust(vector);
+ }
+ Expression::Load { ref mut pointer } => {
+ adjust(pointer);
+ }
+ Expression::ImageSample {
+ ref mut image,
+ ref mut sampler,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut offset,
+ ref mut level,
+ ref mut depth_ref,
+ gather: _,
+ } => {
+ adjust(image);
+ adjust(sampler);
+ adjust(coordinate);
+ if let Some(e) = array_index.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = offset.as_mut() {
+ adjust(e);
+ }
+ match *level {
+ crate::SampleLevel::Exact(ref mut expr)
+ | crate::SampleLevel::Bias(ref mut expr) => {
+ adjust(expr);
+ }
+ crate::SampleLevel::Gradient {
+ ref mut x,
+ ref mut y,
+ } => {
+ adjust(x);
+ adjust(y);
+ }
+ _ => {}
+ }
+ if let Some(e) = depth_ref.as_mut() {
+ adjust(e);
+ }
+ }
+ Expression::ImageLoad {
+ ref mut image,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut sample,
+ ref mut level,
+ } => {
+ adjust(image);
+ adjust(coordinate);
+ if let Some(e) = array_index.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = sample.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = level.as_mut() {
+ adjust(e);
+ }
+ }
+ Expression::ImageQuery {
+ ref mut image,
+ ref mut query,
+ } => {
+ adjust(image);
+ match *query {
+ crate::ImageQuery::Size { ref mut level } => {
+ if let Some(e) = level.as_mut() {
+ adjust(e);
+ }
+ }
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => {}
+ }
+ }
+ Expression::Unary {
+ ref mut expr,
+ op: _,
+ } => {
+ adjust(expr);
+ }
+ Expression::Binary {
+ ref mut left,
+ ref mut right,
+ op: _,
+ } => {
+ adjust(left);
+ adjust(right);
+ }
+ Expression::Select {
+ ref mut condition,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ adjust(condition);
+ adjust(accept);
+ adjust(reject);
+ }
+ Expression::Derivative {
+ ref mut expr,
+ axis: _,
+ ctrl: _,
+ } => {
+ adjust(expr);
+ }
+ Expression::Relational {
+ ref mut argument,
+ fun: _,
+ } => {
+ adjust(argument);
+ }
+ Expression::Math {
+ ref mut arg,
+ ref mut arg1,
+ ref mut arg2,
+ ref mut arg3,
+ fun: _,
+ } => {
+ adjust(arg);
+ if let Some(e) = arg1.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = arg2.as_mut() {
+ adjust(e);
+ }
+ if let Some(e) = arg3.as_mut() {
+ adjust(e);
+ }
+ }
+ Expression::As {
+ ref mut expr,
+ kind: _,
+ convert: _,
+ } => {
+ adjust(expr);
+ }
+ Expression::ArrayLength(ref mut expr) => {
+ adjust(expr);
+ }
+ Expression::RayQueryGetIntersection {
+ ref mut query,
+ committed: _,
+ } => {
+ adjust(query);
+ }
+ Expression::Literal(_)
+ | Expression::FunctionArgument(_)
+ | Expression::GlobalVariable(_)
+ | Expression::LocalVariable(_)
+ | Expression::CallResult(_)
+ | Expression::RayQueryProceedResult
+ | Expression::Constant(_)
+ | Expression::Override(_)
+ | Expression::ZeroValue(_)
+ | Expression::AtomicResult {
+ ty: _,
+ comparison: _,
+ }
+ | Expression::WorkGroupUniformLoadResult { ty: _ }
+ | Expression::SubgroupBallotResult
+ | Expression::SubgroupOperationResult { .. } => {}
+ }
+}
+
+/// Replace every expression handle in `block` with its counterpart
+/// given by `new_pos`.
+fn adjust_block(new_pos: &[Handle<Expression>], block: &mut Block) {
+ for stmt in block.iter_mut() {
+ adjust_stmt(new_pos, stmt);
+ }
+}
+
+/// Replace every expression handle in `stmt` with its counterpart
+/// given by `new_pos`.
+fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
+ let adjust = |expr: &mut Handle<Expression>| {
+ *expr = new_pos[expr.index()];
+ };
+ match *stmt {
+ Statement::Emit(ref mut range) => {
+ if let Some((mut first, mut last)) = range.first_and_last() {
+ adjust(&mut first);
+ adjust(&mut last);
+ *range = Range::new_from_bounds(first, last);
+ }
+ }
+ Statement::Block(ref mut block) => {
+ adjust_block(new_pos, block);
+ }
+ Statement::If {
+ ref mut condition,
+ ref mut accept,
+ ref mut reject,
+ } => {
+ adjust(condition);
+ adjust_block(new_pos, accept);
+ adjust_block(new_pos, reject);
+ }
+ Statement::Switch {
+ ref mut selector,
+ ref mut cases,
+ } => {
+ adjust(selector);
+ for case in cases.iter_mut() {
+ adjust_block(new_pos, &mut case.body);
+ }
+ }
+ Statement::Loop {
+ ref mut body,
+ ref mut continuing,
+ ref mut break_if,
+ } => {
+ adjust_block(new_pos, body);
+ adjust_block(new_pos, continuing);
+ if let Some(e) = break_if.as_mut() {
+ adjust(e);
+ }
+ }
+ Statement::Return { ref mut value } => {
+ if let Some(e) = value.as_mut() {
+ adjust(e);
+ }
+ }
+ Statement::Store {
+ ref mut pointer,
+ ref mut value,
+ } => {
+ adjust(pointer);
+ adjust(value);
+ }
+ Statement::ImageStore {
+ ref mut image,
+ ref mut coordinate,
+ ref mut array_index,
+ ref mut value,
+ } => {
+ adjust(image);
+ adjust(coordinate);
+ if let Some(e) = array_index.as_mut() {
+ adjust(e);
+ }
+ adjust(value);
+ }
+ crate::Statement::Atomic {
+ ref mut pointer,
+ ref mut value,
+ ref mut result,
+ ref mut fun,
+ } => {
+ adjust(pointer);
+ adjust(value);
+ adjust(result);
+ match *fun {
+ crate::AtomicFunction::Exchange {
+ compare: Some(ref mut compare),
+ } => {
+ adjust(compare);
+ }
+ crate::AtomicFunction::Add
+ | crate::AtomicFunction::Subtract
+ | crate::AtomicFunction::And
+ | crate::AtomicFunction::ExclusiveOr
+ | crate::AtomicFunction::InclusiveOr
+ | crate::AtomicFunction::Min
+ | crate::AtomicFunction::Max
+ | crate::AtomicFunction::Exchange { compare: None } => {}
+ }
+ }
+ Statement::WorkGroupUniformLoad {
+ ref mut pointer,
+ ref mut result,
+ } => {
+ adjust(pointer);
+ adjust(result);
+ }
+ Statement::SubgroupBallot {
+ ref mut result,
+ ref mut predicate,
+ } => {
+ if let Some(ref mut predicate) = *predicate {
+ adjust(predicate);
+ }
+ adjust(result);
+ }
+ Statement::SubgroupCollectiveOperation {
+ ref mut argument,
+ ref mut result,
+ ..
+ } => {
+ adjust(argument);
+ adjust(result);
+ }
+ Statement::SubgroupGather {
+ ref mut mode,
+ ref mut argument,
+ ref mut result,
+ } => {
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(ref mut index)
+ | crate::GatherMode::Shuffle(ref mut index)
+ | crate::GatherMode::ShuffleDown(ref mut index)
+ | crate::GatherMode::ShuffleUp(ref mut index)
+ | crate::GatherMode::ShuffleXor(ref mut index) => {
+ adjust(index);
+ }
+ }
+ adjust(argument);
+ adjust(result)
+ }
+ Statement::Call {
+ ref mut arguments,
+ ref mut result,
+ function: _,
+ } => {
+ for argument in arguments.iter_mut() {
+ adjust(argument);
+ }
+ if let Some(e) = result.as_mut() {
+ adjust(e);
+ }
+ }
+ Statement::RayQuery {
+ ref mut query,
+ ref mut fun,
+ } => {
+ adjust(query);
+ match *fun {
+ crate::RayQueryFunction::Initialize {
+ ref mut acceleration_structure,
+ ref mut descriptor,
+ } => {
+ adjust(acceleration_structure);
+ adjust(descriptor);
+ }
+ crate::RayQueryFunction::Proceed { ref mut result } => {
+ adjust(result);
+ }
+ crate::RayQueryFunction::Terminate => {}
+ }
+ }
+ Statement::Break | Statement::Continue | Statement::Kill | Statement::Barrier(_) => {}
+ }
+}
+
+/// Adjust [`Emit`] statements in `block` to skip [`needs_pre_emit`] expressions we have introduced.
+///
+/// According to validation, [`Emit`] statements must not cover any expressions
+/// for which [`Expression::needs_pre_emit`] returns true. All expressions built
+/// by successful constant evaluation fall into that category, meaning that
+/// `process_function` will usually rewrite [`Override`] expressions and those
+/// that use their values into pre-emitted expressions, leaving any [`Emit`]
+/// statements that cover them invalid.
+///
+/// This function rewrites all [`Emit`] statements into zero or more new
+/// [`Emit`] statements covering only those expressions in the original range
+/// that are not pre-emitted.
+///
+/// [`Emit`]: Statement::Emit
+/// [`needs_pre_emit`]: Expression::needs_pre_emit
+/// [`Override`]: Expression::Override
+fn filter_emits_in_block(block: &mut Block, expressions: &Arena<Expression>) {
+ let original = std::mem::replace(block, Block::with_capacity(block.len()));
+ for (stmt, span) in original.span_into_iter() {
+ match stmt {
+ Statement::Emit(range) => {
+ let mut current = None;
+ for expr_h in range {
+ if expressions[expr_h].needs_pre_emit() {
+ if let Some((first, last)) = current {
+ block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
+ }
+
+ current = None;
+ } else if let Some((_, ref mut last)) = current {
+ *last = expr_h;
+ } else {
+ current = Some((expr_h, expr_h));
+ }
+ }
+ if let Some((first, last)) = current {
+ block.push(Statement::Emit(Range::new_from_bounds(first, last)), span);
+ }
+ }
+ Statement::Block(mut child) => {
+ filter_emits_in_block(&mut child, expressions);
+ block.push(Statement::Block(child), span);
+ }
+ Statement::If {
+ condition,
+ mut accept,
+ mut reject,
+ } => {
+ filter_emits_in_block(&mut accept, expressions);
+ filter_emits_in_block(&mut reject, expressions);
+ block.push(
+ Statement::If {
+ condition,
+ accept,
+ reject,
+ },
+ span,
+ );
+ }
+ Statement::Switch {
+ selector,
+ mut cases,
+ } => {
+ for case in &mut cases {
+ filter_emits_in_block(&mut case.body, expressions);
+ }
+ block.push(Statement::Switch { selector, cases }, span);
+ }
+ Statement::Loop {
+ mut body,
+ mut continuing,
+ break_if,
+ } => {
+ filter_emits_in_block(&mut body, expressions);
+ filter_emits_in_block(&mut continuing, expressions);
+ block.push(
+ Statement::Loop {
+ body,
+ continuing,
+ break_if,
+ },
+ span,
+ );
+ }
+ stmt => block.push(stmt.clone(), span),
+ }
+ }
+}
+
+fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
+ // note that in rust 0.0 == -0.0
+ match scalar {
+ Scalar::BOOL => {
+ // https://webidl.spec.whatwg.org/#js-boolean
+ let value = value != 0.0 && !value.is_nan();
+ Ok(Literal::Bool(value))
+ }
+ Scalar::I32 => {
+ // https://webidl.spec.whatwg.org/#js-long
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ let value = value.trunc();
+ if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
+ return Err(PipelineConstantError::DstRangeTooSmall);
+ }
+
+ let value = value as i32;
+ Ok(Literal::I32(value))
+ }
+ Scalar::U32 => {
+ // https://webidl.spec.whatwg.org/#js-unsigned-long
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ let value = value.trunc();
+ if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
+ return Err(PipelineConstantError::DstRangeTooSmall);
+ }
+
+ let value = value as u32;
+ Ok(Literal::U32(value))
+ }
+ Scalar::F32 => {
+ // https://webidl.spec.whatwg.org/#js-float
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ let value = value as f32;
+ if !value.is_finite() {
+ return Err(PipelineConstantError::DstRangeTooSmall);
+ }
+
+ Ok(Literal::F32(value))
+ }
+ Scalar::F64 => {
+ // https://webidl.spec.whatwg.org/#js-double
+ if !value.is_finite() {
+ return Err(PipelineConstantError::SrcNeedsToBeFinite);
+ }
+
+ Ok(Literal::F64(value))
+ }
+ _ => unreachable!(),
+ }
+}
+
+#[test]
+fn test_map_value_to_literal() {
+ let bool_test_cases = [
+ (0.0, false),
+ (-0.0, false),
+ (f64::NAN, false),
+ (1.0, true),
+ (f64::INFINITY, true),
+ (f64::NEG_INFINITY, true),
+ ];
+ for (value, out) in bool_test_cases {
+ let res = Ok(Literal::Bool(out));
+ assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
+ }
+
+ for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
+ for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
+ let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
+ assert_eq!(map_value_to_literal(value, scalar), res);
+ }
+ }
+
+ // i32
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
+ Ok(Literal::I32(i32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
+ Ok(Literal::I32(i32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+
+ // u32
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
+ Ok(Literal::U32(u32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
+ Ok(Literal::U32(u32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+
+ // f32
+ assert_eq!(
+ map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
+ Ok(Literal::F32(f32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
+ Ok(Literal::F32(f32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
+ Ok(Literal::F32(f32::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
+ Ok(Literal::F32(f32::MAX))
+ );
+ assert_eq!(
+ map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+ assert_eq!(
+ map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
+ Err(PipelineConstantError::DstRangeTooSmall)
+ );
+
+ // f64
+ assert_eq!(
+ map_value_to_literal(f64::MIN, Scalar::F64),
+ Ok(Literal::F64(f64::MIN))
+ );
+ assert_eq!(
+ map_value_to_literal(f64::MAX, Scalar::F64),
+ Ok(Literal::F64(f64::MAX))
+ );
+}
diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs
index 81f2fc10e0..120d60fc40 100644
--- a/third_party/rust/naga/src/back/spv/block.rs
+++ b/third_party/rust/naga/src/back/spv/block.rs
@@ -239,6 +239,7 @@ impl<'w> BlockContext<'w> {
let init = self.ir_module.constants[handle].init;
self.writer.constant_ids[init.index()]
}
+ crate::Expression::Override(_) => return Err(Error::Override),
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();
@@ -1072,7 +1073,7 @@ impl<'w> BlockContext<'w> {
//
// bitfieldExtract(x, o, c)
- let bit_width = arg_ty.scalar_width().unwrap();
+ let bit_width = arg_ty.scalar_width().unwrap() * 8;
let width_constant = self
.writer
.get_constant_scalar(crate::Literal::U32(bit_width as u32));
@@ -1128,7 +1129,7 @@ impl<'w> BlockContext<'w> {
Mf::InsertBits => {
// The behavior of InsertBits has the same undefined behavior as ExtractBits.
- let bit_width = arg_ty.scalar_width().unwrap();
+ let bit_width = arg_ty.scalar_width().unwrap() * 8;
let width_constant = self
.writer
.get_constant_scalar(crate::Literal::U32(bit_width as u32));
@@ -1184,7 +1185,7 @@ impl<'w> BlockContext<'w> {
}
Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
Mf::FindMsb => {
- if arg_ty.scalar_width() == Some(32) {
+ if arg_ty.scalar_width() == Some(4) {
let thing = match arg_scalar_kind {
Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
@@ -1278,7 +1279,9 @@ impl<'w> BlockContext<'w> {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
- | crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
+ | crate::Expression::RayQueryProceedResult
+ | crate::Expression::SubgroupBallotResult
+ | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle],
crate::Expression::As {
expr,
kind,
@@ -2489,6 +2492,27 @@ impl<'w> BlockContext<'w> {
crate::Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);
}
+ crate::Statement::SubgroupBallot {
+ result,
+ ref predicate,
+ } => {
+ self.write_subgroup_ballot(predicate, result, &mut block)?;
+ }
+ crate::Statement::SubgroupCollectiveOperation {
+ ref op,
+ ref collective_op,
+ argument,
+ result,
+ } => {
+ self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?;
+ }
+ crate::Statement::SubgroupGather {
+ ref mode,
+ argument,
+ result,
+ } => {
+ self.write_subgroup_gather(mode, argument, result, &mut block)?;
+ }
}
}
diff --git a/third_party/rust/naga/src/back/spv/helpers.rs b/third_party/rust/naga/src/back/spv/helpers.rs
index 5b6226db85..1fb447e384 100644
--- a/third_party/rust/naga/src/back/spv/helpers.rs
+++ b/third_party/rust/naga/src/back/spv/helpers.rs
@@ -10,8 +10,12 @@ pub(super) fn bytes_to_words(bytes: &[u8]) -> Vec<Word> {
pub(super) fn string_to_words(input: &str) -> Vec<Word> {
let bytes = input.as_bytes();
- let mut words = bytes_to_words(bytes);
+ str_bytes_to_words(bytes)
+}
+
+pub(super) fn str_bytes_to_words(bytes: &[u8]) -> Vec<Word> {
+ let mut words = bytes_to_words(bytes);
if bytes.len() % 4 == 0 {
// nul-termination
words.push(0x0u32);
@@ -20,6 +24,21 @@ pub(super) fn string_to_words(input: &str) -> Vec<Word> {
words
}
+/// split a string into chunks and keep utf8 valid
+#[allow(unstable_name_collisions)]
+pub(super) fn string_to_byte_chunks(input: &str, limit: usize) -> Vec<&[u8]> {
+ let mut offset: usize = 0;
+ let mut start: usize = 0;
+ let mut words = vec![];
+ while offset < input.len() {
+ offset = input.floor_char_boundary(offset + limit);
+ words.push(input[start..offset].as_bytes());
+ start = offset;
+ }
+
+ words
+}
+
pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::StorageClass {
match space {
crate::AddressSpace::Handle => spirv::StorageClass::UniformConstant,
@@ -107,3 +126,35 @@ pub fn global_needs_wrapper(ir_module: &crate::Module, var: &crate::GlobalVariab
_ => true,
}
}
+
+///HACK: this is taken from std unstable, remove it when std's floor_char_boundary is stable
+trait U8Internal {
+ fn is_utf8_char_boundary(&self) -> bool;
+}
+
+impl U8Internal for u8 {
+ fn is_utf8_char_boundary(&self) -> bool {
+ // This is bit magic equivalent to: b < 128 || b >= 192
+ (*self as i8) >= -0x40
+ }
+}
+
+trait StrUnstable {
+ fn floor_char_boundary(&self, index: usize) -> usize;
+}
+
+impl StrUnstable for str {
+ fn floor_char_boundary(&self, index: usize) -> usize {
+ if index >= self.len() {
+ self.len()
+ } else {
+ let lower_bound = index.saturating_sub(3);
+ let new_index = self.as_bytes()[lower_bound..=index]
+ .iter()
+ .rposition(|b| b.is_utf8_char_boundary());
+
+ // SAFETY: we know that the character boundary will be within four bytes
+ unsafe { lower_bound + new_index.unwrap_unchecked() }
+ }
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/instructions.rs b/third_party/rust/naga/src/back/spv/instructions.rs
index b963793ad3..df2774ab9c 100644
--- a/third_party/rust/naga/src/back/spv/instructions.rs
+++ b/third_party/rust/naga/src/back/spv/instructions.rs
@@ -43,6 +43,42 @@ impl super::Instruction {
instruction
}
+ pub(super) fn source_continued(source: &[u8]) -> Self {
+ let mut instruction = Self::new(Op::SourceContinued);
+ instruction.add_operands(helpers::str_bytes_to_words(source));
+ instruction
+ }
+
+ pub(super) fn source_auto_continued(
+ source_language: spirv::SourceLanguage,
+ version: u32,
+ source: &Option<DebugInfoInner>,
+ ) -> Vec<Self> {
+ let mut instructions = vec![];
+
+ let with_continue = source.as_ref().and_then(|debug_info| {
+ (debug_info.source_code.len() > u16::MAX as usize).then_some(debug_info)
+ });
+ if let Some(debug_info) = with_continue {
+ let mut instruction = Self::new(Op::Source);
+ instruction.add_operand(source_language as u32);
+ instruction.add_operands(helpers::bytes_to_words(&version.to_le_bytes()));
+
+ let words = helpers::string_to_byte_chunks(debug_info.source_code, u16::MAX as usize);
+ instruction.add_operand(debug_info.source_file_id);
+ instruction.add_operands(helpers::str_bytes_to_words(words[0]));
+ instructions.push(instruction);
+ for word_bytes in words[1..].iter() {
+ let instruction_continue = Self::source_continued(word_bytes);
+ instructions.push(instruction_continue);
+ }
+ } else {
+ let instruction = Self::source(source_language, version, source);
+ instructions.push(instruction);
+ }
+ instructions
+ }
+
pub(super) fn name(target_id: Word, name: &str) -> Self {
let mut instruction = Self::new(Op::Name);
instruction.add_operand(target_id);
@@ -1037,6 +1073,73 @@ impl super::Instruction {
instruction.add_operand(semantics_id);
instruction
}
+
+ // Group Instructions
+
+ pub(super) fn group_non_uniform_ballot(
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ predicate: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::GroupNonUniformBallot);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(predicate);
+
+ instruction
+ }
+ pub(super) fn group_non_uniform_broadcast_first(
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(value);
+
+ instruction
+ }
+ pub(super) fn group_non_uniform_gather(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ value: Word,
+ index: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ instruction.add_operand(value);
+ instruction.add_operand(index);
+
+ instruction
+ }
+ pub(super) fn group_non_uniform_arithmetic(
+ op: Op,
+ result_type_id: Word,
+ id: Word,
+ exec_scope_id: Word,
+ group_op: Option<spirv::GroupOperation>,
+ value: Word,
+ ) -> Self {
+ let mut instruction = Self::new(op);
+ instruction.set_type(result_type_id);
+ instruction.set_result(id);
+ instruction.add_operand(exec_scope_id);
+ if let Some(group_op) = group_op {
+ instruction.add_operand(group_op as u32);
+ }
+ instruction.add_operand(value);
+
+ instruction
+ }
}
impl From<crate::StorageFormat> for spirv::ImageFormat {
diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs
index eb29e3cd8b..38a87049e6 100644
--- a/third_party/rust/naga/src/back/spv/mod.rs
+++ b/third_party/rust/naga/src/back/spv/mod.rs
@@ -13,6 +13,7 @@ mod layout;
mod ray;
mod recyclable;
mod selection;
+mod subgroup;
mod writer;
pub use spirv::Capability;
@@ -70,6 +71,8 @@ pub enum Error {
FeatureNotImplemented(&'static str),
#[error("module is not validated properly: {0}")]
Validation(&'static str),
+ #[error("overrides should not be present at this stage")]
+ Override,
}
#[derive(Default)]
@@ -245,7 +248,7 @@ impl LocalImageType {
/// this, by converting everything possible to a `LocalType` before inspecting
/// it.
///
-/// ## `Localtype` equality and SPIR-V `OpType` uniqueness
+/// ## `LocalType` equality and SPIR-V `OpType` uniqueness
///
/// The definition of `Eq` on `LocalType` is carefully chosen to help us follow
/// certain SPIR-V rules. SPIR-V §2.8 requires some classes of `OpType...`
@@ -454,7 +457,7 @@ impl recyclable::Recyclable for CachedExpressions {
#[derive(Eq, Hash, PartialEq)]
enum CachedConstant {
- Literal(crate::Literal),
+ Literal(crate::proc::HashableLiteral),
Composite {
ty: LookupType,
constituent_ids: Vec<Word>,
@@ -527,6 +530,42 @@ struct FunctionArgument {
handle_id: Word,
}
+/// Tracks the expressions for which the backend emits the following instructions:
+/// - OpConstantTrue
+/// - OpConstantFalse
+/// - OpConstant
+/// - OpConstantComposite
+/// - OpConstantNull
+struct ExpressionConstnessTracker {
+ inner: bit_set::BitSet,
+}
+
+impl ExpressionConstnessTracker {
+ fn from_arena(arena: &crate::Arena<crate::Expression>) -> Self {
+ let mut inner = bit_set::BitSet::new();
+ for (handle, expr) in arena.iter() {
+ let insert = match *expr {
+ crate::Expression::Literal(_)
+ | crate::Expression::ZeroValue(_)
+ | crate::Expression::Constant(_) => true,
+ crate::Expression::Compose { ref components, .. } => {
+ components.iter().all(|h| inner.contains(h.index()))
+ }
+ crate::Expression::Splat { value, .. } => inner.contains(value.index()),
+ _ => false,
+ };
+ if insert {
+ inner.insert(handle.index());
+ }
+ }
+ Self { inner }
+ }
+
+ fn is_const(&self, value: Handle<crate::Expression>) -> bool {
+ self.inner.contains(value.index())
+ }
+}
+
/// General information needed to emit SPIR-V for Naga statements.
struct BlockContext<'w> {
/// The writer handling the module to which this code belongs.
@@ -552,7 +591,7 @@ struct BlockContext<'w> {
temp_list: Vec<Word>,
/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
- expression_constness: crate::proc::ExpressionConstnessTracker,
+ expression_constness: ExpressionConstnessTracker,
}
impl BlockContext<'_> {
@@ -725,7 +764,7 @@ impl<'a> Default for Options<'a> {
}
// A subset of options meant to be changed per pipeline.
-#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
diff --git a/third_party/rust/naga/src/back/spv/subgroup.rs b/third_party/rust/naga/src/back/spv/subgroup.rs
new file mode 100644
index 0000000000..c952cb11a7
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/subgroup.rs
@@ -0,0 +1,207 @@
+use super::{Block, BlockContext, Error, Instruction};
+use crate::{
+ arena::Handle,
+ back::spv::{LocalType, LookupType},
+ TypeInner,
+};
+
+impl<'w> BlockContext<'w> {
+ pub(super) fn write_subgroup_ballot(
+ &mut self,
+ predicate: &Option<Handle<crate::Expression>>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+ let predicate = if let Some(predicate) = *predicate {
+ self.cached[predicate]
+ } else {
+ self.writer.get_constant_scalar(crate::Literal::Bool(true))
+ };
+ let id = self.gen_id();
+ block.body.push(Instruction::group_non_uniform_ballot(
+ vec4_u32_type_id,
+ id,
+ exec_scope_id,
+ predicate,
+ ));
+ self.cached[result] = id;
+ Ok(())
+ }
+ pub(super) fn write_subgroup_operation(
+ &mut self,
+ op: &crate::SubgroupOperation,
+ collective_op: &crate::CollectiveOperation,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ use crate::SubgroupOperation as sg;
+ match *op {
+ sg::All | sg::Any => {
+ self.writer.require_any(
+ "GroupNonUniformVote",
+ &[spirv::Capability::GroupNonUniformVote],
+ )?;
+ }
+ _ => {
+ self.writer.require_any(
+ "GroupNonUniformArithmetic",
+ &[spirv::Capability::GroupNonUniformArithmetic],
+ )?;
+ }
+ }
+
+ let id = self.gen_id();
+ let result_ty = &self.fun_info[result].ty;
+ let result_type_id = self.get_expression_type_id(result_ty);
+ let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
+
+ let (is_scalar, scalar) = match *result_ty_inner {
+ TypeInner::Scalar(kind) => (true, kind),
+ TypeInner::Vector { scalar: kind, .. } => (false, kind),
+ _ => unimplemented!(),
+ };
+
+ use crate::ScalarKind as sk;
+ let spirv_op = match (scalar.kind, *op) {
+ (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
+ (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
+ (_, sg::All | sg::Any) => unimplemented!(),
+
+ (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
+ (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
+ (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
+ (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
+ (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
+ (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
+ (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
+ (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
+ (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
+ (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
+ (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
+
+ (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
+ (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
+ (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
+ (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
+ (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
+ (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
+ (_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
+ };
+
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+
+ use crate::CollectiveOperation as c;
+ let group_op = match *op {
+ sg::All | sg::Any => None,
+ _ => Some(match *collective_op {
+ c::Reduce => spirv::GroupOperation::Reduce,
+ c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
+ c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
+ }),
+ };
+
+ let arg_id = self.cached[argument];
+ block.body.push(Instruction::group_non_uniform_arithmetic(
+ spirv_op,
+ result_type_id,
+ id,
+ exec_scope_id,
+ group_op,
+ arg_id,
+ ));
+ self.cached[result] = id;
+ Ok(())
+ }
+ pub(super) fn write_subgroup_gather(
+ &mut self,
+ mode: &crate::GatherMode,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ match *mode {
+ crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ }
+ crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => {
+ self.writer.require_any(
+ "GroupNonUniformShuffle",
+ &[spirv::Capability::GroupNonUniformShuffle],
+ )?;
+ }
+ crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
+ self.writer.require_any(
+ "GroupNonUniformShuffleRelative",
+ &[spirv::Capability::GroupNonUniformShuffleRelative],
+ )?;
+ }
+ }
+
+ let id = self.gen_id();
+ let result_ty = &self.fun_info[result].ty;
+ let result_type_id = self.get_expression_type_id(result_ty);
+
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+
+ let arg_id = self.cached[argument];
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {
+ block
+ .body
+ .push(Instruction::group_non_uniform_broadcast_first(
+ result_type_id,
+ id,
+ exec_scope_id,
+ arg_id,
+ ));
+ }
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let index_id = self.cached[index];
+ let op = match *mode {
+ crate::GatherMode::BroadcastFirst => unreachable!(),
+ // Use shuffle to emit broadcast to allow the index to
+ // be dynamically uniform on Vulkan 1.1. The argument to
+ // OpGroupNonUniformBroadcast must be a constant pre SPIR-V
+ // 1.5 (vulkan 1.2)
+ crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
+ crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
+ crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
+ crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
+ crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
+ };
+ block.body.push(Instruction::group_non_uniform_gather(
+ op,
+ result_type_id,
+ id,
+ exec_scope_id,
+ arg_id,
+ index_id,
+ ));
+ }
+ }
+ self.cached[result] = id;
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs
index a5065e0623..73a16c273e 100644
--- a/third_party/rust/naga/src/back/spv/writer.rs
+++ b/third_party/rust/naga/src/back/spv/writer.rs
@@ -615,7 +615,7 @@ impl Writer {
// Steal the Writer's temp list for a bit.
temp_list: std::mem::take(&mut self.temp_list),
writer: self,
- expression_constness: crate::proc::ExpressionConstnessTracker::from_arena(
+ expression_constness: super::ExpressionConstnessTracker::from_arena(
&ir_function.expressions,
),
};
@@ -970,6 +970,11 @@ impl Writer {
handle: Handle<crate::Type>,
) -> Result<Word, Error> {
let ty = &arena[handle];
+ // If it's a type that needs SPIR-V capabilities, request them now.
+ // This needs to happen regardless of the LocalType lookup succeeding,
+ // because some types which map to the same LocalType have different
+ // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569
+ self.request_type_capabilities(&ty.inner)?;
let id = if let Some(local) = make_local(&ty.inner) {
// This type can be represented as a `LocalType`, so check if we've
// already written an instruction for it. If not, do so now, with
@@ -985,10 +990,6 @@ impl Writer {
self.write_type_declaration_local(id, local);
- // If it's a type that needs SPIR-V capabilities, request them now,
- // so write_type_declaration_local can stay infallible.
- self.request_type_capabilities(&ty.inner)?;
-
id
}
}
@@ -1150,7 +1151,7 @@ impl Writer {
}
pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word {
- let scalar = CachedConstant::Literal(value);
+ let scalar = CachedConstant::Literal(value.into());
if let Some(&id) = self.cached_constants.get(&scalar) {
return id;
}
@@ -1258,7 +1259,7 @@ impl Writer {
ir_module: &crate::Module,
mod_info: &ModuleInfo,
) -> Result<Word, Error> {
- let id = match ir_module.const_expressions[handle] {
+ let id = match ir_module.global_expressions[handle] {
crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
crate::Expression::Constant(constant) => {
let constant = &ir_module.constants[constant];
@@ -1272,7 +1273,7 @@ impl Writer {
let component_ids: Vec<_> = crate::proc::flatten_compose(
ty,
components,
- &ir_module.const_expressions,
+ &ir_module.global_expressions,
&ir_module.types,
)
.map(|component| self.constant_ids[component.index()])
@@ -1310,7 +1311,11 @@ impl Writer {
spirv::MemorySemantics::WORKGROUP_MEMORY,
flags.contains(crate::Barrier::WORK_GROUP),
);
- let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
+ let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) {
+ self.get_index_constant(spirv::Scope::Subgroup as u32)
+ } else {
+ self.get_index_constant(spirv::Scope::Workgroup as u32)
+ };
let mem_scope_id = self.get_index_constant(memory_scope as u32);
let semantics_id = self.get_index_constant(semantics.bits());
block.body.push(Instruction::control_barrier(
@@ -1585,6 +1590,41 @@ impl Writer {
Bi::WorkGroupId => BuiltIn::WorkgroupId,
Bi::WorkGroupSize => BuiltIn::WorkgroupSize,
Bi::NumWorkGroups => BuiltIn::NumWorkgroups,
+ // Subgroup
+ Bi::NumSubgroups => {
+ self.require_any(
+ "`num_subgroups` built-in",
+ &[spirv::Capability::GroupNonUniform],
+ )?;
+ BuiltIn::NumSubgroups
+ }
+ Bi::SubgroupId => {
+ self.require_any(
+ "`subgroup_id` built-in",
+ &[spirv::Capability::GroupNonUniform],
+ )?;
+ BuiltIn::SubgroupId
+ }
+ Bi::SubgroupSize => {
+ self.require_any(
+ "`subgroup_size` built-in",
+ &[
+ spirv::Capability::GroupNonUniform,
+ spirv::Capability::SubgroupBallotKHR,
+ ],
+ )?;
+ BuiltIn::SubgroupSize
+ }
+ Bi::SubgroupInvocationId => {
+ self.require_any(
+ "`subgroup_invocation_id` built-in",
+ &[
+ spirv::Capability::GroupNonUniform,
+ spirv::Capability::SubgroupBallotKHR,
+ ],
+ )?;
+ BuiltIn::SubgroupLocalInvocationId
+ }
};
self.decorate(id, Decoration::BuiltIn, &[built_in as u32]);
@@ -1899,7 +1939,7 @@ impl Writer {
source_code: debug_info.source_code,
source_file_id,
});
- self.debugs.push(Instruction::source(
+ self.debugs.append(&mut Instruction::source_auto_continued(
spirv::SourceLanguage::Unknown,
0,
&debug_info_inner,
@@ -1914,8 +1954,8 @@ impl Writer {
// write all const-expressions as constants
self.constant_ids
- .resize(ir_module.const_expressions.len(), 0);
- for (handle, _) in ir_module.const_expressions.iter() {
+ .resize(ir_module.global_expressions.len(), 0);
+ for (handle, _) in ir_module.global_expressions.iter() {
self.write_constant_expr(handle, ir_module, mod_info)?;
}
debug_assert!(self.constant_ids.iter().all(|&id| id != 0));
@@ -2029,6 +2069,10 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
+ if !ir_module.overrides.is_empty() {
+ return Err(Error::Override);
+ }
+
self.reset();
// Try to find the entry point and corresponding index
diff --git a/third_party/rust/naga/src/back/wgsl/writer.rs b/third_party/rust/naga/src/back/wgsl/writer.rs
index 3039cbbbe4..789f6f62bf 100644
--- a/third_party/rust/naga/src/back/wgsl/writer.rs
+++ b/third_party/rust/naga/src/back/wgsl/writer.rs
@@ -106,6 +106,12 @@ impl<W: Write> Writer<W> {
}
pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
+ if !module.overrides.is_empty() {
+ return Err(Error::Unimplemented(
+ "Pipeline constants are not yet supported for this back-end".to_string(),
+ ));
+ }
+
self.reset(module);
// Save all ep result types
@@ -918,8 +924,124 @@ impl<W: Write> Writer<W> {
if barrier.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}workgroupBarrier();")?;
}
+
+ if barrier.contains(crate::Barrier::SUB_GROUP) {
+ writeln!(self.out, "{level}subgroupBarrier();")?;
+ }
}
Statement::RayQuery { .. } => unreachable!(),
+ Statement::SubgroupBallot { result, predicate } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ write!(self.out, "subgroupBallot(")?;
+ if let Some(predicate) = predicate {
+ self.write_expr(module, predicate, func_ctx)?;
+ }
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ match (collective_op, op) {
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
+ write!(self.out, "subgroupAll(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
+ write!(self.out, "subgroupAny(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupAdd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupMul(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
+ write!(self.out, "subgroupMax(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
+ write!(self.out, "subgroupMin(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
+ write!(self.out, "subgroupAnd(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
+ write!(self.out, "subgroupOr(")?
+ }
+ (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
+ write!(self.out, "subgroupXor(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupExclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupExclusiveMul(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
+ write!(self.out, "subgroupInclusiveAdd(")?
+ }
+ (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
+ write!(self.out, "subgroupInclusiveMul(")?
+ }
+ _ => unimplemented!(),
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ writeln!(self.out, ");")?;
+ }
+ Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ write!(self.out, "{level}")?;
+ let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
+ self.start_named_expr(module, result, func_ctx, &res_name)?;
+ self.named_expressions.insert(result, res_name);
+
+ match mode {
+ crate::GatherMode::BroadcastFirst => {
+ write!(self.out, "subgroupBroadcastFirst(")?;
+ }
+ crate::GatherMode::Broadcast(_) => {
+ write!(self.out, "subgroupBroadcast(")?;
+ }
+ crate::GatherMode::Shuffle(_) => {
+ write!(self.out, "subgroupShuffle(")?;
+ }
+ crate::GatherMode::ShuffleDown(_) => {
+ write!(self.out, "subgroupShuffleDown(")?;
+ }
+ crate::GatherMode::ShuffleUp(_) => {
+ write!(self.out, "subgroupShuffleUp(")?;
+ }
+ crate::GatherMode::ShuffleXor(_) => {
+ write!(self.out, "subgroupShuffleXor(")?;
+ }
+ }
+ self.write_expr(module, argument, func_ctx)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ write!(self.out, ", ")?;
+ self.write_expr(module, index, func_ctx)?;
+ }
+ }
+ writeln!(self.out, ");")?;
+ }
}
Ok(())
@@ -1070,7 +1192,7 @@ impl<W: Write> Writer<W> {
self.write_possibly_const_expression(
module,
expr,
- &module.const_expressions,
+ &module.global_expressions,
|writer, expr| writer.write_const_expression(module, expr),
)
}
@@ -1199,6 +1321,7 @@ impl<W: Write> Writer<W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
+ Expression::Override(_) => unreachable!(),
Expression::FunctionArgument(pos) => {
let name_key = func_ctx.argument_key(pos);
let name = &self.names[&name_key];
@@ -1691,6 +1814,8 @@ impl<W: Write> Writer<W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
+ | Expression::SubgroupBallotResult
+ | Expression::SubgroupOperationResult { .. }
| Expression::WorkGroupUniformLoadResult { .. } => {}
}
@@ -1792,6 +1917,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
Bi::SampleMask => "sample_mask",
Bi::PrimitiveIndex => "primitive_index",
Bi::ViewIndex => "view_index",
+ Bi::NumSubgroups => "num_subgroups",
+ Bi::SubgroupId => "subgroup_id",
+ Bi::SubgroupSize => "subgroup_size",
+ Bi::SubgroupInvocationId => "subgroup_invocation_id",
Bi::BaseInstance
| Bi::BaseVertex
| Bi::ClipDistance
diff --git a/third_party/rust/naga/src/block.rs b/third_party/rust/naga/src/block.rs
index 0abda9da7c..2e86a928f1 100644
--- a/third_party/rust/naga/src/block.rs
+++ b/third_party/rust/naga/src/block.rs
@@ -65,6 +65,12 @@ impl Block {
self.span_info.splice(range.clone(), other.span_info);
self.body.splice(range, other.body);
}
+
+ pub fn span_into_iter(self) -> impl Iterator<Item = (Statement, Span)> {
+ let Block { body, span_info } = self;
+ body.into_iter().zip(span_info)
+ }
+
pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> {
let span_iter = self.span_info.iter();
self.body.iter().zip(span_iter)
diff --git a/third_party/rust/naga/src/compact/expressions.rs b/third_party/rust/naga/src/compact/expressions.rs
index 301bbe3240..a418bde301 100644
--- a/third_party/rust/naga/src/compact/expressions.rs
+++ b/third_party/rust/naga/src/compact/expressions.rs
@@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle};
pub struct ExpressionTracer<'tracer> {
pub constants: &'tracer Arena<crate::Constant>,
+ pub overrides: &'tracer Arena<crate::Override>,
/// The arena in which we are currently tracing expressions.
pub expressions: &'tracer Arena<crate::Expression>,
@@ -20,11 +21,11 @@ pub struct ExpressionTracer<'tracer> {
/// the module's constant expression arena.
pub expressions_used: &'tracer mut HandleSet<crate::Expression>,
- /// The used set for the module's `const_expressions` arena.
+ /// The used set for the module's `global_expressions` arena.
///
/// If `None`, we are already tracing the constant expressions,
/// and `expressions_used` already refers to their handle set.
- pub const_expressions_used: Option<&'tracer mut HandleSet<crate::Expression>>,
+ pub global_expressions_used: Option<&'tracer mut HandleSet<crate::Expression>>,
}
impl<'tracer> ExpressionTracer<'tracer> {
@@ -39,11 +40,11 @@ impl<'tracer> ExpressionTracer<'tracer> {
/// marked.
///
/// [fe]: crate::Function::expressions
- /// [ce]: crate::Module::const_expressions
+ /// [ce]: crate::Module::global_expressions
pub fn trace_expressions(&mut self) {
log::trace!(
"entering trace_expression of {}",
- if self.const_expressions_used.is_some() {
+ if self.global_expressions_used.is_some() {
"function expressions"
} else {
"const expressions"
@@ -71,6 +72,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
| Ex::GlobalVariable(_)
| Ex::LocalVariable(_)
| Ex::CallResult(_)
+ | Ex::SubgroupBallotResult
| Ex::RayQueryProceedResult => {}
Ex::Constant(handle) => {
@@ -83,11 +85,16 @@ impl<'tracer> ExpressionTracer<'tracer> {
// and the constant refers to the initializer, it must
// precede `expr` in the arena.
let init = self.constants[handle].init;
- match self.const_expressions_used {
+ match self.global_expressions_used {
Some(ref mut used) => used.insert(init),
None => self.expressions_used.insert(init),
}
}
+ Ex::Override(_) => {
+ // All overrides are considered used by definition. We mark
+ // their types and initialization expressions as used in
+ // `compact::compact`, so we have no more work to do here.
+ }
Ex::ZeroValue(ty) => self.types_used.insert(ty),
Ex::Compose { ty, ref components } => {
self.types_used.insert(ty);
@@ -116,7 +123,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
self.expressions_used
.insert_iter([image, sampler, coordinate]);
self.expressions_used.insert_iter(array_index);
- match self.const_expressions_used {
+ match self.global_expressions_used {
Some(ref mut used) => used.insert_iter(offset),
None => self.expressions_used.insert_iter(offset),
}
@@ -186,6 +193,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty),
Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty),
Ex::ArrayLength(expr) => self.expressions_used.insert(expr),
+ Ex::SubgroupOperationResult { ty } => self.types_used.insert(ty),
Ex::RayQueryGetIntersection {
query,
committed: _,
@@ -217,8 +225,12 @@ impl ModuleMap {
| Ex::GlobalVariable(_)
| Ex::LocalVariable(_)
| Ex::CallResult(_)
+ | Ex::SubgroupBallotResult
| Ex::RayQueryProceedResult => {}
+ // All overrides are retained, so their handles never change.
+ Ex::Override(_) => {}
+
// Expressions that contain handles that need to be adjusted.
Ex::Constant(ref mut constant) => self.constants.adjust(constant),
Ex::ZeroValue(ref mut ty) => self.types.adjust(ty),
@@ -267,7 +279,7 @@ impl ModuleMap {
adjust(coordinate);
operand_map.adjust_option(array_index);
if let Some(ref mut offset) = *offset {
- self.const_expressions.adjust(offset);
+ self.global_expressions.adjust(offset);
}
self.adjust_sample_level(level, operand_map);
operand_map.adjust_option(depth_ref);
@@ -344,6 +356,7 @@ impl ModuleMap {
comparison: _,
} => self.types.adjust(ty),
Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty),
+ Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty),
Ex::ArrayLength(ref mut expr) => adjust(expr),
Ex::RayQueryGetIntersection {
ref mut query,
diff --git a/third_party/rust/naga/src/compact/functions.rs b/third_party/rust/naga/src/compact/functions.rs
index b0d08c7e96..4ac2223eb7 100644
--- a/third_party/rust/naga/src/compact/functions.rs
+++ b/third_party/rust/naga/src/compact/functions.rs
@@ -4,10 +4,11 @@ use super::{FunctionMap, ModuleMap};
pub struct FunctionTracer<'a> {
pub function: &'a crate::Function,
pub constants: &'a crate::Arena<crate::Constant>,
+ pub overrides: &'a crate::Arena<crate::Override>,
pub types_used: &'a mut HandleSet<crate::Type>,
pub constants_used: &'a mut HandleSet<crate::Constant>,
- pub const_expressions_used: &'a mut HandleSet<crate::Expression>,
+ pub global_expressions_used: &'a mut HandleSet<crate::Expression>,
/// Function-local expressions used.
pub expressions_used: HandleSet<crate::Expression>,
@@ -47,12 +48,13 @@ impl<'a> FunctionTracer<'a> {
fn as_expression(&mut self) -> super::expressions::ExpressionTracer {
super::expressions::ExpressionTracer {
constants: self.constants,
+ overrides: self.overrides,
expressions: &self.function.expressions,
types_used: self.types_used,
constants_used: self.constants_used,
expressions_used: &mut self.expressions_used,
- const_expressions_used: Some(&mut self.const_expressions_used),
+ global_expressions_used: Some(&mut self.global_expressions_used),
}
}
}
diff --git a/third_party/rust/naga/src/compact/mod.rs b/third_party/rust/naga/src/compact/mod.rs
index b4e57ed5c9..0d7a37b579 100644
--- a/third_party/rust/naga/src/compact/mod.rs
+++ b/third_party/rust/naga/src/compact/mod.rs
@@ -38,7 +38,7 @@ pub fn compact(module: &mut crate::Module) {
log::trace!("tracing global {:?}", global.name);
module_tracer.types_used.insert(global.ty);
if let Some(init) = global.init {
- module_tracer.const_expressions_used.insert(init);
+ module_tracer.global_expressions_used.insert(init);
}
}
}
@@ -50,7 +50,15 @@ pub fn compact(module: &mut crate::Module) {
for (handle, constant) in module.constants.iter() {
if constant.name.is_some() {
module_tracer.constants_used.insert(handle);
- module_tracer.const_expressions_used.insert(constant.init);
+ module_tracer.global_expressions_used.insert(constant.init);
+ }
+ }
+
+ // We treat all overrides as used by definition.
+ for (_, override_) in module.overrides.iter() {
+ module_tracer.types_used.insert(override_.ty);
+ if let Some(init) = override_.init {
+ module_tracer.global_expressions_used.insert(init);
}
}
@@ -137,9 +145,9 @@ pub fn compact(module: &mut crate::Module) {
// Drop unused constant expressions, reusing existing storage.
log::trace!("adjusting constant expressions");
- module.const_expressions.retain_mut(|handle, expr| {
- if module_map.const_expressions.used(handle) {
- module_map.adjust_expression(expr, &module_map.const_expressions);
+ module.global_expressions.retain_mut(|handle, expr| {
+ if module_map.global_expressions.used(handle) {
+ module_map.adjust_expression(expr, &module_map.global_expressions);
true
} else {
false
@@ -151,20 +159,29 @@ pub fn compact(module: &mut crate::Module) {
module.constants.retain_mut(|handle, constant| {
if module_map.constants.used(handle) {
module_map.types.adjust(&mut constant.ty);
- module_map.const_expressions.adjust(&mut constant.init);
+ module_map.global_expressions.adjust(&mut constant.init);
true
} else {
false
}
});
+ // Adjust override types and initializers.
+ log::trace!("adjusting overrides");
+ for (_, override_) in module.overrides.iter_mut() {
+ module_map.types.adjust(&mut override_.ty);
+ if let Some(init) = override_.init.as_mut() {
+ module_map.global_expressions.adjust(init);
+ }
+ }
+
// Adjust global variables' types and initializers.
log::trace!("adjusting global variables");
for (_, global) in module.global_variables.iter_mut() {
log::trace!("adjusting global {:?}", global.name);
module_map.types.adjust(&mut global.ty);
if let Some(ref mut init) = global.init {
- module_map.const_expressions.adjust(init);
+ module_map.global_expressions.adjust(init);
}
}
@@ -193,7 +210,7 @@ struct ModuleTracer<'module> {
module: &'module crate::Module,
types_used: HandleSet<crate::Type>,
constants_used: HandleSet<crate::Constant>,
- const_expressions_used: HandleSet<crate::Expression>,
+ global_expressions_used: HandleSet<crate::Expression>,
}
impl<'module> ModuleTracer<'module> {
@@ -202,7 +219,7 @@ impl<'module> ModuleTracer<'module> {
module,
types_used: HandleSet::for_arena(&module.types),
constants_used: HandleSet::for_arena(&module.constants),
- const_expressions_used: HandleSet::for_arena(&module.const_expressions),
+ global_expressions_used: HandleSet::for_arena(&module.global_expressions),
}
}
@@ -233,12 +250,13 @@ impl<'module> ModuleTracer<'module> {
fn as_const_expression(&mut self) -> expressions::ExpressionTracer {
expressions::ExpressionTracer {
- expressions: &self.module.const_expressions,
+ expressions: &self.module.global_expressions,
constants: &self.module.constants,
+ overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
- expressions_used: &mut self.const_expressions_used,
- const_expressions_used: None,
+ expressions_used: &mut self.global_expressions_used,
+ global_expressions_used: None,
}
}
@@ -249,9 +267,10 @@ impl<'module> ModuleTracer<'module> {
FunctionTracer {
function,
constants: &self.module.constants,
+ overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
- const_expressions_used: &mut self.const_expressions_used,
+ global_expressions_used: &mut self.global_expressions_used,
expressions_used: HandleSet::for_arena(&function.expressions),
}
}
@@ -260,7 +279,7 @@ impl<'module> ModuleTracer<'module> {
struct ModuleMap {
types: HandleMap<crate::Type>,
constants: HandleMap<crate::Constant>,
- const_expressions: HandleMap<crate::Expression>,
+ global_expressions: HandleMap<crate::Expression>,
}
impl From<ModuleTracer<'_>> for ModuleMap {
@@ -268,7 +287,7 @@ impl From<ModuleTracer<'_>> for ModuleMap {
ModuleMap {
types: HandleMap::from_set(used.types_used),
constants: HandleMap::from_set(used.constants_used),
- const_expressions: HandleMap::from_set(used.const_expressions_used),
+ global_expressions: HandleMap::from_set(used.global_expressions_used),
}
}
}
diff --git a/third_party/rust/naga/src/compact/statements.rs b/third_party/rust/naga/src/compact/statements.rs
index 0698b57258..a124281bc1 100644
--- a/third_party/rust/naga/src/compact/statements.rs
+++ b/third_party/rust/naga/src/compact/statements.rs
@@ -97,6 +97,39 @@ impl FunctionTracer<'_> {
self.expressions_used.insert(query);
self.trace_ray_query_function(fun);
}
+ St::SubgroupBallot { result, predicate } => {
+ if let Some(predicate) = predicate {
+ self.expressions_used.insert(predicate)
+ }
+ self.expressions_used.insert(result)
+ }
+ St::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: _,
+ argument,
+ result,
+ } => {
+ self.expressions_used.insert(argument);
+ self.expressions_used.insert(result)
+ }
+ St::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ self.expressions_used.insert(index)
+ }
+ }
+ self.expressions_used.insert(argument);
+ self.expressions_used.insert(result)
+ }
// Trivial statements.
St::Break
@@ -250,6 +283,40 @@ impl FunctionMap {
adjust(query);
self.adjust_ray_query_function(fun);
}
+ St::SubgroupBallot {
+ ref mut result,
+ ref mut predicate,
+ } => {
+ if let Some(ref mut predicate) = *predicate {
+ adjust(predicate);
+ }
+ adjust(result);
+ }
+ St::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: _,
+ ref mut argument,
+ ref mut result,
+ } => {
+ adjust(argument);
+ adjust(result);
+ }
+ St::SubgroupGather {
+ ref mut mode,
+ ref mut argument,
+ ref mut result,
+ } => {
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(ref mut index)
+ | crate::GatherMode::Shuffle(ref mut index)
+ | crate::GatherMode::ShuffleDown(ref mut index)
+ | crate::GatherMode::ShuffleUp(ref mut index)
+ | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index),
+ }
+ adjust(argument);
+ adjust(result);
+ }
// Trivial statements.
St::Break
diff --git a/third_party/rust/naga/src/error.rs b/third_party/rust/naga/src/error.rs
new file mode 100644
index 0000000000..5f2e28360b
--- /dev/null
+++ b/third_party/rust/naga/src/error.rs
@@ -0,0 +1,74 @@
+use std::{error::Error, fmt};
+
+#[derive(Clone, Debug)]
+pub struct ShaderError<E> {
+ /// The source code of the shader.
+ pub source: String,
+ pub label: Option<String>,
+ pub inner: Box<E>,
+}
+
+#[cfg(feature = "wgsl-in")]
+impl fmt::Display for ShaderError<crate::front::wgsl::ParseError> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let label = self.label.as_deref().unwrap_or_default();
+ let string = self.inner.emit_to_string(&self.source);
+ write!(f, "\nShader '{label}' parsing {string}")
+ }
+}
+#[cfg(feature = "glsl-in")]
+impl fmt::Display for ShaderError<crate::front::glsl::ParseErrors> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let label = self.label.as_deref().unwrap_or_default();
+ let string = self.inner.emit_to_string(&self.source);
+ write!(f, "\nShader '{label}' parsing {string}")
+ }
+}
+#[cfg(feature = "spv-in")]
+impl fmt::Display for ShaderError<crate::front::spv::Error> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let label = self.label.as_deref().unwrap_or_default();
+ let string = self.inner.emit_to_string(&self.source);
+ write!(f, "\nShader '{label}' parsing {string}")
+ }
+}
+impl fmt::Display for ShaderError<crate::WithSpan<crate::valid::ValidationError>> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ use codespan_reporting::{
+ diagnostic::{Diagnostic, Label},
+ files::SimpleFile,
+ term,
+ };
+
+ let label = self.label.as_deref().unwrap_or_default();
+ let files = SimpleFile::new(label, &self.source);
+ let config = term::Config::default();
+ let mut writer = term::termcolor::NoColor::new(Vec::new());
+
+ let diagnostic = Diagnostic::error().with_labels(
+ self.inner
+ .spans()
+ .map(|&(span, ref desc)| {
+ Label::primary((), span.to_range().unwrap()).with_message(desc.to_owned())
+ })
+ .collect(),
+ );
+
+ term::emit(&mut writer, &config, &files, &diagnostic).expect("cannot write error");
+
+ write!(
+ f,
+ "\nShader validation {}",
+ String::from_utf8_lossy(&writer.into_inner())
+ )
+ }
+}
+impl<E> Error for ShaderError<E>
+where
+ ShaderError<E>: fmt::Display,
+ E: Error + 'static,
+{
+ fn source(&self) -> Option<&(dyn Error + 'static)> {
+ Some(&self.inner)
+ }
+}
diff --git a/third_party/rust/naga/src/front/glsl/context.rs b/third_party/rust/naga/src/front/glsl/context.rs
index f26c57965d..6ba7df593a 100644
--- a/third_party/rust/naga/src/front/glsl/context.rs
+++ b/third_party/rust/naga/src/front/glsl/context.rs
@@ -77,12 +77,19 @@ pub struct Context<'a> {
pub body: Block,
pub module: &'a mut crate::Module,
pub is_const: bool,
- /// Tracks the constness of `Expression`s residing in `self.expressions`
- pub expression_constness: crate::proc::ExpressionConstnessTracker,
+ /// Tracks the expression kind of `Expression`s residing in `self.expressions`
+ pub local_expression_kind_tracker: crate::proc::ExpressionKindTracker,
+ /// Tracks the expression kind of `Expression`s residing in `self.module.global_expressions`
+ pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
}
impl<'a> Context<'a> {
- pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> {
+ pub fn new(
+ frontend: &Frontend,
+ module: &'a mut crate::Module,
+ is_const: bool,
+ global_expression_kind_tracker: &'a mut crate::proc::ExpressionKindTracker,
+ ) -> Result<Self> {
let mut this = Context {
expressions: Arena::new(),
locals: Arena::new(),
@@ -101,7 +108,8 @@ impl<'a> Context<'a> {
body: Block::new(),
module,
is_const: false,
- expression_constness: crate::proc::ExpressionConstnessTracker::new(),
+ local_expression_kind_tracker: crate::proc::ExpressionKindTracker::new(),
+ global_expression_kind_tracker,
};
this.emit_start();
@@ -249,40 +257,24 @@ impl<'a> Context<'a> {
pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
let mut eval = if self.is_const {
- crate::proc::ConstantEvaluator::for_glsl_module(self.module)
+ crate::proc::ConstantEvaluator::for_glsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ )
} else {
crate::proc::ConstantEvaluator::for_glsl_function(
self.module,
&mut self.expressions,
- &mut self.expression_constness,
+ &mut self.local_expression_kind_tracker,
&mut self.emitter,
&mut self.body,
)
};
- let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error {
+ eval.try_eval_and_append(expr, meta).map_err(|e| Error {
kind: e.into(),
meta,
- });
-
- match res {
- Ok(expr) => Ok(expr),
- Err(e) => {
- if self.is_const {
- Err(e)
- } else {
- let needs_pre_emit = expr.needs_pre_emit();
- if needs_pre_emit {
- self.body.extend(self.emitter.finish(&self.expressions));
- }
- let h = self.expressions.append(expr, meta);
- if needs_pre_emit {
- self.emitter.start(&self.expressions);
- }
- Ok(h)
- }
- }
- }
+ })
}
/// Add variable to current scope
@@ -1479,7 +1471,7 @@ impl Index<Handle<Expression>> for Context<'_> {
fn index(&self, index: Handle<Expression>) -> &Self::Output {
if self.is_const {
- &self.module.const_expressions[index]
+ &self.module.global_expressions[index]
} else {
&self.expressions[index]
}
diff --git a/third_party/rust/naga/src/front/glsl/error.rs b/third_party/rust/naga/src/front/glsl/error.rs
index bd16ee30bc..e0771437e6 100644
--- a/third_party/rust/naga/src/front/glsl/error.rs
+++ b/third_party/rust/naga/src/front/glsl/error.rs
@@ -1,4 +1,5 @@
use super::token::TokenValue;
+use crate::SourceLocation;
use crate::{proc::ConstantEvaluatorError, Span};
use codespan_reporting::diagnostic::{Diagnostic, Label};
use codespan_reporting::files::SimpleFile;
@@ -137,14 +138,21 @@ pub struct Error {
pub meta: Span,
}
+impl Error {
+ /// Returns a [`SourceLocation`] for the error message.
+ pub fn location(&self, source: &str) -> Option<SourceLocation> {
+ Some(self.meta.location(source))
+ }
+}
+
/// A collection of errors returned during shader parsing.
#[derive(Clone, Debug)]
#[cfg_attr(test, derive(PartialEq))]
-pub struct ParseError {
+pub struct ParseErrors {
pub errors: Vec<Error>,
}
-impl ParseError {
+impl ParseErrors {
pub fn emit_to_writer(&self, writer: &mut impl WriteColor, source: &str) {
self.emit_to_writer_with_path(writer, source, "glsl");
}
@@ -172,19 +180,19 @@ impl ParseError {
}
}
-impl std::fmt::Display for ParseError {
+impl std::fmt::Display for ParseErrors {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
self.errors.iter().try_for_each(|e| write!(f, "{e:?}"))
}
}
-impl std::error::Error for ParseError {
+impl std::error::Error for ParseErrors {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
-impl From<Vec<Error>> for ParseError {
+impl From<Vec<Error>> for ParseErrors {
fn from(errors: Vec<Error>) -> Self {
Self { errors }
}
diff --git a/third_party/rust/naga/src/front/glsl/functions.rs b/third_party/rust/naga/src/front/glsl/functions.rs
index 01846eb814..fa1bbef56b 100644
--- a/third_party/rust/naga/src/front/glsl/functions.rs
+++ b/third_party/rust/naga/src/front/glsl/functions.rs
@@ -1236,6 +1236,8 @@ impl Frontend {
let pointer = ctx
.expressions
.append(Expression::GlobalVariable(arg.handle), Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(pointer, crate::proc::ExpressionKind::Runtime);
let ty = ctx.module.global_variables[arg.handle].ty;
@@ -1256,6 +1258,8 @@ impl Frontend {
let value = ctx
.expressions
.append(Expression::FunctionArgument(idx), Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(value, crate::proc::ExpressionKind::Runtime);
ctx.body
.push(Statement::Store { pointer, value }, Default::default());
},
@@ -1285,6 +1289,8 @@ impl Frontend {
let pointer = ctx
.expressions
.append(Expression::GlobalVariable(arg.handle), Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(pointer, crate::proc::ExpressionKind::Runtime);
let ty = ctx.module.global_variables[arg.handle].ty;
@@ -1307,6 +1313,8 @@ impl Frontend {
let load = ctx
.expressions
.append(Expression::Load { pointer }, Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(load, crate::proc::ExpressionKind::Runtime);
ctx.body.push(
Statement::Emit(ctx.expressions.range_from(len)),
Default::default(),
@@ -1329,6 +1337,8 @@ impl Frontend {
let res = ctx
.expressions
.append(Expression::Compose { ty, components }, Default::default());
+ ctx.local_expression_kind_tracker
+ .insert(res, crate::proc::ExpressionKind::Runtime);
ctx.body.push(
Statement::Emit(ctx.expressions.range_from(len)),
Default::default(),
diff --git a/third_party/rust/naga/src/front/glsl/mod.rs b/third_party/rust/naga/src/front/glsl/mod.rs
index 75f3929db4..ea202b2445 100644
--- a/third_party/rust/naga/src/front/glsl/mod.rs
+++ b/third_party/rust/naga/src/front/glsl/mod.rs
@@ -13,7 +13,7 @@ To begin, take a look at the documentation for the [`Frontend`].
*/
pub use ast::{Precision, Profile};
-pub use error::{Error, ErrorKind, ExpectedToken, ParseError};
+pub use error::{Error, ErrorKind, ExpectedToken, ParseErrors};
pub use token::TokenValue;
use crate::{proc::Layouter, FastHashMap, FastHashSet, Handle, Module, ShaderStage, Span, Type};
@@ -196,7 +196,7 @@ impl Frontend {
&mut self,
options: &Options,
source: &str,
- ) -> std::result::Result<Module, ParseError> {
+ ) -> std::result::Result<Module, ParseErrors> {
self.reset(options.stage);
let lexer = lex::Lexer::new(source, &options.defines);
diff --git a/third_party/rust/naga/src/front/glsl/parser.rs b/third_party/rust/naga/src/front/glsl/parser.rs
index 851d2e1d79..28e0808063 100644
--- a/third_party/rust/naga/src/front/glsl/parser.rs
+++ b/third_party/rust/naga/src/front/glsl/parser.rs
@@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> {
pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> {
let mut module = Module::default();
+ let mut global_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
// Body and expression arena for global initialization
- let mut ctx = Context::new(frontend, &mut module, false)?;
+ let mut ctx = Context::new(
+ frontend,
+ &mut module,
+ false,
+ &mut global_expression_kind_tracker,
+ )?;
while self.peek(frontend).is_some() {
self.parse_external_declaration(frontend, &mut ctx)?;
@@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> {
frontend: &mut Frontend,
ctx: &mut Context,
) -> Result<(u32, Span)> {
- let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?;
+ let (const_expr, meta) = self.parse_constant_expression(
+ frontend,
+ ctx.module,
+ ctx.global_expression_kind_tracker,
+ )?;
let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr);
@@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> {
&mut self,
frontend: &mut Frontend,
module: &mut Module,
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
) -> Result<(Handle<Expression>, Span)> {
- let mut ctx = Context::new(frontend, module, true)?;
+ let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?;
let mut stmt_ctx = ctx.stmt_ctx();
let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?;
diff --git a/third_party/rust/naga/src/front/glsl/parser/declarations.rs b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
index f5e38fb016..2d253a378d 100644
--- a/third_party/rust/naga/src/front/glsl/parser/declarations.rs
+++ b/third_party/rust/naga/src/front/glsl/parser/declarations.rs
@@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> {
init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok());
late_initializer = None;
} else if let Some(init) = init {
- if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) {
+ if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) {
decl_initializer = None;
late_initializer = Some(init);
} else {
@@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> {
let result = ty.map(|ty| FunctionResult { ty, binding: None });
- let mut context = Context::new(frontend, ctx.module, false)?;
+ let mut context = Context::new(
+ frontend,
+ ctx.module,
+ false,
+ ctx.global_expression_kind_tracker,
+ )?;
self.parse_function_args(frontend, &mut context)?;
diff --git a/third_party/rust/naga/src/front/glsl/parser/functions.rs b/third_party/rust/naga/src/front/glsl/parser/functions.rs
index d428d74761..d0c889e4d3 100644
--- a/third_party/rust/naga/src/front/glsl/parser/functions.rs
+++ b/third_party/rust/naga/src/front/glsl/parser/functions.rs
@@ -192,10 +192,13 @@ impl<'source> ParsingContext<'source> {
TokenValue::Case => {
self.bump(frontend)?;
- let (const_expr, meta) =
- self.parse_constant_expression(frontend, ctx.module)?;
+ let (const_expr, meta) = self.parse_constant_expression(
+ frontend,
+ ctx.module,
+ ctx.global_expression_kind_tracker,
+ )?;
- match ctx.module.const_expressions[const_expr] {
+ match ctx.module.global_expressions[const_expr] {
Expression::Literal(Literal::I32(value)) => match uint {
// This unchecked cast isn't good, but since
// we only reach this code when the selector
diff --git a/third_party/rust/naga/src/front/glsl/parser_tests.rs b/third_party/rust/naga/src/front/glsl/parser_tests.rs
index 259052cd27..135765ca58 100644
--- a/third_party/rust/naga/src/front/glsl/parser_tests.rs
+++ b/third_party/rust/naga/src/front/glsl/parser_tests.rs
@@ -1,7 +1,7 @@
use super::{
ast::Profile,
error::ExpectedToken,
- error::{Error, ErrorKind, ParseError},
+ error::{Error, ErrorKind, ParseErrors},
token::TokenValue,
Frontend, Options, Span,
};
@@ -21,7 +21,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::InvalidVersion(99000),
meta: Span::new(9, 14)
@@ -37,7 +37,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::InvalidVersion(449),
meta: Span::new(9, 12)
@@ -53,7 +53,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::InvalidProfile("smart".into()),
meta: Span::new(13, 18),
@@ -69,7 +69,7 @@ fn version() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![
Error {
kind: ErrorKind::PreprocessorError(PreprocessorError::UnexpectedHash,),
@@ -455,7 +455,7 @@ fn functions() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::SemanticError("Function already defined".into()),
meta: Span::new(134, 152),
@@ -539,7 +539,7 @@ fn constants() {
let mut types = module.types.iter();
let mut constants = module.constants.iter();
- let mut const_expressions = module.const_expressions.iter();
+ let mut global_expressions = module.global_expressions.iter();
let (ty_handle, ty) = types.next().unwrap();
assert_eq!(
@@ -550,14 +550,13 @@ fn constants() {
}
);
- let (init_handle, init) = const_expressions.next().unwrap();
+ let (init_handle, init) = global_expressions.next().unwrap();
assert_eq!(init, &Expression::Literal(crate::Literal::F32(1.0)));
assert_eq!(
constants.next().unwrap().1,
&Constant {
name: Some("a".to_owned()),
- r#override: crate::Override::None,
ty: ty_handle,
init: init_handle
}
@@ -567,7 +566,6 @@ fn constants() {
constants.next().unwrap().1,
&Constant {
name: Some("b".to_owned()),
- r#override: crate::Override::None,
ty: ty_handle,
init: init_handle
}
@@ -636,7 +634,7 @@ fn implicit_conversions() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::SemanticError("Unknown function \'test\'".into()),
meta: Span::new(156, 165),
@@ -660,7 +658,7 @@ fn implicit_conversions() {
)
.err()
.unwrap(),
- ParseError {
+ ParseErrors {
errors: vec![Error {
kind: ErrorKind::SemanticError("Ambiguous best function for \'test\'".into()),
meta: Span::new(158, 165),
diff --git a/third_party/rust/naga/src/front/glsl/types.rs b/third_party/rust/naga/src/front/glsl/types.rs
index e87d76fffc..f6836169c0 100644
--- a/third_party/rust/naga/src/front/glsl/types.rs
+++ b/third_party/rust/naga/src/front/glsl/types.rs
@@ -233,7 +233,7 @@ impl Context<'_> {
};
let expressions = if self.is_const {
- &self.module.const_expressions
+ &self.module.global_expressions
} else {
&self.expressions
};
@@ -330,23 +330,25 @@ impl Context<'_> {
expr: Handle<Expression>,
) -> Result<Handle<Expression>> {
let meta = self.expressions.get_span(expr);
- Ok(match self.expressions[expr] {
+ let h = match self.expressions[expr] {
ref expr @ (Expression::Literal(_)
| Expression::Constant(_)
- | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta),
+ | Expression::ZeroValue(_)) => {
+ self.module.global_expressions.append(expr.clone(), meta)
+ }
Expression::Compose { ty, ref components } => {
let mut components = components.clone();
for component in &mut components {
*component = self.lift_up_const_expression(*component)?;
}
self.module
- .const_expressions
+ .global_expressions
.append(Expression::Compose { ty, components }, meta)
}
Expression::Splat { size, value } => {
let value = self.lift_up_const_expression(value)?;
self.module
- .const_expressions
+ .global_expressions
.append(Expression::Splat { size, value }, meta)
}
_ => {
@@ -355,6 +357,9 @@ impl Context<'_> {
meta,
})
}
- })
+ };
+ self.global_expression_kind_tracker
+ .insert(h, crate::proc::ExpressionKind::Const);
+ Ok(h)
}
}
diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs
index 9d2e7a0e7b..0725fbd94f 100644
--- a/third_party/rust/naga/src/front/glsl/variables.rs
+++ b/third_party/rust/naga/src/front/glsl/variables.rs
@@ -472,7 +472,6 @@ impl Frontend {
let constant = Constant {
name: name.clone(),
- r#override: crate::Override::None,
ty,
init,
};
diff --git a/third_party/rust/naga/src/front/spv/convert.rs b/third_party/rust/naga/src/front/spv/convert.rs
index f0a714fbeb..a6bf0e0451 100644
--- a/third_party/rust/naga/src/front/spv/convert.rs
+++ b/third_party/rust/naga/src/front/spv/convert.rs
@@ -153,6 +153,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result<crate::B
Some(Bi::WorkgroupId) => crate::BuiltIn::WorkGroupId,
Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize,
Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups,
+ // subgroup
+ Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups,
+ Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId,
+ Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize,
+ Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId,
_ => return Err(Error::UnsupportedBuiltIn(word)),
})
}
diff --git a/third_party/rust/naga/src/front/spv/error.rs b/third_party/rust/naga/src/front/spv/error.rs
index af025636c0..44beadce98 100644
--- a/third_party/rust/naga/src/front/spv/error.rs
+++ b/third_party/rust/naga/src/front/spv/error.rs
@@ -5,7 +5,7 @@ use codespan_reporting::files::SimpleFile;
use codespan_reporting::term;
use termcolor::{NoColor, WriteColor};
-#[derive(Debug, thiserror::Error)]
+#[derive(Clone, Debug, thiserror::Error)]
pub enum Error {
#[error("invalid header")]
InvalidHeader,
@@ -58,6 +58,8 @@ pub enum Error {
UnknownBinaryOperator(spirv::Op),
#[error("unknown relational function {0:?}")]
UnknownRelationalFunction(spirv::Op),
+ #[error("unsupported group operation %{0}")]
+ UnsupportedGroupOperation(spirv::Word),
#[error("invalid parameter {0:?}")]
InvalidParameter(spirv::Op),
#[error("invalid operand count {1} for {0:?}")]
@@ -118,8 +120,8 @@ pub enum Error {
ControlFlowGraphCycle(crate::front::spv::BlockId),
#[error("recursive function call %{0}")]
FunctionCallCycle(spirv::Word),
- #[error("invalid array size {0:?}")]
- InvalidArraySize(Handle<crate::Constant>),
+ #[error("invalid array size %{0}")]
+ InvalidArraySize(spirv::Word),
#[error("invalid barrier scope %{0}")]
InvalidBarrierScope(spirv::Word),
#[error("invalid barrier memory semantics %{0}")]
@@ -130,6 +132,8 @@ pub enum Error {
come from a binding)"
)]
NonBindingArrayOfImageOrSamplers,
+ #[error("naga only supports specialization constant IDs up to 65535 but was given {0}")]
+ SpecIdTooHigh(u32),
}
impl Error {
diff --git a/third_party/rust/naga/src/front/spv/function.rs b/third_party/rust/naga/src/front/spv/function.rs
index e81ecf5c9b..113ca56313 100644
--- a/third_party/rust/naga/src/front/spv/function.rs
+++ b/third_party/rust/naga/src/front/spv/function.rs
@@ -59,8 +59,11 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
})
},
local_variables: Arena::new(),
- expressions: self
- .make_expression_storage(&module.global_variables, &module.constants),
+ expressions: self.make_expression_storage(
+ &module.global_variables,
+ &module.constants,
+ &module.overrides,
+ ),
named_expressions: crate::NamedExpressions::default(),
body: crate::Block::new(),
}
@@ -128,7 +131,8 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
expressions: &mut fun.expressions,
local_arena: &mut fun.local_variables,
const_arena: &mut module.constants,
- const_expressions: &mut module.const_expressions,
+ overrides: &mut module.overrides,
+ global_expressions: &mut module.global_expressions,
type_arena: &module.types,
global_arena: &module.global_variables,
arguments: &fun.arguments,
@@ -581,7 +585,8 @@ impl<'function> BlockContext<'function> {
crate::proc::GlobalCtx {
types: self.type_arena,
constants: self.const_arena,
- const_expressions: self.const_expressions,
+ overrides: self.overrides,
+ global_expressions: self.global_expressions,
}
}
diff --git a/third_party/rust/naga/src/front/spv/image.rs b/third_party/rust/naga/src/front/spv/image.rs
index 0f25dd626b..284c4cf7fd 100644
--- a/third_party/rust/naga/src/front/spv/image.rs
+++ b/third_party/rust/naga/src/front/spv/image.rs
@@ -507,11 +507,14 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
}
spirv::ImageOperands::CONST_OFFSET => {
let offset_constant = self.next()?;
- let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle;
- let offset_handle = ctx.const_expressions.append(
- crate::Expression::Constant(offset_handle),
- Default::default(),
- );
+ let offset_expr = self
+ .lookup_constant
+ .lookup(offset_constant)?
+ .inner
+ .to_expr();
+ let offset_handle = ctx
+ .global_expressions
+ .append(offset_expr, Default::default());
offset = Some(offset_handle);
words_left -= 1;
}
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs
index b793448597..7ac5a18cd6 100644
--- a/third_party/rust/naga/src/front/spv/mod.rs
+++ b/third_party/rust/naga/src/front/spv/mod.rs
@@ -196,7 +196,7 @@ struct Decoration {
location: Option<spirv::Word>,
desc_set: Option<spirv::Word>,
desc_index: Option<spirv::Word>,
- specialization: Option<spirv::Word>,
+ specialization_constant_id: Option<spirv::Word>,
storage_buffer: bool,
offset: Option<spirv::Word>,
array_stride: Option<NonZeroU32>,
@@ -216,11 +216,6 @@ impl Decoration {
}
}
- fn specialization(&self) -> crate::Override {
- self.specialization
- .map_or(crate::Override::None, crate::Override::ByNameOrId)
- }
-
const fn resource_binding(&self) -> Option<crate::ResourceBinding> {
match *self {
Decoration {
@@ -284,8 +279,23 @@ struct LookupType {
}
#[derive(Debug)]
+enum Constant {
+ Constant(Handle<crate::Constant>),
+ Override(Handle<crate::Override>),
+}
+
+impl Constant {
+ const fn to_expr(&self) -> crate::Expression {
+ match *self {
+ Self::Constant(c) => crate::Expression::Constant(c),
+ Self::Override(o) => crate::Expression::Override(o),
+ }
+ }
+}
+
+#[derive(Debug)]
struct LookupConstant {
- handle: Handle<crate::Constant>,
+ inner: Constant,
type_id: spirv::Word,
}
@@ -537,7 +547,8 @@ struct BlockContext<'function> {
local_arena: &'function mut Arena<crate::LocalVariable>,
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
- const_expressions: &'function mut Arena<crate::Expression>,
+ overrides: &'function mut Arena<crate::Override>,
+ global_expressions: &'function mut Arena<crate::Expression>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
/// Global arena of the module being processed
@@ -757,7 +768,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
dec.matrix_major = Some(Majority::Row);
}
spirv::Decoration::SpecId => {
- dec.specialization = Some(self.next()?);
+ dec.specialization_constant_id = Some(self.next()?);
}
other => {
log::warn!("Unknown decoration {:?}", other);
@@ -1393,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
inst.expect(5)?;
let init_id = self.next()?;
let lconst = self.lookup_constant.lookup(init_id)?;
- Some(
- ctx.expressions
- .append(crate::Expression::Constant(lconst.handle), span),
- )
+ Some(ctx.expressions.append(lconst.inner.to_expr(), span))
} else {
None
};
@@ -3650,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let semantics_const = self.lookup_constant.lookup(semantics_id)?;
- let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle)
+ let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
- let semantics = resolve_constant(ctx.gctx(), semantics_const.handle)
+ let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;
if exec_scope == spirv::Scope::Workgroup as u32 {
@@ -3692,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
},
);
}
+ Op::GroupNonUniformBallot => {
+ inst.expect(5)?;
+ block.extend(emitter.finish(ctx.expressions));
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let predicate_id = self.next()?;
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let predicate = if self
+ .lookup_constant
+ .lookup(predicate_id)
+ .ok()
+ .filter(|predicate_const| match predicate_const.inner {
+ Constant::Constant(constant) => matches!(
+ ctx.gctx().global_expressions[ctx.gctx().constants[constant].init],
+ crate::Expression::Literal(crate::Literal::Bool(true)),
+ ),
+ Constant::Override(_) => false,
+ })
+ .is_some()
+ {
+ None
+ } else {
+ let predicate_lookup = self.lookup_expression.lookup(predicate_id)?;
+ let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup);
+ Some(predicate_handle)
+ };
+
+ let result_handle = ctx
+ .expressions
+ .append(crate::Expression::SubgroupBallotResult, span);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupBallot {
+ result: result_handle,
+ predicate,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ spirv::Op::GroupNonUniformAll
+ | spirv::Op::GroupNonUniformAny
+ | spirv::Op::GroupNonUniformIAdd
+ | spirv::Op::GroupNonUniformFAdd
+ | spirv::Op::GroupNonUniformIMul
+ | spirv::Op::GroupNonUniformFMul
+ | spirv::Op::GroupNonUniformSMax
+ | spirv::Op::GroupNonUniformUMax
+ | spirv::Op::GroupNonUniformFMax
+ | spirv::Op::GroupNonUniformSMin
+ | spirv::Op::GroupNonUniformUMin
+ | spirv::Op::GroupNonUniformFMin
+ | spirv::Op::GroupNonUniformBitwiseAnd
+ | spirv::Op::GroupNonUniformBitwiseOr
+ | spirv::Op::GroupNonUniformBitwiseXor
+ | spirv::Op::GroupNonUniformLogicalAnd
+ | spirv::Op::GroupNonUniformLogicalOr
+ | spirv::Op::GroupNonUniformLogicalXor => {
+ block.extend(emitter.finish(ctx.expressions));
+ inst.expect(
+ if matches!(
+ inst.op,
+ spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny
+ ) {
+ 5
+ } else {
+ 6
+ },
+ )?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let collective_op_id = match inst.op {
+ spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => {
+ crate::CollectiveOperation::Reduce
+ }
+ _ => {
+ let group_op_id = self.next()?;
+ match spirv::GroupOperation::from_u32(group_op_id) {
+ Some(spirv::GroupOperation::Reduce) => {
+ crate::CollectiveOperation::Reduce
+ }
+ Some(spirv::GroupOperation::InclusiveScan) => {
+ crate::CollectiveOperation::InclusiveScan
+ }
+ Some(spirv::GroupOperation::ExclusiveScan) => {
+ crate::CollectiveOperation::ExclusiveScan
+ }
+ _ => return Err(Error::UnsupportedGroupOperation(group_op_id)),
+ }
+ }
+ };
+ let argument_id = self.next()?;
+
+ let argument_lookup = self.lookup_expression.lookup(argument_id)?;
+ let argument_handle = get_expr_handle!(argument_id, argument_lookup);
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let op_id = match inst.op {
+ spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All,
+ spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any,
+ spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => {
+ crate::SubgroupOperation::Add
+ }
+ spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => {
+ crate::SubgroupOperation::Mul
+ }
+ spirv::Op::GroupNonUniformSMax
+ | spirv::Op::GroupNonUniformUMax
+ | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max,
+ spirv::Op::GroupNonUniformSMin
+ | spirv::Op::GroupNonUniformUMin
+ | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min,
+ spirv::Op::GroupNonUniformBitwiseAnd
+ | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And,
+ spirv::Op::GroupNonUniformBitwiseOr
+ | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or,
+ spirv::Op::GroupNonUniformBitwiseXor
+ | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor,
+ _ => unreachable!(),
+ };
+
+ let result_type = self.lookup_type.lookup(result_type_id)?;
+
+ let result_handle = ctx.expressions.append(
+ crate::Expression::SubgroupOperationResult {
+ ty: result_type.handle,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupCollectiveOperation {
+ result: result_handle,
+ op: op_id,
+ collective_op: collective_op_id,
+ argument: argument_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::GroupNonUniformBroadcastFirst
+ | Op::GroupNonUniformBroadcast
+ | Op::GroupNonUniformShuffle
+ | Op::GroupNonUniformShuffleDown
+ | Op::GroupNonUniformShuffleUp
+ | Op::GroupNonUniformShuffleXor => {
+ inst.expect(
+ if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
+ 5
+ } else {
+ 6
+ },
+ )?;
+ block.extend(emitter.finish(ctx.expressions));
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let argument_id = self.next()?;
+
+ let argument_lookup = self.lookup_expression.lookup(argument_id)?;
+ let argument_handle = get_expr_handle!(argument_id, argument_lookup);
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
+ crate::GatherMode::BroadcastFirst
+ } else {
+ let index_id = self.next()?;
+ let index_lookup = self.lookup_expression.lookup(index_id)?;
+ let index_handle = get_expr_handle!(index_id, index_lookup);
+ match inst.op {
+ spirv::Op::GroupNonUniformBroadcast => {
+ crate::GatherMode::Broadcast(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffle => {
+ crate::GatherMode::Shuffle(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleDown => {
+ crate::GatherMode::ShuffleDown(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleUp => {
+ crate::GatherMode::ShuffleUp(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleXor => {
+ crate::GatherMode::ShuffleXor(index_handle)
+ }
+ _ => unreachable!(),
+ }
+ };
+
+ let result_type = self.lookup_type.lookup(result_type_id)?;
+
+ let result_handle = ctx.expressions.append(
+ crate::Expression::SubgroupOperationResult {
+ ty: result_type.handle,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupGather {
+ result: result_handle,
+ mode,
+ argument: argument_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
}
};
@@ -3713,6 +3969,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
&mut self,
globals: &Arena<crate::GlobalVariable>,
constants: &Arena<crate::Constant>,
+ overrides: &Arena<crate::Override>,
) -> Arena<crate::Expression> {
let mut expressions = Arena::new();
#[allow(clippy::panic)]
@@ -3737,8 +3994,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
// register constants
for (&id, con) in self.lookup_constant.iter() {
- let span = constants.get_span(con.handle);
- let handle = expressions.append(crate::Expression::Constant(con.handle), span);
+ let (expr, span) = match con.inner {
+ Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)),
+ Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)),
+ };
+ let handle = expressions.append(expr, span);
self.lookup_expression.insert(
id,
LookupExpression {
@@ -3812,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
| S::Store { .. }
| S::ImageStore { .. }
| S::Atomic { .. }
- | S::RayQuery { .. } => {}
+ | S::RayQuery { .. }
+ | S::SubgroupBallot { .. }
+ | S::SubgroupCollectiveOperation { .. }
+ | S::SubgroupGather { .. } => {}
S::Call {
function: ref mut callee,
ref arguments,
@@ -3944,10 +4207,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
- Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
+ Op::ConstantComposite | Op::SpecConstantComposite => {
+ self.parse_composite_constant(inst, &mut module)
+ }
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
- Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
- Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
+ Op::ConstantTrue | Op::SpecConstantTrue => {
+ self.parse_bool_constant(inst, true, &mut module)
+ }
+ Op::ConstantFalse | Op::SpecConstantFalse => {
+ self.parse_bool_constant(inst, false, &mut module)
+ }
Op::Variable => self.parse_global_variable(inst, &mut module),
Op::Function => {
self.switch(ModuleState::Function, inst.op)?;
@@ -4504,9 +4773,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let length_id = self.next()?;
let length_const = self.lookup_constant.lookup(length_id)?;
- let size = resolve_constant(module.to_ctx(), length_const.handle)
+ let size = resolve_constant(module.to_ctx(), &length_const.inner)
.and_then(NonZeroU32::new)
- .ok_or(Error::InvalidArraySize(length_const.handle))?;
+ .ok_or(Error::InvalidArraySize(length_id))?;
let decor = self.future_decor.remove(&id).unwrap_or_default();
let base = self.lookup_type.lookup(type_id)?.handle;
@@ -4919,29 +5188,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
_ => return Err(Error::UnsupportedType(type_lookup.handle)),
};
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let span = self.span_from_with_op(start);
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Literal(literal), span);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_composite_constant(
@@ -4965,34 +5218,18 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let span = self.span_from_with_op(start);
let constant = self.lookup_constant.lookup(component_id)?;
let expr = module
- .const_expressions
- .append(crate::Expression::Constant(constant.handle), span);
+ .global_expressions
+ .append(constant.inner.to_expr(), span);
components.push(expr);
}
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let span = self.span_from_with_op(start);
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Compose { ty, components }, span);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_null_constant(
@@ -5010,23 +5247,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::ZeroValue(ty), span);
- let handle = module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- );
- self.lookup_constant
- .insert(id, LookupConstant { handle, type_id });
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_bool_constant(
@@ -5045,27 +5270,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
- let init = module.const_expressions.append(
+ let init = module.global_expressions.append(
crate::Expression::Literal(crate::Literal::Bool(value)),
span,
);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
+ }
+
+ fn insert_parsed_constant(
+ &mut self,
+ module: &mut crate::Module,
+ id: u32,
+ type_id: u32,
+ ty: Handle<crate::Type>,
+ init: Handle<crate::Expression>,
+ span: crate::Span,
+ ) -> Result<(), Error> {
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let inner = if let Some(id) = decor.specialization_constant_id {
+ let o = crate::Override {
+ name: decor.name,
+ id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?),
+ ty,
+ init: Some(init),
+ };
+ Constant::Override(module.overrides.append(o, span))
+ } else {
+ let c = crate::Constant {
+ name: decor.name,
+ ty,
+ init,
+ };
+ Constant::Constant(module.constants.append(c, span))
+ };
+
+ self.lookup_constant
+ .insert(id, LookupConstant { inner, type_id });
Ok(())
}
@@ -5087,8 +5329,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let span = self.span_from_with_op(start);
let lconst = self.lookup_constant.lookup(init_id)?;
let expr = module
- .const_expressions
- .append(crate::Expression::Constant(lconst.handle), span);
+ .global_expressions
+ .append(lconst.inner.to_expr(), span);
Some(expr)
} else {
None
@@ -5209,7 +5451,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
match null::generate_default_built_in(
Some(built_in),
ty,
- &mut module.const_expressions,
+ &mut module.global_expressions,
span,
) {
Ok(handle) => Some(handle),
@@ -5231,14 +5473,14 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let handle = null::generate_default_built_in(
built_in,
member.ty,
- &mut module.const_expressions,
+ &mut module.global_expressions,
span,
)?;
components.push(handle);
}
Some(
module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Compose { ty, components }, span),
)
}
@@ -5303,11 +5545,12 @@ fn make_index_literal(
Ok(expr)
}
-fn resolve_constant(
- gctx: crate::proc::GlobalCtx,
- constant: Handle<crate::Constant>,
-) -> Option<u32> {
- match gctx.const_expressions[gctx.constants[constant].init] {
+fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> {
+ let constant = match *constant {
+ Constant::Constant(constant) => constant,
+ Constant::Override(_) => return None,
+ };
+ match gctx.global_expressions[gctx.constants[constant].init] {
crate::Expression::Literal(crate::Literal::U32(id)) => Some(id),
crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32),
_ => None,
diff --git a/third_party/rust/naga/src/front/spv/null.rs b/third_party/rust/naga/src/front/spv/null.rs
index 42cccca80a..c7d3776841 100644
--- a/third_party/rust/naga/src/front/spv/null.rs
+++ b/third_party/rust/naga/src/front/spv/null.rs
@@ -5,14 +5,14 @@ use crate::arena::{Arena, Handle};
pub fn generate_default_built_in(
built_in: Option<crate::BuiltIn>,
ty: Handle<crate::Type>,
- const_expressions: &mut Arena<crate::Expression>,
+ global_expressions: &mut Arena<crate::Expression>,
span: crate::Span,
) -> Result<Handle<crate::Expression>, Error> {
let expr = match built_in {
Some(crate::BuiltIn::Position { .. }) => {
- let zero = const_expressions
+ let zero = global_expressions
.append(crate::Expression::Literal(crate::Literal::F32(0.0)), span);
- let one = const_expressions
+ let one = global_expressions
.append(crate::Expression::Literal(crate::Literal::F32(1.0)), span);
crate::Expression::Compose {
ty,
@@ -27,5 +27,5 @@ pub fn generate_default_built_in(
// Note: `crate::BuiltIn::ClipDistance` is intentionally left for the default path
_ => crate::Expression::ZeroValue(ty),
};
- Ok(const_expressions.append(expr, span))
+ Ok(global_expressions.append(expr, span))
}
diff --git a/third_party/rust/naga/src/front/wgsl/error.rs b/third_party/rust/naga/src/front/wgsl/error.rs
index 54aa8296b1..dc1339521c 100644
--- a/third_party/rust/naga/src/front/wgsl/error.rs
+++ b/third_party/rust/naga/src/front/wgsl/error.rs
@@ -13,6 +13,7 @@ use thiserror::Error;
#[derive(Clone, Debug)]
pub struct ParseError {
message: String,
+ // The first span should be the primary span, and the other ones should be complementary.
labels: Vec<(Span, Cow<'static, str>)>,
notes: Vec<String>,
}
@@ -190,7 +191,7 @@ pub enum Error<'a> {
expected: String,
got: String,
},
- MissingType(Span),
+ DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span),
InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span),
@@ -269,6 +270,11 @@ pub enum Error<'a> {
scalar: String,
inner: ConstantEvaluatorError,
},
+ ExceededLimitForNestedBraces {
+ span: Span,
+ limit: u8,
+ },
+ PipelineConstantIDValue(Span),
}
impl<'a> Error<'a> {
@@ -518,11 +524,11 @@ impl<'a> Error<'a> {
notes: vec![],
}
}
- Error::MissingType(name_span) => ParseError {
- message: format!("variable `{}` needs a type", &source[name_span]),
+ Error::DeclMissingTypeAndInit(name_span) => ParseError {
+ message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]),
labels: vec![(
name_span,
- format!("definition of `{}`", &source[name_span]).into(),
+ "needs a type specifier or initializer".into(),
)],
notes: vec![],
},
@@ -770,6 +776,21 @@ impl<'a> Error<'a> {
format!("the expression should have been converted to have {} scalar type", scalar),
]
},
+ Error::ExceededLimitForNestedBraces { span, limit } => ParseError {
+ message: "brace nesting limit reached".into(),
+ labels: vec![(span, "limit reached at this brace".into())],
+ notes: vec![
+ format!("nesting limit is currently set to {limit}"),
+ ],
+ },
+ Error::PipelineConstantIDValue(span) => ParseError {
+ message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(),
+ labels: vec![(
+ span,
+ "must be between 0 and 65535 inclusive".into(),
+ )],
+ notes: vec![],
+ },
}
}
}
diff --git a/third_party/rust/naga/src/front/wgsl/index.rs b/third_party/rust/naga/src/front/wgsl/index.rs
index a5524fe8f1..593405508f 100644
--- a/third_party/rust/naga/src/front/wgsl/index.rs
+++ b/third_party/rust/naga/src/front/wgsl/index.rs
@@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
+ ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
}
diff --git a/third_party/rust/naga/src/front/wgsl/lower/mod.rs b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
index 2ca6c182b7..e7cce17723 100644
--- a/third_party/rust/naga/src/front/wgsl/lower/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/lower/mod.rs
@@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> {
module: &'out mut crate::Module,
const_typifier: &'temp mut Typifier,
+
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'source> GlobalContext<'source, '_, '_> {
@@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> {
module: self.module,
const_typifier: self.const_typifier,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
+ }
+ }
+
+ fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> {
+ ExpressionContext {
+ ast_expressions: self.ast_expressions,
+ globals: self.globals,
+ types: self.types,
+ module: self.module,
+ const_typifier: self.const_typifier,
+ expr_type: ExpressionContextType::Override,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -164,7 +179,8 @@ pub struct StatementContext<'source, 'temp, 'out> {
/// with the form of the expressions; it is also tracking whether WGSL says
/// we should consider them to be const. See the use of `force_non_const` in
/// the code for lowering `let` bindings.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
@@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
ast_expressions: self.ast_expressions,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
module: self.module,
expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
local_table: self.local_table,
@@ -188,7 +205,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
block,
emitter,
typifier: self.typifier,
- expression_constness: self.expression_constness,
+ local_expression_kind_tracker: self.local_expression_kind_tracker,
}),
}
}
@@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -232,8 +250,8 @@ pub struct RuntimeExpressionContext<'temp, 'out> {
/// Which `Expression`s in `self.naga_expressions` are const expressions, in
/// the WGSL sense.
///
- /// See [`StatementContext::expression_constness`] for details.
- expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
+ /// See [`StatementContext::local_expression_kind_tracker`] for details.
+ local_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
}
/// The type of Naga IR expression we are lowering an [`ast::Expression`] to.
@@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> {
/// available in the [`ExpressionContext`], so this variant
/// carries no further information.
Constant,
+
+ /// We are lowering to an override expression, to be included in the module's
+ /// constant expression arena.
+ ///
+ /// Everything override expressions are allowed to refer to is
+ /// available in the [`ExpressionContext`], so this variant
+ /// carries no further information.
+ Override,
}
/// State for lowering an [`ast::Expression`] to Naga IR.
@@ -307,10 +333,11 @@ pub struct ExpressionContext<'source, 'temp, 'out> {
/// [`Module`]: crate::Module
module: &'out mut crate::Module,
- /// Type judgments for [`module::const_expressions`].
+ /// Type judgments for [`module::global_expressions`].
///
- /// [`module::const_expressions`]: crate::Module::const_expressions
+ /// [`module::global_expressions`]: crate::Module::global_expressions
const_typifier: &'temp mut Typifier,
+ global_expression_kind_tracker: &'temp mut crate::proc::ExpressionKindTracker,
/// Whether we are lowering a constant expression or a general
/// runtime expression, and the data needed in each case.
@@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
const_typifier: self.const_typifier,
module: self.module,
expr_type: ExpressionContextType::Constant,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
+ global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -344,11 +373,20 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
ExpressionContextType::Runtime(ref mut rctx) => ConstantEvaluator::for_wgsl_function(
self.module,
&mut rctx.function.expressions,
- rctx.expression_constness,
+ rctx.local_expression_kind_tracker,
rctx.emitter,
rctx.block,
),
- ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
+ ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ false,
+ ),
+ ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module(
+ self.module,
+ self.global_expression_kind_tracker,
+ true,
+ ),
}
}
@@ -358,24 +396,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
span: Span,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let mut eval = self.as_const_evaluator();
- match eval.try_eval_and_append(&expr, span) {
- Ok(expr) => Ok(expr),
-
- // `expr` is not a constant expression. This is fine as
- // long as we're not building `Module::const_expressions`.
- Err(err) => match self.expr_type {
- ExpressionContextType::Runtime(ref mut rctx) => {
- Ok(rctx.function.expressions.append(expr, span))
- }
- ExpressionContextType::Constant => Err(Error::ConstantEvaluatorError(err, span)),
- },
- }
+ eval.try_eval_and_append(expr, span)
+ .map_err(|e| Error::ConstantEvaluatorError(e, span))
}
fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
- if !ctx.expression_constness.is_const(handle) {
+ if !ctx.local_expression_kind_tracker.is_const(handle) {
return None;
}
@@ -385,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.ok()
}
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
+ ExpressionContextType::Override => None,
}
}
fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
- ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.module.global_expressions.get_span(handle)
+ }
}
}
fn typifier(&self) -> &Typifier {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ self.const_typifier
+ }
}
}
@@ -408,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
- ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ Err(Error::UnexpectedOperationInConstContext(span))
+ }
}
}
@@ -420,7 +455,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<crate::SwizzleComponent, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref rctx) => {
- if !rctx.expression_constness.is_const(expr) {
+ if !rctx.local_expression_kind_tracker.is_const(expr) {
return Err(Error::ExpectedConstExprConcreteIntegerScalar(
component_span,
));
@@ -445,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
// This means a `gather` operation appeared in a constant expression.
// This error refers to the `gather` itself, not its "component" argument.
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(gather_span))
}
}
@@ -471,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
// to also borrow self.module.types mutably below.
let typifier = match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
- ExpressionContextType::Constant => &*self.const_typifier,
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
+ &*self.const_typifier
+ }
};
Ok(typifier.register_type(handle, &mut self.module.types))
}
@@ -514,10 +551,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
typifier = &mut *ctx.typifier;
expressions = &ctx.function.expressions;
}
- ExpressionContextType::Constant => {
+ ExpressionContextType::Constant | ExpressionContextType::Override => {
resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
typifier = self.const_typifier;
- expressions = &self.module.const_expressions;
+ expressions = &self.module.global_expressions;
}
};
typifier
@@ -610,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
let result = self.append_expression(expression, span);
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
rctx.emitter.start(&rctx.function.expressions);
}
- ExpressionContextType::Constant => {}
+ ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
result
}
@@ -786,6 +823,7 @@ enum LoweredGlobalDecl {
Function(Handle<crate::Function>),
Var(Handle<crate::GlobalVariable>),
Const(Handle<crate::Constant>),
+ Override(Handle<crate::Override>),
Type(Handle<crate::Type>),
EntryPoint,
}
@@ -836,6 +874,29 @@ impl Texture {
}
}
+enum SubgroupGather {
+ BroadcastFirst,
+ Broadcast,
+ Shuffle,
+ ShuffleDown,
+ ShuffleUp,
+ ShuffleXor,
+}
+
+impl SubgroupGather {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "subgroupBroadcastFirst" => Self::BroadcastFirst,
+ "subgroupBroadcast" => Self::Broadcast,
+ "subgroupShuffle" => Self::Shuffle,
+ "subgroupShuffleDown" => Self::ShuffleDown,
+ "subgroupShuffleUp" => Self::ShuffleUp,
+ "subgroupShuffleXor" => Self::ShuffleXor,
+ _ => return None,
+ })
+ }
+}
+
pub struct Lowerer<'source, 'temp> {
index: &'temp Index<'source>,
layouter: Layouter,
@@ -861,6 +922,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
types: &tu.types,
module: &mut module,
const_typifier: &mut Typifier::new(),
+ global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker::new(),
};
for decl_handle in self.index.visit_ordered() {
@@ -877,7 +939,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let init;
if let Some(init_ast) = v.init {
- let mut ectx = ctx.as_const();
+ let mut ectx = ctx.as_override();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
@@ -956,7 +1018,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let handle = ctx.module.constants.append(
crate::Constant {
name: Some(c.name.name.to_string()),
- r#override: crate::Override::None,
ty,
init,
},
@@ -966,6 +1027,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
}
+ ast::GlobalDeclKind::Override(ref o) => {
+ let init = o
+ .init
+ .map(|init| self.expression(init, &mut ctx.as_override()))
+ .transpose()?;
+ let inferred_type = init
+ .map(|init| ctx.as_const().register_type(init))
+ .transpose()?;
+
+ let explicit_ty =
+ o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
+ .transpose()?;
+
+ let id =
+ o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
+ .transpose()?;
+
+ let id = if let Some((id, id_span)) = id {
+ Some(
+ u16::try_from(id)
+ .map_err(|_| Error::PipelineConstantIDValue(id_span))?,
+ )
+ } else {
+ None
+ };
+
+ let ty = match (explicit_ty, inferred_type) {
+ (Some(explicit_ty), Some(inferred_type)) => {
+ if explicit_ty == inferred_type {
+ explicit_ty
+ } else {
+ let gctx = ctx.module.to_ctx();
+ return Err(Error::InitializationTypeMismatch {
+ name: o.name.span,
+ expected: explicit_ty.to_wgsl(&gctx),
+ got: inferred_type.to_wgsl(&gctx),
+ });
+ }
+ }
+ (Some(explicit_ty), None) => explicit_ty,
+ (None, Some(inferred_type)) => inferred_type,
+ (None, None) => {
+ return Err(Error::DeclMissingTypeAndInit(o.name.span));
+ }
+ };
+
+ let handle = ctx.module.overrides.append(
+ crate::Override {
+ name: Some(o.name.name.to_string()),
+ id,
+ ty,
+ init,
+ },
+ span,
+ );
+
+ ctx.globals
+ .insert(o.name.name, LoweredGlobalDecl::Override(handle));
+ }
ast::GlobalDeclKind::Struct(ref s) => {
let handle = self.r#struct(s, span, &mut ctx)?;
ctx.globals
@@ -1000,6 +1120,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut local_table = FastHashMap::default();
let mut expressions = Arena::new();
let mut named_expressions = FastIndexMap::default();
+ let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();
let arguments = f
.arguments
@@ -1011,6 +1132,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Typed::Plain(expr));
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
+ local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime);
Ok(crate::FunctionArgument {
name: Some(arg.name.name.to_string()),
@@ -1053,7 +1175,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
named_expressions: &mut named_expressions,
types: ctx.types,
module: ctx.module,
- expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
+ local_expression_kind_tracker: &mut local_expression_kind_tracker,
+ global_expression_kind_tracker: ctx.global_expression_kind_tracker,
};
let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
ensure_block_returns(&mut body);
@@ -1132,7 +1255,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// affects when errors must be reported, so we can't even
// treat suitable `let` bindings as constant as an
// optimization.
- ctx.expression_constness.force_non_const(value);
+ ctx.local_expression_kind_tracker.force_non_const(value);
let explicit_ty =
l.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx.as_global()))
@@ -1203,7 +1326,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty = explicit_ty;
initializer = None;
}
- (None, None) => return Err(Error::MissingType(v.name.span)),
+ (None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)),
}
let (const_initializer, initializer) = {
@@ -1216,7 +1339,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// - the initialization is not a constant
// expression, so its value depends on the
// state at the point of initialization.
- if is_inside_loop || !ctx.expression_constness.is_const(init) {
+ if is_inside_loop
+ || !ctx.local_expression_kind_tracker.is_const_or_override(init)
+ {
(None, Some(init))
} else {
(Some(init), None)
@@ -1469,6 +1594,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.function
.expressions
.append(crate::Expression::Binary { op, left, right }, stmt.span);
+ rctx.local_expression_kind_tracker
+ .insert(left, crate::proc::ExpressionKind::Runtime);
+ rctx.local_expression_kind_tracker
+ .insert(value, crate::proc::ExpressionKind::Runtime);
block.extend(emitter.finish(&ctx.function.expressions));
crate::Statement::Store {
@@ -1562,7 +1691,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
LoweredGlobalDecl::Const(handle) => {
Typed::Plain(crate::Expression::Constant(handle))
}
- _ => {
+ LoweredGlobalDecl::Override(handle) => {
+ Typed::Plain(crate::Expression::Override(handle))
+ }
+ LoweredGlobalDecl::Function(_)
+ | LoweredGlobalDecl::Type(_)
+ | LoweredGlobalDecl::EntryPoint => {
return Err(Error::Unexpected(span, ExpectedToken::Variable));
}
};
@@ -1819,9 +1953,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
Ok(Some(handle))
}
- Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
- Err(Error::Unexpected(function.span, ExpectedToken::Function))
- }
+ Some(
+ &LoweredGlobalDecl::Const(_)
+ | &LoweredGlobalDecl::Override(_)
+ | &LoweredGlobalDecl::Var(_),
+ ) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
Some(&LoweredGlobalDecl::Function(function)) => {
let arguments = arguments
@@ -1835,9 +1971,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
let result = has_result.then(|| {
- rctx.function
+ let result = rctx
+ .function
.expressions
- .append(crate::Expression::CallResult(function), span)
+ .append(crate::Expression::CallResult(function), span);
+ rctx.local_expression_kind_tracker
+ .insert(result, crate::proc::ExpressionKind::Runtime);
+ result
});
rctx.emitter.start(&rctx.function.expressions);
rctx.block.push(
@@ -1937,6 +2077,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
} else if let Some(fun) = Texture::map(function.name) {
self.texture_sample_helper(fun, arguments, span, ctx)?
+ } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) {
+ return Ok(Some(
+ self.subgroup_operation_helper(span, op, cop, arguments, ctx)?,
+ ));
+ } else if let Some(mode) = SubgroupGather::map(function.name) {
+ return Ok(Some(
+ self.subgroup_gather_helper(span, mode, arguments, ctx)?,
+ ));
+ } else if let Some(fun) = crate::AtomicFunction::map(function.name) {
+ return Ok(Some(self.atomic_helper(span, fun, arguments, ctx)?));
} else {
match function.name {
"select" => {
@@ -1982,70 +2132,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Store { pointer, value }, span);
return Ok(None);
}
- "atomicAdd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Add,
- arguments,
- ctx,
- )?))
- }
- "atomicSub" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Subtract,
- arguments,
- ctx,
- )?))
- }
- "atomicAnd" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::And,
- arguments,
- ctx,
- )?))
- }
- "atomicOr" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::InclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicXor" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::ExclusiveOr,
- arguments,
- ctx,
- )?))
- }
- "atomicMin" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Min,
- arguments,
- ctx,
- )?))
- }
- "atomicMax" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Max,
- arguments,
- ctx,
- )?))
- }
- "atomicExchange" => {
- return Ok(Some(self.atomic_helper(
- span,
- crate::AtomicFunction::Exchange { compare: None },
- arguments,
- ctx,
- )?))
- }
"atomicCompareExchangeWeak" => {
let mut args = ctx.prepare_args(arguments, 3, span);
@@ -2104,6 +2190,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
return Ok(None);
}
+ "subgroupBarrier" => {
+ ctx.prepare_args(arguments, 0, span).finish()?;
+
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span);
+ return Ok(None);
+ }
"workgroupUniformLoad" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = args.next()?;
@@ -2311,6 +2405,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
return Ok(Some(handle));
}
+ "subgroupBallot" => {
+ let mut args = ctx.prepare_args(arguments, 0, span);
+ let predicate = if arguments.len() == 1 {
+ Some(self.expression(args.next()?, ctx)?)
+ } else {
+ None
+ };
+ args.finish()?;
+
+ let result = ctx
+ .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block
+ .push(crate::Statement::SubgroupBallot { result, predicate }, span);
+ return Ok(Some(result));
+ }
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
};
@@ -2502,6 +2612,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
})
}
+ fn subgroup_operation_helper(
+ &mut self,
+ span: Span,
+ op: crate::SubgroupOperation,
+ collective_op: crate::CollectiveOperation,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 1, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupCollectiveOperation {
+ op,
+ collective_op,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
+ fn subgroup_gather_helper(
+ &mut self,
+ span: Span,
+ mode: SubgroupGather,
+ arguments: &[Handle<ast::Expression<'source>>],
+ ctx: &mut ExpressionContext<'source, '_, '_>,
+ ) -> Result<Handle<crate::Expression>, Error<'source>> {
+ let mut args = ctx.prepare_args(arguments, 2, span);
+
+ let argument = self.expression(args.next()?, ctx)?;
+
+ use SubgroupGather as Sg;
+ let mode = if let Sg::BroadcastFirst = mode {
+ crate::GatherMode::BroadcastFirst
+ } else {
+ let index = self.expression(args.next()?, ctx)?;
+ match mode {
+ Sg::Broadcast => crate::GatherMode::Broadcast(index),
+ Sg::Shuffle => crate::GatherMode::Shuffle(index),
+ Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index),
+ Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index),
+ Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index),
+ Sg::BroadcastFirst => unreachable!(),
+ }
+ };
+
+ args.finish()?;
+
+ let ty = ctx.register_type(argument)?;
+
+ let result =
+ ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?;
+ let rctx = ctx.runtime_expression_ctx(span)?;
+ rctx.block.push(
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ },
+ span,
+ );
+ Ok(result)
+ }
+
fn r#struct(
&mut self,
s: &ast::Struct<'source>,
@@ -2760,3 +2944,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}
}
+
+impl crate::AtomicFunction {
+ pub fn map(word: &str) -> Option<Self> {
+ Some(match word {
+ "atomicAdd" => crate::AtomicFunction::Add,
+ "atomicSub" => crate::AtomicFunction::Subtract,
+ "atomicAnd" => crate::AtomicFunction::And,
+ "atomicOr" => crate::AtomicFunction::InclusiveOr,
+ "atomicXor" => crate::AtomicFunction::ExclusiveOr,
+ "atomicMin" => crate::AtomicFunction::Min,
+ "atomicMax" => crate::AtomicFunction::Max,
+ "atomicExchange" => crate::AtomicFunction::Exchange { compare: None },
+ _ => return None,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/front/wgsl/mod.rs b/third_party/rust/naga/src/front/wgsl/mod.rs
index b6151fe1c0..aec1e657fc 100644
--- a/third_party/rust/naga/src/front/wgsl/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/mod.rs
@@ -44,6 +44,17 @@ impl Frontend {
}
}
+/// <div class="warning">
+// NOTE: Keep this in sync with `wgpu::Device::create_shader_module`!
+// NOTE: Keep this in sync with `wgpu_core::Global::device_create_shader_module`!
+///
+/// This function may consume a lot of stack space. Compiler-enforced limits for parsing recursion
+/// exist; if shader compilation runs into them, it will return an error gracefully. However, on
+/// some build profiles and platforms, the default stack size for a thread may be exceeded before
+/// this limit is reached during parsing. Callers should ensure that there is enough stack space
+/// for this, particularly if calls to this method are exposed to user input.
+///
+/// </div>
pub fn parse_str(source: &str) -> Result<crate::Module, ParseError> {
Frontend::new().parse(source)
}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/ast.rs b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
index dbaac523cb..ea8013ee7c 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/ast.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/ast.rs
@@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> {
Fn(Function<'a>),
Var(GlobalVariable<'a>),
Const(Const<'a>),
+ Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
}
@@ -200,6 +201,14 @@ pub struct Const<'a> {
pub init: Handle<Expression<'a>>,
}
+#[derive(Debug)]
+pub struct Override<'a> {
+ pub name: Ident<'a>,
+ pub id: Option<Handle<Expression<'a>>>,
+ pub ty: Option<Handle<Type<'a>>>,
+ pub init: Option<Handle<Expression<'a>>>,
+}
+
/// The size of an [`Array`] or [`BindingArray`].
///
/// [`Array`]: Type::Array
diff --git a/third_party/rust/naga/src/front/wgsl/parse/conv.rs b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
index 1a4911a3bd..207f0eda41 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/conv.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/conv.rs
@@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result<crate::BuiltIn, Error<'_>>
"local_invocation_index" => crate::BuiltIn::LocalInvocationIndex,
"workgroup_id" => crate::BuiltIn::WorkGroupId,
"num_workgroups" => crate::BuiltIn::NumWorkGroups,
+ // subgroup
+ "num_subgroups" => crate::BuiltIn::NumSubgroups,
+ "subgroup_id" => crate::BuiltIn::SubgroupId,
+ "subgroup_size" => crate::BuiltIn::SubgroupSize,
+ "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId,
_ => return Err(Error::UnknownBuiltin(span)),
})
}
@@ -260,3 +265,26 @@ pub fn map_conservative_depth(
_ => Err(Error::UnknownConservativeDepth(span)),
}
}
+
+pub fn map_subgroup_operation(
+ word: &str,
+) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> {
+ use crate::CollectiveOperation as co;
+ use crate::SubgroupOperation as sg;
+ Some(match word {
+ "subgroupAll" => (sg::All, co::Reduce),
+ "subgroupAny" => (sg::Any, co::Reduce),
+ "subgroupAdd" => (sg::Add, co::Reduce),
+ "subgroupMul" => (sg::Mul, co::Reduce),
+ "subgroupMin" => (sg::Min, co::Reduce),
+ "subgroupMax" => (sg::Max, co::Reduce),
+ "subgroupAnd" => (sg::And, co::Reduce),
+ "subgroupOr" => (sg::Or, co::Reduce),
+ "subgroupXor" => (sg::Xor, co::Reduce),
+ "subgroupExclusiveAdd" => (sg::Add, co::ExclusiveScan),
+ "subgroupExclusiveMul" => (sg::Mul, co::ExclusiveScan),
+ "subgroupInclusiveAdd" => (sg::Add, co::InclusiveScan),
+ "subgroupInclusiveMul" => (sg::Mul, co::InclusiveScan),
+ _ => return None,
+ })
+}
diff --git a/third_party/rust/naga/src/front/wgsl/parse/mod.rs b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
index 51fc2f013b..79ea1ae609 100644
--- a/third_party/rust/naga/src/front/wgsl/parse/mod.rs
+++ b/third_party/rust/naga/src/front/wgsl/parse/mod.rs
@@ -1619,22 +1619,21 @@ impl Parser {
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
block: &mut ast::Block<'a>,
+ brace_nesting_level: u8,
) -> Result<(), Error<'a>> {
self.push_rule_span(Rule::Statement, lexer);
match lexer.peek() {
(Token::Separator(';'), _) => {
let _ = lexer.next();
self.pop_rule_span(lexer);
- return Ok(());
}
(Token::Paren('{'), _) => {
- let (inner, span) = self.block(lexer, ctx)?;
+ let (inner, span) = self.block(lexer, ctx, brace_nesting_level)?;
block.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(inner),
span,
});
self.pop_rule_span(lexer);
- return Ok(());
}
(Token::Word(word), _) => {
let kind = match word {
@@ -1711,7 +1710,7 @@ impl Parser {
let _ = lexer.next();
let condition = self.general_expression(lexer, ctx)?;
- let accept = self.block(lexer, ctx)?.0;
+ let accept = self.block(lexer, ctx, brace_nesting_level)?.0;
let mut elsif_stack = Vec::new();
let mut elseif_span_start = lexer.start_byte_offset();
@@ -1722,12 +1721,12 @@ impl Parser {
if !lexer.skip(Token::Word("if")) {
// ... else { ... }
- break self.block(lexer, ctx)?.0;
+ break self.block(lexer, ctx, brace_nesting_level)?.0;
}
// ... else if (...) { ... }
let other_condition = self.general_expression(lexer, ctx)?;
- let other_block = self.block(lexer, ctx)?;
+ let other_block = self.block(lexer, ctx, brace_nesting_level)?;
elsif_stack.push((elseif_span_start, other_condition, other_block));
elseif_span_start = lexer.start_byte_offset();
};
@@ -1759,7 +1758,9 @@ impl Parser {
"switch" => {
let _ = lexer.next();
let selector = self.general_expression(lexer, ctx)?;
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level =
+ Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
let mut cases = Vec::new();
loop {
@@ -1784,7 +1785,7 @@ impl Parser {
});
};
- let body = self.block(lexer, ctx)?.0;
+ let body = self.block(lexer, ctx, brace_nesting_level)?.0;
cases.push(ast::SwitchCase {
value,
@@ -1794,7 +1795,7 @@ impl Parser {
}
(Token::Word("default"), _) => {
lexer.skip(Token::Separator(':'));
- let body = self.block(lexer, ctx)?.0;
+ let body = self.block(lexer, ctx, brace_nesting_level)?.0;
cases.push(ast::SwitchCase {
value: ast::SwitchValue::Default,
body,
@@ -1810,7 +1811,7 @@ impl Parser {
ast::StatementKind::Switch { selector, cases }
}
- "loop" => self.r#loop(lexer, ctx)?,
+ "loop" => self.r#loop(lexer, ctx, brace_nesting_level)?,
"while" => {
let _ = lexer.next();
let mut body = ast::Block::default();
@@ -1834,7 +1835,7 @@ impl Parser {
span,
});
- let (block, span) = self.block(lexer, ctx)?;
+ let (block, span) = self.block(lexer, ctx, brace_nesting_level)?;
body.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(block),
span,
@@ -1857,7 +1858,9 @@ impl Parser {
let (_, span) = {
let ctx = &mut *ctx;
let block = &mut *block;
- lexer.capture_span(|lexer| self.statement(lexer, ctx, block))?
+ lexer.capture_span(|lexer| {
+ self.statement(lexer, ctx, block, brace_nesting_level)
+ })?
};
if block.stmts.len() != num_statements {
@@ -1902,7 +1905,7 @@ impl Parser {
lexer.expect(Token::Paren(')'))?;
}
- let (block, span) = self.block(lexer, ctx)?;
+ let (block, span) = self.block(lexer, ctx, brace_nesting_level)?;
body.stmts.push(ast::Statement {
kind: ast::StatementKind::Block(block),
span,
@@ -1964,13 +1967,15 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
+ brace_nesting_level: u8,
) -> Result<ast::StatementKind<'a>, Error<'a>> {
let _ = lexer.next();
let mut body = ast::Block::default();
let mut continuing = ast::Block::default();
let mut break_if = None;
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
ctx.local_table.push_scope();
@@ -1980,7 +1985,9 @@ impl Parser {
// the last thing in the loop body
// Expect a opening brace to start the continuing block
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level =
+ Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
loop {
if lexer.skip(Token::Word("break")) {
// Branch for the `break if` statement, this statement
@@ -2009,7 +2016,7 @@ impl Parser {
break;
} else {
// Otherwise try to parse a statement
- self.statement(lexer, ctx, &mut continuing)?;
+ self.statement(lexer, ctx, &mut continuing, brace_nesting_level)?;
}
}
// Since the continuing block must be the last part of the loop body,
@@ -2023,7 +2030,7 @@ impl Parser {
break;
}
// Otherwise try to parse a statement
- self.statement(lexer, ctx, &mut body)?;
+ self.statement(lexer, ctx, &mut body, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2040,15 +2047,17 @@ impl Parser {
&mut self,
lexer: &mut Lexer<'a>,
ctx: &mut ExpressionContext<'a, '_, '_>,
+ brace_nesting_level: u8,
) -> Result<(ast::Block<'a>, Span), Error<'a>> {
self.push_rule_span(Rule::Block, lexer);
ctx.local_table.push_scope();
- lexer.expect(Token::Paren('{'))?;
+ let brace_span = lexer.expect_span(Token::Paren('{'))?;
+ let brace_nesting_level = Self::increase_brace_nesting(brace_nesting_level, brace_span)?;
let mut block = ast::Block::default();
while !lexer.skip(Token::Paren('}')) {
- self.statement(lexer, ctx, &mut block)?;
+ self.statement(lexer, ctx, &mut block, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2135,9 +2144,10 @@ impl Parser {
// do not use `self.block` here, since we must not push a new scope
lexer.expect(Token::Paren('{'))?;
+ let brace_nesting_level = 1;
let mut body = ast::Block::default();
while !lexer.skip(Token::Paren('}')) {
- self.statement(lexer, &mut ctx, &mut body)?;
+ self.statement(lexer, &mut ctx, &mut body, brace_nesting_level)?;
}
ctx.local_table.pop_scope();
@@ -2170,6 +2180,7 @@ impl Parser {
let mut early_depth_test = ParsedAttribute::default();
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());
+ let mut id = ParsedAttribute::default();
let mut dependencies = FastIndexSet::default();
let mut ctx = ExpressionContext {
@@ -2193,6 +2204,11 @@ impl Parser {
bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
+ ("id", name_span) => {
+ lexer.expect(Token::Paren('('))?;
+ id.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
+ lexer.expect(Token::Paren(')'))?;
+ }
("vertex", name_span) => {
stage.set(crate::ShaderStage::Vertex, name_span)?;
}
@@ -2283,6 +2299,30 @@ impl Parser {
Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
}
+ (Token::Word("override"), _) => {
+ let name = lexer.next_ident()?;
+
+ let ty = if lexer.skip(Token::Separator(':')) {
+ Some(self.type_decl(lexer, &mut ctx)?)
+ } else {
+ None
+ };
+
+ let init = if lexer.skip(Token::Operation('=')) {
+ Some(self.general_expression(lexer, &mut ctx)?)
+ } else {
+ None
+ };
+
+ lexer.expect(Token::Separator(';'))?;
+
+ Some(ast::GlobalDeclKind::Override(ast::Override {
+ name,
+ id: id.value,
+ ty,
+ init,
+ }))
+ }
(Token::Word("var"), _) => {
let mut var = self.variable_decl(lexer, &mut ctx)?;
var.binding = binding.take();
@@ -2347,4 +2387,30 @@ impl Parser {
Ok(tu)
}
+
+ const fn increase_brace_nesting(
+ brace_nesting_level: u8,
+ brace_span: Span,
+ ) -> Result<u8, Error<'static>> {
+ // From [spec.](https://gpuweb.github.io/gpuweb/wgsl/#limits):
+ //
+ // > § 2.4. Limits
+ // >
+ // > …
+ // >
+ // > Maximum nesting depth of brace-enclosed statements in a function[:] 127
+ //
+ // _However_, we choose 64 instead because (a) it avoids stack overflows in CI and
+ // (b) we expect the limit to be decreased to 63 based on this conversation in
+ // WebGPU CTS upstream:
+ // <https://github.com/gpuweb/cts/pull/3389#discussion_r1543742701>
+ const BRACE_NESTING_MAXIMUM: u8 = 64;
+ if brace_nesting_level + 1 > BRACE_NESTING_MAXIMUM {
+ return Err(Error::ExceededLimitForNestedBraces {
+ span: brace_span,
+ limit: BRACE_NESTING_MAXIMUM,
+ });
+ }
+ Ok(brace_nesting_level + 1)
+ }
}
diff --git a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
index c8331ace09..63bc9f7317 100644
--- a/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
+++ b/third_party/rust/naga/src/front/wgsl/to_wgsl.rs
@@ -226,7 +226,8 @@ mod tests {
let gctx = crate::proc::GlobalCtx {
types: &types,
constants: &crate::Arena::new(),
- const_expressions: &crate::Arena::new(),
+ overrides: &crate::Arena::new(),
+ global_expressions: &crate::Arena::new(),
};
let array = crate::TypeInner::Array {
base: mytype1,
diff --git a/third_party/rust/naga/src/lib.rs b/third_party/rust/naga/src/lib.rs
index 4b45174300..24e1b02c76 100644
--- a/third_party/rust/naga/src/lib.rs
+++ b/third_party/rust/naga/src/lib.rs
@@ -34,9 +34,6 @@ with optional span info, representing a series of statements executed in order.
`EntryPoint`s or `Function` is a `Block`, and `Statement` has a
[`Block`][Statement::Block] variant.
-If the `clone` feature is enabled, [`Arena`], [`UniqueArena`], [`Type`], [`TypeInner`],
-[`Constant`], [`Function`], [`EntryPoint`] and [`Module`] can be cloned.
-
## Arenas
To improve translator performance and reduce memory usage, most structures are
@@ -175,7 +172,7 @@ tree.
A Naga *constant expression* is one of the following [`Expression`]
variants, whose operands (if any) are also constant expressions:
- [`Literal`]
-- [`Constant`], for [`Constant`s][const_type] whose [`override`] is [`None`]
+- [`Constant`], for [`Constant`]s
- [`ZeroValue`], for fixed-size types
- [`Compose`]
- [`Access`]
@@ -194,8 +191,7 @@ A constant expression can be evaluated at module translation time.
## Override expressions
A Naga *override expression* is the same as a [constant expression],
-except that it is also allowed to refer to [`Constant`s][const_type]
-whose [`override`] is something other than [`None`].
+except that it is also allowed to reference other [`Override`]s.
An override expression can be evaluated at pipeline creation time.
@@ -238,10 +234,6 @@ An override expression can be evaluated at pipeline creation time.
[`Math`]: Expression::Math
[`As`]: Expression::As
-[const_type]: Constant
-[`override`]: Constant::override
-[`None`]: Override::None
-
[constant expression]: index.html#constant-expressions
*/
@@ -282,6 +274,7 @@ pub mod back;
mod block;
#[cfg(feature = "compact")]
pub mod compact;
+pub mod error;
pub mod front;
pub mod keywords;
pub mod proc;
@@ -439,6 +432,11 @@ pub enum BuiltIn {
WorkGroupId,
WorkGroupSize,
NumWorkGroups,
+ // subgroup
+ NumSubgroups,
+ SubgroupId,
+ SubgroupSize,
+ SubgroupInvocationId,
}
/// Number of bytes per scalar.
@@ -874,7 +872,7 @@ pub enum TypeInner {
BindingArray { base: Handle<Type>, size: ArraySize },
}
-#[derive(Debug, Clone, Copy, PartialOrd)]
+#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
@@ -892,41 +890,37 @@ pub enum Literal {
AbstractFloat(f64),
}
-#[derive(Debug, PartialEq)]
-#[cfg_attr(feature = "clone", derive(Clone))]
+/// Pipeline-overridable constant.
+#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
-pub enum Override {
- None,
- ByName,
- ByNameOrId(u32),
+pub struct Override {
+ pub name: Option<String>,
+ /// Pipeline Constant ID.
+ pub id: Option<u16>,
+ pub ty: Handle<Type>,
+
+ /// The default value of the pipeline-overridable constant.
+ ///
+ /// This [`Handle`] refers to [`Module::global_expressions`], not
+ /// any [`Function::expressions`] arena.
+ pub init: Option<Handle<Expression>>,
}
/// Constant value.
-#[derive(Debug, PartialEq)]
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub struct Constant {
pub name: Option<String>,
- pub r#override: Override,
pub ty: Handle<Type>,
/// The value of the constant.
///
- /// This [`Handle`] refers to [`Module::const_expressions`], not
+ /// This [`Handle`] refers to [`Module::global_expressions`], not
/// any [`Function::expressions`] arena.
- ///
- /// If [`override`] is [`None`], then this must be a Naga
- /// [constant expression]. Otherwise, this may be a Naga
- /// [override expression] or [constant expression].
- ///
- /// [`override`]: Constant::override
- /// [`None`]: Override::None
- /// [constant expression]: index.html#constant-expressions
- /// [override expression]: index.html#override-expressions
pub init: Handle<Expression>,
}
@@ -992,7 +986,7 @@ pub struct GlobalVariable {
pub ty: Handle<Type>,
/// Initial value for this variable.
///
- /// Expression handle lives in const_expressions
+ /// Expression handle lives in global_expressions
pub init: Option<Handle<Expression>>,
}
@@ -1010,7 +1004,7 @@ pub struct LocalVariable {
///
/// This handle refers to this `LocalVariable`'s function's
/// [`expressions`] arena, but it is required to be an evaluated
- /// constant expression.
+ /// override expression.
///
/// [`expressions`]: Function::expressions
pub init: Option<Handle<Expression>>,
@@ -1289,6 +1283,51 @@ pub enum SwizzleComponent {
W = 3,
}
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum GatherMode {
+ /// All gather from the active lane with the smallest index
+ BroadcastFirst,
+ /// All gather from the same lane at the index given by the expression
+ Broadcast(Handle<Expression>),
+ /// Each gathers from a different lane at the index given by the expression
+ Shuffle(Handle<Expression>),
+ /// Each gathers from their lane plus the shift given by the expression
+ ShuffleDown(Handle<Expression>),
+ /// Each gathers from their lane minus the shift given by the expression
+ ShuffleUp(Handle<Expression>),
+ /// Each gathers from their lane xored with the given by the expression
+ ShuffleXor(Handle<Expression>),
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum SubgroupOperation {
+ All = 0,
+ Any = 1,
+ Add = 2,
+ Mul = 3,
+ Min = 4,
+ Max = 5,
+ And = 6,
+ Or = 7,
+ Xor = 8,
+}
+
+#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
+#[cfg_attr(feature = "serialize", derive(Serialize))]
+#[cfg_attr(feature = "deserialize", derive(Deserialize))]
+#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
+pub enum CollectiveOperation {
+ Reduce = 0,
+ InclusiveScan = 1,
+ ExclusiveScan = 2,
+}
+
bitflags::bitflags! {
/// Memory barrier flags.
#[cfg_attr(feature = "serialize", derive(Serialize))]
@@ -1297,9 +1336,11 @@ bitflags::bitflags! {
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct Barrier: u32 {
/// Barrier affects all `AddressSpace::Storage` accesses.
- const STORAGE = 0x1;
+ const STORAGE = 1 << 0;
/// Barrier affects all `AddressSpace::WorkGroup` accesses.
- const WORK_GROUP = 0x2;
+ const WORK_GROUP = 1 << 1;
+ /// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction.
+ const SUB_GROUP = 1 << 2;
}
}
@@ -1315,6 +1356,8 @@ pub enum Expression {
Literal(Literal),
/// Constant value.
Constant(Handle<Constant>),
+ /// Pipeline-overridable constant.
+ Override(Handle<Override>),
/// Zero value of a type.
ZeroValue(Handle<Type>),
/// Composite expression.
@@ -1440,7 +1483,7 @@ pub enum Expression {
gather: Option<SwizzleComponent>,
coordinate: Handle<Expression>,
array_index: Option<Handle<Expression>>,
- /// Expression handle lives in const_expressions
+ /// Expression handle lives in global_expressions
offset: Option<Handle<Expression>>,
level: SampleLevel,
depth_ref: Option<Handle<Expression>>,
@@ -1598,6 +1641,15 @@ pub enum Expression {
query: Handle<Expression>,
committed: bool,
},
+ /// Result of a [`SubgroupBallot`] statement.
+ ///
+ /// [`SubgroupBallot`]: Statement::SubgroupBallot
+ SubgroupBallotResult,
+ /// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement.
+ ///
+ /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation
+ /// [`SubgroupGather`]: Statement::SubgroupGather
+ SubgroupOperationResult { ty: Handle<Type> },
}
pub use block::Block;
@@ -1882,6 +1934,39 @@ pub enum Statement {
/// The specific operation we're performing on `query`.
fun: RayQueryFunction,
},
+ /// Calculate a bitmask using a boolean from each active thread in the subgroup
+ SubgroupBallot {
+ /// The [`SubgroupBallotResult`] expression representing this load's result.
+ ///
+ /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult
+ result: Handle<Expression>,
+ /// The value from this thread to store in the ballot
+ predicate: Option<Handle<Expression>>,
+ },
+ /// Gather a value from another active thread in the subgroup
+ SubgroupGather {
+ /// Specifies which thread to gather from
+ mode: GatherMode,
+ /// The value to broadcast over
+ argument: Handle<Expression>,
+ /// The [`SubgroupOperationResult`] expression representing this load's result.
+ ///
+ /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
+ result: Handle<Expression>,
+ },
+ /// Compute a collective operation across all active threads in the subgroup
+ SubgroupCollectiveOperation {
+ /// What operation to compute
+ op: SubgroupOperation,
+ /// How to combine the results
+ collective_op: CollectiveOperation,
+ /// The value to compute over
+ argument: Handle<Expression>,
+ /// The [`SubgroupOperationResult`] expression representing this load's result.
+ ///
+ /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult
+ result: Handle<Expression>,
+ },
}
/// A function argument.
@@ -1913,8 +1998,7 @@ pub struct FunctionResult {
}
/// A function defined in the module.
-#[derive(Debug, Default)]
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
@@ -1978,8 +2062,7 @@ pub struct Function {
/// [`Location`]: Binding::Location
/// [`function`]: EntryPoint::function
/// [`stage`]: EntryPoint::stage
-#[derive(Debug)]
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
@@ -2003,8 +2086,7 @@ pub struct EntryPoint {
/// These cannot be spelled in WGSL source.
///
/// Stored in [`SpecialTypes::predeclared_types`] and created by [`Module::generate_predeclared_type`].
-#[derive(Debug, PartialEq, Eq, Hash)]
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
@@ -2021,8 +2103,7 @@ pub enum PredeclaredType {
}
/// Set of special types that can be optionally generated by the frontends.
-#[derive(Debug, Default)]
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
@@ -2057,8 +2138,7 @@ pub struct SpecialTypes {
/// Alternatively, you can load an existing shader using one of the [available front ends][front].
///
/// When finished, you can export modules using one of the [available backends][back].
-#[derive(Debug, Default)]
-#[cfg_attr(feature = "clone", derive(Clone))]
+#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
@@ -2069,6 +2149,8 @@ pub struct Module {
pub special_types: SpecialTypes,
/// Arena for the constants defined in this module.
pub constants: Arena<Constant>,
+ /// Arena for the pipeline-overridable constants defined in this module.
+ pub overrides: Arena<Override>,
/// Arena for the global variables defined in this module.
pub global_variables: Arena<GlobalVariable>,
/// [Constant expressions] and [override expressions] used by this module.
@@ -2078,7 +2160,7 @@ pub struct Module {
///
/// [Constant expressions]: index.html#constant-expressions
/// [override expressions]: index.html#override-expressions
- pub const_expressions: Arena<Expression>,
+ pub global_expressions: Arena<Expression>,
/// Arena for the functions defined in this module.
///
/// Each function must appear in this arena strictly before all its callers.
diff --git a/third_party/rust/naga/src/proc/constant_evaluator.rs b/third_party/rust/naga/src/proc/constant_evaluator.rs
index 983af3718c..ead3d00980 100644
--- a/third_party/rust/naga/src/proc/constant_evaluator.rs
+++ b/third_party/rust/naga/src/proc/constant_evaluator.rs
@@ -4,8 +4,8 @@ use arrayvec::ArrayVec;
use crate::{
arena::{Arena, Handle, UniqueArena},
- ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner,
- UnaryOperator,
+ ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
+ TypeInner, UnaryOperator,
};
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
@@ -253,9 +253,20 @@ gen_component_wise_extractor! {
}
#[derive(Debug)]
-enum Behavior {
- Wgsl,
- Glsl,
+enum Behavior<'a> {
+ Wgsl(WgslRestrictions<'a>),
+ Glsl(GlslRestrictions<'a>),
+}
+
+impl Behavior<'_> {
+ /// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
+ const fn has_runtime_restrictions(&self) -> bool {
+ matches!(
+ self,
+ &Behavior::Wgsl(WgslRestrictions::Runtime(_))
+ | &Behavior::Glsl(GlslRestrictions::Runtime(_))
+ )
+ }
}
/// A context for evaluating constant expressions.
@@ -278,7 +289,7 @@ enum Behavior {
#[derive(Debug)]
pub struct ConstantEvaluator<'a> {
/// Which language's evaluation rules we should follow.
- behavior: Behavior,
+ behavior: Behavior<'a>,
/// The module's type arena.
///
@@ -291,71 +302,155 @@ pub struct ConstantEvaluator<'a> {
/// The module's constant arena.
constants: &'a Arena<Constant>,
+ /// The module's override arena.
+ overrides: &'a Arena<Override>,
+
/// The arena to which we are contributing expressions.
expressions: &'a mut Arena<Expression>,
- /// When `self.expressions` refers to a function's local expression
- /// arena, this needs to be populated
- function_local_data: Option<FunctionLocalData<'a>>,
+ /// Tracks the constness of expressions residing in [`Self::expressions`]
+ expression_kind_tracker: &'a mut ExpressionKindTracker,
+}
+
+#[derive(Debug)]
+enum WgslRestrictions<'a> {
+ /// - const-expressions will be evaluated and inserted in the arena
+ Const,
+ /// - const-expressions will be evaluated and inserted in the arena
+ /// - override-expressions will be inserted in the arena
+ Override,
+ /// - const-expressions will be evaluated and inserted in the arena
+ /// - override-expressions will be inserted in the arena
+ /// - runtime-expressions will be inserted in the arena
+ Runtime(FunctionLocalData<'a>),
+}
+
+#[derive(Debug)]
+enum GlslRestrictions<'a> {
+ /// - const-expressions will be evaluated and inserted in the arena
+ Const,
+ /// - const-expressions will be evaluated and inserted in the arena
+ /// - override-expressions will be inserted in the arena
+ /// - runtime-expressions will be inserted in the arena
+ Runtime(FunctionLocalData<'a>),
}
#[derive(Debug)]
struct FunctionLocalData<'a> {
/// Global constant expressions
- const_expressions: &'a Arena<Expression>,
- /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions`
- expression_constness: &'a mut ExpressionConstnessTracker,
+ global_expressions: &'a Arena<Expression>,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
}
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
+pub enum ExpressionKind {
+ Const,
+ Override,
+ Runtime,
+}
+
#[derive(Debug)]
-pub struct ExpressionConstnessTracker {
- inner: bit_set::BitSet,
+pub struct ExpressionKindTracker {
+ inner: Vec<ExpressionKind>,
}
-impl ExpressionConstnessTracker {
- pub fn new() -> Self {
- Self {
- inner: bit_set::BitSet::new(),
- }
+impl ExpressionKindTracker {
+ pub const fn new() -> Self {
+ Self { inner: Vec::new() }
}
/// Forces the the expression to not be const
pub fn force_non_const(&mut self, value: Handle<Expression>) {
- self.inner.remove(value.index());
+ self.inner[value.index()] = ExpressionKind::Runtime;
}
- fn insert(&mut self, value: Handle<Expression>) {
- self.inner.insert(value.index());
+ pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
+ assert_eq!(self.inner.len(), value.index());
+ self.inner.push(expr_type);
+ }
+ pub fn is_const(&self, h: Handle<Expression>) -> bool {
+ matches!(self.type_of(h), ExpressionKind::Const)
+ }
+
+ pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
+ matches!(
+ self.type_of(h),
+ ExpressionKind::Const | ExpressionKind::Override
+ )
}
- pub fn is_const(&self, value: Handle<Expression>) -> bool {
- self.inner.contains(value.index())
+ fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
+ self.inner[value.index()]
}
pub fn from_arena(arena: &Arena<Expression>) -> Self {
- let mut tracker = Self::new();
- for (handle, expr) in arena.iter() {
- let insert = match *expr {
- crate::Expression::Literal(_)
- | crate::Expression::ZeroValue(_)
- | crate::Expression::Constant(_) => true,
- crate::Expression::Compose { ref components, .. } => {
- components.iter().all(|h| tracker.is_const(*h))
- }
- crate::Expression::Splat { value, .. } => tracker.is_const(value),
- _ => false,
- };
- if insert {
- tracker.insert(handle);
- }
+ let mut tracker = Self {
+ inner: Vec::with_capacity(arena.len()),
+ };
+ for (_, expr) in arena.iter() {
+ tracker.inner.push(tracker.type_of_with_expr(expr));
}
tracker
}
+
+ fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
+ match *expr {
+ Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
+ ExpressionKind::Const
+ }
+ Expression::Override(_) => ExpressionKind::Override,
+ Expression::Compose { ref components, .. } => {
+ let mut expr_type = ExpressionKind::Const;
+ for component in components {
+ expr_type = expr_type.max(self.type_of(*component))
+ }
+ expr_type
+ }
+ Expression::Splat { value, .. } => self.type_of(value),
+ Expression::AccessIndex { base, .. } => self.type_of(base),
+ Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
+ Expression::Swizzle { vector, .. } => self.type_of(vector),
+ Expression::Unary { expr, .. } => self.type_of(expr),
+ Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
+ Expression::Math {
+ arg,
+ arg1,
+ arg2,
+ arg3,
+ ..
+ } => self
+ .type_of(arg)
+ .max(
+ arg1.map(|arg| self.type_of(arg))
+ .unwrap_or(ExpressionKind::Const),
+ )
+ .max(
+ arg2.map(|arg| self.type_of(arg))
+ .unwrap_or(ExpressionKind::Const),
+ )
+ .max(
+ arg3.map(|arg| self.type_of(arg))
+ .unwrap_or(ExpressionKind::Const),
+ ),
+ Expression::As { expr, .. } => self.type_of(expr),
+ Expression::Select {
+ condition,
+ accept,
+ reject,
+ } => self
+ .type_of(condition)
+ .max(self.type_of(accept))
+ .max(self.type_of(reject)),
+ Expression::Relational { argument, .. } => self.type_of(argument),
+ Expression::ArrayLength(expr) => self.type_of(expr),
+ _ => ExpressionKind::Runtime,
+ }
+ }
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantEvaluatorError {
#[error("Constants cannot access function arguments")]
FunctionArg,
@@ -381,6 +476,8 @@ pub enum ConstantEvaluatorError {
ImageExpression,
#[error("Constants don't support ray query expressions")]
RayQueryExpression,
+ #[error("Constants don't support subgroup expressions")]
+ SubgroupExpression,
#[error("Cannot access the type")]
InvalidAccessBase,
#[error("Cannot access at the index")]
@@ -432,6 +529,12 @@ pub enum ConstantEvaluatorError {
ShiftedMoreThan32Bits,
#[error(transparent)]
Literal(#[from] crate::valid::LiteralError),
+ #[error("Can't use pipeline-overridable constants in const-expressions")]
+ Override,
+ #[error("Unexpected runtime-expression")]
+ RuntimeExpr,
+ #[error("Unexpected override-expression")]
+ OverrideExpr,
}
impl<'a> ConstantEvaluator<'a> {
@@ -439,25 +542,49 @@ impl<'a> ConstantEvaluator<'a> {
/// constant expression arena.
///
/// Report errors according to WGSL's rules for constant evaluation.
- pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self {
- Self::for_module(Behavior::Wgsl, module)
+ pub fn for_wgsl_module(
+ module: &'a mut crate::Module,
+ global_expression_kind_tracker: &'a mut ExpressionKindTracker,
+ in_override_ctx: bool,
+ ) -> Self {
+ Self::for_module(
+ Behavior::Wgsl(if in_override_ctx {
+ WgslRestrictions::Override
+ } else {
+ WgslRestrictions::Const
+ }),
+ module,
+ global_expression_kind_tracker,
+ )
}
/// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
/// constant expression arena.
///
/// Report errors according to GLSL's rules for constant evaluation.
- pub fn for_glsl_module(module: &'a mut crate::Module) -> Self {
- Self::for_module(Behavior::Glsl, module)
+ pub fn for_glsl_module(
+ module: &'a mut crate::Module,
+ global_expression_kind_tracker: &'a mut ExpressionKindTracker,
+ ) -> Self {
+ Self::for_module(
+ Behavior::Glsl(GlslRestrictions::Const),
+ module,
+ global_expression_kind_tracker,
+ )
}
- fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self {
+ fn for_module(
+ behavior: Behavior<'a>,
+ module: &'a mut crate::Module,
+ global_expression_kind_tracker: &'a mut ExpressionKindTracker,
+ ) -> Self {
Self {
behavior,
types: &mut module.types,
constants: &module.constants,
- expressions: &mut module.const_expressions,
- function_local_data: None,
+ overrides: &module.overrides,
+ expressions: &mut module.global_expressions,
+ expression_kind_tracker: global_expression_kind_tracker,
}
}
@@ -468,18 +595,22 @@ impl<'a> ConstantEvaluator<'a> {
pub fn for_wgsl_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
- expression_constness: &'a mut ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'a mut ExpressionKindTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
- Self::for_function(
- Behavior::Wgsl,
- module,
+ Self {
+ behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData {
+ global_expressions: &module.global_expressions,
+ emitter,
+ block,
+ })),
+ types: &mut module.types,
+ constants: &module.constants,
+ overrides: &module.overrides,
expressions,
- expression_constness,
- emitter,
- block,
- )
+ expression_kind_tracker: local_expression_kind_tracker,
+ }
}
/// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
@@ -489,39 +620,21 @@ impl<'a> ConstantEvaluator<'a> {
pub fn for_glsl_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
- expression_constness: &'a mut ExpressionConstnessTracker,
- emitter: &'a mut super::Emitter,
- block: &'a mut crate::Block,
- ) -> Self {
- Self::for_function(
- Behavior::Glsl,
- module,
- expressions,
- expression_constness,
- emitter,
- block,
- )
- }
-
- fn for_function(
- behavior: Behavior,
- module: &'a mut crate::Module,
- expressions: &'a mut Arena<Expression>,
- expression_constness: &'a mut ExpressionConstnessTracker,
+ local_expression_kind_tracker: &'a mut ExpressionKindTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
Self {
- behavior,
+ behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
+ global_expressions: &module.global_expressions,
+ emitter,
+ block,
+ })),
types: &mut module.types,
constants: &module.constants,
+ overrides: &module.overrides,
expressions,
- function_local_data: Some(FunctionLocalData {
- const_expressions: &module.const_expressions,
- expression_constness,
- emitter,
- block,
- }),
+ expression_kind_tracker: local_expression_kind_tracker,
}
}
@@ -529,19 +642,18 @@ impl<'a> ConstantEvaluator<'a> {
crate::proc::GlobalCtx {
types: self.types,
constants: self.constants,
- const_expressions: match self.function_local_data {
- Some(ref data) => data.const_expressions,
+ overrides: self.overrides,
+ global_expressions: match self.function_local_data() {
+ Some(data) => data.global_expressions,
None => self.expressions,
},
}
}
fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
- if let Some(ref function_local_data) = self.function_local_data {
- if !function_local_data.expression_constness.is_const(expr) {
- log::debug!("check: SubexpressionsAreNotConstant");
- return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
- }
+ if !self.expression_kind_tracker.is_const(expr) {
+ log::debug!("check: SubexpressionsAreNotConstant");
+ return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
}
Ok(())
}
@@ -554,11 +666,11 @@ impl<'a> ConstantEvaluator<'a> {
Expression::Constant(c) => {
// Are we working in a function's expression arena, or the
// module's constant expression arena?
- if let Some(ref function_local_data) = self.function_local_data {
+ if let Some(function_local_data) = self.function_local_data() {
// Deep-copy the constant's value into our arena.
self.copy_from(
self.constants[c].init,
- function_local_data.const_expressions,
+ function_local_data.global_expressions,
)
} else {
// "See through" the constant and use its initializer.
@@ -580,9 +692,11 @@ impl<'a> ConstantEvaluator<'a> {
/// [`ZeroValue`], and [`Swizzle`] expressions - to the expression arena
/// `self` contributes to.
///
- /// If `expr`'s value cannot be determined at compile time, return a an
- /// error. If it's acceptable to evaluate `expr` at runtime, this error can
- /// be ignored, and the caller can append `expr` to the arena itself.
+ /// If `expr`'s value cannot be determined at compile time, and `self` is
+ /// contributing to some function's expression arena, then append `expr` to
+ /// that arena unchanged (and thus unevaluated). Otherwise, `self` must be
+ /// contributing to the module's constant expression arena; since `expr`'s
+ /// value is not a constant, return an error.
///
/// We only consider `expr` itself, without recursing into its operands. Its
/// operands must all have been produced by prior calls to
@@ -595,16 +709,81 @@ impl<'a> ConstantEvaluator<'a> {
/// [`Swizzle`]: Expression::Swizzle
pub fn try_eval_and_append(
&mut self,
+ expr: Expression,
+ span: Span,
+ ) -> Result<Handle<Expression>, ConstantEvaluatorError> {
+ match self.expression_kind_tracker.type_of_with_expr(&expr) {
+ ExpressionKind::Const => {
+ let eval_result = self.try_eval_and_append_impl(&expr, span);
+ // We should be able to evaluate `Const` expressions at this
+ // point. If we failed to, then that probably means we just
+ // haven't implemented that part of constant evaluation. Work
+ // around this by simply emitting it as a run-time expression.
+ if self.behavior.has_runtime_restrictions()
+ && matches!(
+ eval_result,
+ Err(ConstantEvaluatorError::NotImplemented(_)
+ | ConstantEvaluatorError::InvalidBinaryOpArgs,)
+ )
+ {
+ Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
+ } else {
+ eval_result
+ }
+ }
+ ExpressionKind::Override => match self.behavior {
+ Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
+ Ok(self.append_expr(expr, span, ExpressionKind::Override))
+ }
+ Behavior::Wgsl(WgslRestrictions::Const) => {
+ Err(ConstantEvaluatorError::OverrideExpr)
+ }
+ Behavior::Glsl(_) => {
+ unreachable!()
+ }
+ },
+ ExpressionKind::Runtime => {
+ if self.behavior.has_runtime_restrictions() {
+ Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
+ } else {
+ Err(ConstantEvaluatorError::RuntimeExpr)
+ }
+ }
+ }
+ }
+
+ /// Is the [`Self::expressions`] arena the global module expression arena?
+ const fn is_global_arena(&self) -> bool {
+ matches!(
+ self.behavior,
+ Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override)
+ | Behavior::Glsl(GlslRestrictions::Const)
+ )
+ }
+
+ const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
+ match self.behavior {
+ Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data))
+ | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
+ Some(function_local_data)
+ }
+ _ => None,
+ }
+ }
+
+ fn try_eval_and_append_impl(
+ &mut self,
expr: &Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
log::trace!("try_eval_and_append: {:?}", expr);
match *expr {
- Expression::Constant(c) if self.function_local_data.is_none() => {
+ Expression::Constant(c) if self.is_global_arena() => {
// "See through" the constant and use its initializer.
// This is mainly done to avoid having constants pointing to other constants.
Ok(self.constants[c].init)
}
+ Expression::Override(_) => Err(ConstantEvaluatorError::Override),
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
self.register_evaluated_expr(expr.clone(), span)
}
@@ -685,8 +864,8 @@ impl<'a> ConstantEvaluator<'a> {
format!("{fun:?} built-in function"),
)),
Expression::ArrayLength(expr) => match self.behavior {
- Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength),
- Behavior::Glsl => {
+ Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
+ Behavior::Glsl(_) => {
let expr = self.check_and_get(expr)?;
self.array_length(expr, span)
}
@@ -707,6 +886,12 @@ impl<'a> ConstantEvaluator<'a> {
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
Err(ConstantEvaluatorError::RayQueryExpression)
}
+ Expression::SubgroupBallotResult { .. } => {
+ Err(ConstantEvaluatorError::SubgroupExpression)
+ }
+ Expression::SubgroupOperationResult { .. } => {
+ Err(ConstantEvaluatorError::SubgroupExpression)
+ }
}
}
@@ -765,10 +950,10 @@ impl<'a> ConstantEvaluator<'a> {
pattern: [crate::SwizzleComponent; 4],
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let mut get_dst_ty = |ty| match self.types[ty].inner {
- crate::TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
+ TypeInner::Vector { size: _, scalar } => Ok(self.types.insert(
Type {
name: None,
- inner: crate::TypeInner::Vector { size, scalar },
+ inner: TypeInner::Vector { size, scalar },
},
span,
)),
@@ -1059,13 +1244,11 @@ impl<'a> ConstantEvaluator<'a> {
Expression::ZeroValue(ty) | Expression::Compose { ty, .. } => {
match self.types[ty].inner {
TypeInner::Array { size, .. } => match size {
- crate::ArraySize::Constant(len) => {
+ ArraySize::Constant(len) => {
let expr = Expression::Literal(Literal::U32(len.get()));
self.register_evaluated_expr(expr, span)
}
- crate::ArraySize::Dynamic => {
- Err(ConstantEvaluatorError::ArrayLengthDynamic)
- }
+ ArraySize::Dynamic => Err(ConstantEvaluatorError::ArrayLengthDynamic),
},
_ => Err(ConstantEvaluatorError::InvalidArrayLengthArg),
}
@@ -1128,7 +1311,7 @@ impl<'a> ConstantEvaluator<'a> {
Expression::ZeroValue(ty)
if matches!(
self.types[ty].inner,
- crate::TypeInner::Scalar(crate::Scalar {
+ TypeInner::Scalar(crate::Scalar {
kind: ScalarKind::Uint,
..
})
@@ -1443,7 +1626,7 @@ impl<'a> ConstantEvaluator<'a> {
return self.cast(expr, target, span);
};
- let crate::TypeInner::Array {
+ let TypeInner::Array {
base: _,
size,
stride: _,
@@ -1853,29 +2036,35 @@ impl<'a> ConstantEvaluator<'a> {
crate::valid::check_literal_value(literal)?;
}
- if let Some(FunctionLocalData {
- ref mut emitter,
- ref mut block,
- ref mut expression_constness,
- ..
- }) = self.function_local_data
- {
- let is_running = emitter.is_running();
- let needs_pre_emit = expr.needs_pre_emit();
- if is_running && needs_pre_emit {
- block.extend(emitter.finish(self.expressions));
- let h = self.expressions.append(expr, span);
- emitter.start(self.expressions);
- expression_constness.insert(h);
- Ok(h)
- } else {
- let h = self.expressions.append(expr, span);
- expression_constness.insert(h);
- Ok(h)
+ Ok(self.append_expr(expr, span, ExpressionKind::Const))
+ }
+
+ fn append_expr(
+ &mut self,
+ expr: Expression,
+ span: Span,
+ expr_type: ExpressionKind,
+ ) -> Handle<Expression> {
+ let h = match self.behavior {
+ Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data))
+ | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
+ let is_running = function_local_data.emitter.is_running();
+ let needs_pre_emit = expr.needs_pre_emit();
+ if is_running && needs_pre_emit {
+ function_local_data
+ .block
+ .extend(function_local_data.emitter.finish(self.expressions));
+ let h = self.expressions.append(expr, span);
+ function_local_data.emitter.start(self.expressions);
+ h
+ } else {
+ self.expressions.append(expr, span)
+ }
}
- } else {
- Ok(self.expressions.append(expr, span))
- }
+ _ => self.expressions.append(expr, span),
+ };
+ self.expression_kind_tracker.insert(h, expr_type);
+ h
}
fn resolve_type(
@@ -2029,13 +2218,14 @@ mod tests {
UniqueArena, VectorSize,
};
- use super::{Behavior, ConstantEvaluator};
+ use super::{Behavior, ConstantEvaluator, ExpressionKindTracker, WgslRestrictions};
#[test]
fn unary_op() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let scalar_ty = types.insert(
Type {
@@ -2059,9 +2249,8 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: scalar_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
@@ -2070,9 +2259,8 @@ mod tests {
let h1 = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: scalar_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(8)), Default::default()),
},
Default::default(),
@@ -2081,9 +2269,8 @@ mod tests {
let vec_h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: vec_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: vec_ty,
components: vec![constants[h].init, constants[h1].init],
@@ -2094,8 +2281,8 @@ mod tests {
Default::default(),
);
- let expr = const_expressions.append(Expression::Constant(h), Default::default());
- let expr1 = const_expressions.append(Expression::Constant(vec_h), Default::default());
+ let expr = global_expressions.append(Expression::Constant(h), Default::default());
+ let expr1 = global_expressions.append(Expression::Constant(vec_h), Default::default());
let expr2 = Expression::Unary {
op: UnaryOperator::Negate,
@@ -2112,35 +2299,37 @@ mod tests {
expr: expr1,
};
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let res1 = solver
- .try_eval_and_append(&expr2, Default::default())
+ .try_eval_and_append(expr2, Default::default())
.unwrap();
let res2 = solver
- .try_eval_and_append(&expr3, Default::default())
+ .try_eval_and_append(expr3, Default::default())
.unwrap();
let res3 = solver
- .try_eval_and_append(&expr4, Default::default())
+ .try_eval_and_append(expr4, Default::default())
.unwrap();
assert_eq!(
- const_expressions[res1],
+ global_expressions[res1],
Expression::Literal(Literal::I32(-4))
);
assert_eq!(
- const_expressions[res2],
+ global_expressions[res2],
Expression::Literal(Literal::I32(!4))
);
- let res3_inner = &const_expressions[res3];
+ let res3_inner = &global_expressions[res3];
match *res3_inner {
Expression::Compose {
@@ -2150,11 +2339,11 @@ mod tests {
assert_eq!(*ty, vec_ty);
let mut components_iter = components.iter().copied();
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::I32(!4))
);
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::I32(!8))
);
assert!(components_iter.next().is_none());
@@ -2167,7 +2356,8 @@ mod tests {
fn cast() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let scalar_ty = types.insert(
Type {
@@ -2180,15 +2370,14 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: scalar_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);
- let expr = const_expressions.append(Expression::Constant(h), Default::default());
+ let expr = global_expressions.append(Expression::Constant(h), Default::default());
let root = Expression::As {
expr,
@@ -2196,20 +2385,22 @@ mod tests {
convert: Some(crate::BOOL_WIDTH),
};
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let res = solver
- .try_eval_and_append(&root, Default::default())
+ .try_eval_and_append(root, Default::default())
.unwrap();
assert_eq!(
- const_expressions[res],
+ global_expressions[res],
Expression::Literal(Literal::Bool(true))
);
}
@@ -2218,7 +2409,8 @@ mod tests {
fn access() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let matrix_ty = types.insert(
Type {
@@ -2247,7 +2439,7 @@ mod tests {
let mut vec2_components = Vec::with_capacity(3);
for i in 0..3 {
- let h = const_expressions.append(
+ let h = global_expressions.append(
Expression::Literal(Literal::F32(i as f32)),
Default::default(),
);
@@ -2256,7 +2448,7 @@ mod tests {
}
for i in 3..6 {
- let h = const_expressions.append(
+ let h = global_expressions.append(
Expression::Literal(Literal::F32(i as f32)),
Default::default(),
);
@@ -2267,9 +2459,8 @@ mod tests {
let vec1 = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: vec_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: vec_ty,
components: vec1_components,
@@ -2283,9 +2474,8 @@ mod tests {
let vec2 = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: vec_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: vec_ty,
components: vec2_components,
@@ -2299,9 +2489,8 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: matrix_ty,
- init: const_expressions.append(
+ init: global_expressions.append(
Expression::Compose {
ty: matrix_ty,
components: vec![constants[vec1].init, constants[vec2].init],
@@ -2312,20 +2501,22 @@ mod tests {
Default::default(),
);
- let base = const_expressions.append(Expression::Constant(h), Default::default());
+ let base = global_expressions.append(Expression::Constant(h), Default::default());
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let root1 = Expression::AccessIndex { base, index: 1 };
let res1 = solver
- .try_eval_and_append(&root1, Default::default())
+ .try_eval_and_append(root1, Default::default())
.unwrap();
let root2 = Expression::AccessIndex {
@@ -2334,10 +2525,10 @@ mod tests {
};
let res2 = solver
- .try_eval_and_append(&root2, Default::default())
+ .try_eval_and_append(root2, Default::default())
.unwrap();
- match const_expressions[res1] {
+ match global_expressions[res1] {
Expression::Compose {
ref ty,
ref components,
@@ -2345,15 +2536,15 @@ mod tests {
assert_eq!(*ty, vec_ty);
let mut components_iter = components.iter().copied();
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::F32(3.))
);
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::F32(4.))
);
assert_eq!(
- const_expressions[components_iter.next().unwrap()],
+ global_expressions[components_iter.next().unwrap()],
Expression::Literal(Literal::F32(5.))
);
assert!(components_iter.next().is_none());
@@ -2362,7 +2553,7 @@ mod tests {
}
assert_eq!(
- const_expressions[res2],
+ global_expressions[res2],
Expression::Literal(Literal::F32(5.))
);
}
@@ -2371,7 +2562,8 @@ mod tests {
fn compose_of_constants() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let i32_ty = types.insert(
Type {
@@ -2395,27 +2587,28 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: i32_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);
- let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
+ let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let solved_compose = solver
.try_eval_and_append(
- &Expression::Compose {
+ Expression::Compose {
ty: vec2_i32_ty,
components: vec![h_expr, h_expr],
},
@@ -2424,7 +2617,7 @@ mod tests {
.unwrap();
let solved_negate = solver
.try_eval_and_append(
- &Expression::Unary {
+ Expression::Unary {
op: UnaryOperator::Negate,
expr: solved_compose,
},
@@ -2432,11 +2625,11 @@ mod tests {
)
.unwrap();
- let pass = match const_expressions[solved_negate] {
+ let pass = match global_expressions[solved_negate] {
Expression::Compose { ty, ref components } => {
ty == vec2_i32_ty
&& components.iter().all(|&component| {
- let component = &const_expressions[component];
+ let component = &global_expressions[component];
matches!(*component, Expression::Literal(Literal::I32(-4)))
})
}
@@ -2451,7 +2644,8 @@ mod tests {
fn splat_of_constant() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
- let mut const_expressions = Arena::new();
+ let overrides = Arena::new();
+ let mut global_expressions = Arena::new();
let i32_ty = types.insert(
Type {
@@ -2475,27 +2669,28 @@ mod tests {
let h = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: i32_ty,
- init: const_expressions
+ init: global_expressions
.append(Expression::Literal(Literal::I32(4)), Default::default()),
},
Default::default(),
);
- let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
+ let h_expr = global_expressions.append(Expression::Constant(h), Default::default());
+ let expression_kind_tracker = &mut ExpressionKindTracker::from_arena(&global_expressions);
let mut solver = ConstantEvaluator {
- behavior: Behavior::Wgsl,
+ behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
- expressions: &mut const_expressions,
- function_local_data: None,
+ overrides: &overrides,
+ expressions: &mut global_expressions,
+ expression_kind_tracker,
};
let solved_compose = solver
.try_eval_and_append(
- &Expression::Splat {
+ Expression::Splat {
size: VectorSize::Bi,
value: h_expr,
},
@@ -2504,7 +2699,7 @@ mod tests {
.unwrap();
let solved_negate = solver
.try_eval_and_append(
- &Expression::Unary {
+ Expression::Unary {
op: UnaryOperator::Negate,
expr: solved_compose,
},
@@ -2512,11 +2707,11 @@ mod tests {
)
.unwrap();
- let pass = match const_expressions[solved_negate] {
+ let pass = match global_expressions[solved_negate] {
Expression::Compose { ty, ref components } => {
ty == vec2_i32_ty
&& components.iter().all(|&component| {
- let component = &const_expressions[component];
+ let component = &global_expressions[component];
matches!(*component, Expression::Literal(Literal::I32(-4)))
})
}
diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs
index af3221c0fe..e2c3de8eb0 100644
--- a/third_party/rust/naga/src/proc/index.rs
+++ b/third_party/rust/naga/src/proc/index.rs
@@ -239,7 +239,7 @@ pub enum GuardedIndex {
pub fn find_checked_indexes(
module: &crate::Module,
function: &crate::Function,
- info: &crate::valid::FunctionInfo,
+ info: &valid::FunctionInfo,
policies: BoundsCheckPolicies,
) -> BitSet {
use crate::Expression as Ex;
@@ -321,7 +321,7 @@ pub fn access_needs_check(
mut index: GuardedIndex,
module: &crate::Module,
function: &crate::Function,
- info: &crate::valid::FunctionInfo,
+ info: &valid::FunctionInfo,
) -> Option<IndexableLength> {
let base_inner = info[base].ty.inner_with(&module.types);
// Unwrap safety: `Err` here indicates unindexable base types and invalid
diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs
index 46cbb6c3b3..93aac5b3e5 100644
--- a/third_party/rust/naga/src/proc/mod.rs
+++ b/third_party/rust/naga/src/proc/mod.rs
@@ -11,7 +11,7 @@ mod terminator;
mod typifier;
pub use constant_evaluator::{
- ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker,
+ ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker,
};
pub use emitter::Emitter;
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
@@ -153,56 +153,31 @@ impl super::Scalar {
}
}
-impl PartialEq for crate::Literal {
- fn eq(&self, other: &Self) -> bool {
- match (*self, *other) {
- (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(),
- (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(),
- (Self::U32(a), Self::U32(b)) => a == b,
- (Self::I32(a), Self::I32(b)) => a == b,
- (Self::U64(a), Self::U64(b)) => a == b,
- (Self::I64(a), Self::I64(b)) => a == b,
- (Self::Bool(a), Self::Bool(b)) => a == b,
- _ => false,
- }
- }
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum HashableLiteral {
+ F64(u64),
+ F32(u32),
+ U32(u32),
+ I32(i32),
+ U64(u64),
+ I64(i64),
+ Bool(bool),
+ AbstractInt(i64),
+ AbstractFloat(u64),
}
-impl Eq for crate::Literal {}
-impl std::hash::Hash for crate::Literal {
- fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
- match *self {
- Self::F64(v) | Self::AbstractFloat(v) => {
- hasher.write_u8(0);
- v.to_bits().hash(hasher);
- }
- Self::F32(v) => {
- hasher.write_u8(1);
- v.to_bits().hash(hasher);
- }
- Self::U32(v) => {
- hasher.write_u8(2);
- v.hash(hasher);
- }
- Self::I32(v) => {
- hasher.write_u8(3);
- v.hash(hasher);
- }
- Self::Bool(v) => {
- hasher.write_u8(4);
- v.hash(hasher);
- }
- Self::I64(v) => {
- hasher.write_u8(5);
- v.hash(hasher);
- }
- Self::U64(v) => {
- hasher.write_u8(6);
- v.hash(hasher);
- }
- Self::AbstractInt(v) => {
- hasher.write_u8(7);
- v.hash(hasher);
- }
+
+impl From<crate::Literal> for HashableLiteral {
+ fn from(l: crate::Literal) -> Self {
+ match l {
+ crate::Literal::F64(v) => Self::F64(v.to_bits()),
+ crate::Literal::F32(v) => Self::F32(v.to_bits()),
+ crate::Literal::U32(v) => Self::U32(v),
+ crate::Literal::I32(v) => Self::I32(v),
+ crate::Literal::U64(v) => Self::U64(v),
+ crate::Literal::I64(v) => Self::I64(v),
+ crate::Literal::Bool(v) => Self::Bool(v),
+ crate::Literal::AbstractInt(v) => Self::AbstractInt(v),
+ crate::Literal::AbstractFloat(v) => Self::AbstractFloat(v.to_bits()),
}
}
}
@@ -216,8 +191,8 @@ impl crate::Literal {
(value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
(value, crate::ScalarKind::Uint, 8) => Some(Self::U64(value as _)),
(value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
- (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)),
- (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)),
+ (1, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(true)),
+ (0, crate::ScalarKind::Bool, crate::BOOL_WIDTH) => Some(Self::Bool(false)),
_ => None,
}
}
@@ -279,8 +254,9 @@ impl super::TypeInner {
self.scalar().map(|scalar| scalar.kind)
}
+ /// Returns the scalar width in bytes
pub fn scalar_width(&self) -> Option<u8> {
- self.scalar().map(|scalar| scalar.width * 8)
+ self.scalar().map(|scalar| scalar.width)
}
pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
@@ -532,6 +508,7 @@ impl crate::Expression {
match *self {
Self::Literal(_)
| Self::Constant(_)
+ | Self::Override(_)
| Self::ZeroValue(_)
| Self::FunctionArgument(_)
| Self::GlobalVariable(_)
@@ -553,13 +530,9 @@ impl crate::Expression {
///
/// [`Access`]: crate::Expression::Access
/// [`ResolveContext`]: crate::proc::ResolveContext
- pub fn is_dynamic_index(&self, module: &crate::Module) -> bool {
+ pub const fn is_dynamic_index(&self) -> bool {
match *self {
- Self::Literal(_) | Self::ZeroValue(_) => false,
- Self::Constant(handle) => {
- let constant = &module.constants[handle];
- !matches!(constant.r#override, crate::Override::None)
- }
+ Self::Literal(_) | Self::ZeroValue(_) | Self::Constant(_) => false,
_ => true,
}
}
@@ -652,7 +625,8 @@ impl crate::Module {
GlobalCtx {
types: &self.types,
constants: &self.constants,
- const_expressions: &self.const_expressions,
+ overrides: &self.overrides,
+ global_expressions: &self.global_expressions,
}
}
}
@@ -667,17 +641,18 @@ pub(super) enum U32EvalError {
pub struct GlobalCtx<'a> {
pub types: &'a crate::UniqueArena<crate::Type>,
pub constants: &'a crate::Arena<crate::Constant>,
- pub const_expressions: &'a crate::Arena<crate::Expression>,
+ pub overrides: &'a crate::Arena<crate::Override>,
+ pub global_expressions: &'a crate::Arena<crate::Expression>,
}
impl GlobalCtx<'_> {
- /// Try to evaluate the expression in `self.const_expressions` using its `handle` and return it as a `u32`.
+ /// Try to evaluate the expression in `self.global_expressions` using its `handle` and return it as a `u32`.
#[allow(dead_code)]
pub(super) fn eval_expr_to_u32(
&self,
handle: crate::Handle<crate::Expression>,
) -> Result<u32, U32EvalError> {
- self.eval_expr_to_u32_from(handle, self.const_expressions)
+ self.eval_expr_to_u32_from(handle, self.global_expressions)
}
/// Try to evaluate the expression in the `arena` using its `handle` and return it as a `u32`.
@@ -700,7 +675,7 @@ impl GlobalCtx<'_> {
&self,
handle: crate::Handle<crate::Expression>,
) -> Option<crate::Literal> {
- self.eval_expr_to_literal_from(handle, self.const_expressions)
+ self.eval_expr_to_literal_from(handle, self.global_expressions)
}
fn eval_expr_to_literal_from(
@@ -724,7 +699,7 @@ impl GlobalCtx<'_> {
}
match arena[handle] {
crate::Expression::Constant(c) => {
- get(*self, self.constants[c].init, self.const_expressions)
+ get(*self, self.constants[c].init, self.global_expressions)
}
_ => get(*self, handle, arena),
}
diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs
index a5239d4eca..5edf55cb73 100644
--- a/third_party/rust/naga/src/proc/terminator.rs
+++ b/third_party/rust/naga/src/proc/terminator.rs
@@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
| S::RayQuery { .. }
| S::Atomic { .. }
| S::WorkGroupUniformLoad { .. }
+ | S::SubgroupBallot { .. }
+ | S::SubgroupCollectiveOperation { .. }
+ | S::SubgroupGather { .. }
| S::Barrier(_)),
)
| None => block.push(S::Return { value: None }, Default::default()),
diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs
index 9c4403445c..3936e7efbe 100644
--- a/third_party/rust/naga/src/proc/typifier.rs
+++ b/third_party/rust/naga/src/proc/typifier.rs
@@ -185,6 +185,7 @@ pub enum ResolveError {
pub struct ResolveContext<'a> {
pub constants: &'a Arena<crate::Constant>,
+ pub overrides: &'a Arena<crate::Override>,
pub types: &'a UniqueArena<crate::Type>,
pub special_types: &'a crate::SpecialTypes,
pub global_vars: &'a Arena<crate::GlobalVariable>,
@@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> {
) -> Self {
Self {
constants: &module.constants,
+ overrides: &module.overrides,
types: &module.types,
special_types: &module.special_types,
global_vars: &module.global_variables,
@@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> {
},
crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
+ crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty),
crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::FunctionArgument(index) => {
@@ -595,6 +598,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
+ crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty),
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
@@ -882,6 +886,10 @@ impl<'a> ResolveContext<'a> {
.ok_or(ResolveError::MissingSpecialType)?;
TypeResolution::Handle(result)
}
+ crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
+ scalar: crate::Scalar::U32,
+ size: crate::VectorSize::Quad,
+ }),
})
}
}
diff --git a/third_party/rust/naga/src/span.rs b/third_party/rust/naga/src/span.rs
index 10744647e9..82cfbe5a4b 100644
--- a/third_party/rust/naga/src/span.rs
+++ b/third_party/rust/naga/src/span.rs
@@ -72,8 +72,8 @@ impl Span {
pub fn location(&self, source: &str) -> SourceLocation {
let prefix = &source[..self.start as usize];
let line_number = prefix.matches('\n').count() as u32 + 1;
- let line_start = prefix.rfind('\n').map(|pos| pos + 1).unwrap_or(0);
- let line_position = source[line_start..self.start as usize].chars().count() as u32 + 1;
+ let line_start = prefix.rfind('\n').map(|pos| pos + 1).unwrap_or(0) as u32;
+ let line_position = self.start - line_start + 1;
SourceLocation {
line_number,
@@ -107,14 +107,14 @@ impl std::ops::Index<Span> for str {
/// Roughly corresponds to the positional members of [`GPUCompilationMessage`][gcm] from
/// the WebGPU specification, except
/// - `offset` and `length` are in bytes (UTF-8 code units), instead of UTF-16 code units.
-/// - `line_position` counts entire Unicode code points, instead of UTF-16 code units.
+/// - `line_position` is in bytes (UTF-8 code units), instead of UTF-16 code units.
///
/// [gcm]: https://www.w3.org/TR/webgpu/#gpucompilationmessage
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct SourceLocation {
/// 1-based line number.
pub line_number: u32,
- /// 1-based column of the start of this span, counted in Unicode code points.
+ /// 1-based column in code units (in bytes) of the start of the span.
pub line_position: u32,
/// 0-based Offset in code units (in bytes) of the start of the span.
pub offset: u32,
@@ -136,7 +136,7 @@ impl<E> fmt::Display for WithSpan<E>
where
E: fmt::Display,
{
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.inner.fmt(f)
}
}
@@ -304,7 +304,7 @@ impl<E> WithSpan<E> {
use term::termcolor::NoColor;
let files = files::SimpleFile::new(path, source);
- let config = codespan_reporting::term::Config::default();
+ let config = term::Config::default();
let mut writer = NoColor::new(Vec::new());
term::emit(&mut writer, &config, &files, &self.diagnostic()).expect("cannot write error");
String::from_utf8(writer.into_inner()).unwrap()
diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs
index 03fbc4089b..6799e5db27 100644
--- a/third_party/rust/naga/src/valid/analyzer.rs
+++ b/third_party/rust/naga/src/valid/analyzer.rs
@@ -226,7 +226,7 @@ struct Sampling {
sampler: GlobalOrArgument,
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct FunctionInfo {
@@ -574,7 +574,7 @@ impl FunctionInfo {
non_uniform_result: self.add_ref(vector),
requirements: UniformityRequirements::empty(),
},
- E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(),
+ E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
E::Compose { ref components, .. } => {
let non_uniform_result = components
.iter()
@@ -787,6 +787,14 @@ impl FunctionInfo {
non_uniform_result: self.add_ref(query),
requirements: UniformityRequirements::empty(),
},
+ E::SubgroupBallotResult => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
+ E::SubgroupOperationResult { .. } => Uniformity {
+ non_uniform_result: Some(handle),
+ requirements: UniformityRequirements::empty(),
+ },
};
let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
@@ -827,7 +835,7 @@ impl FunctionInfo {
let req = self.expressions[expr.index()].uniformity.requirements;
if self
.flags
- .contains(super::ValidationFlags::CONTROL_FLOW_UNIFORMITY)
+ .contains(ValidationFlags::CONTROL_FLOW_UNIFORMITY)
&& !req.is_empty()
{
if let Some(cause) = disruptor {
@@ -1029,6 +1037,42 @@ impl FunctionInfo {
}
FunctionUniformity::new()
}
+ S::SubgroupBallot {
+ result: _,
+ predicate,
+ } => {
+ if let Some(predicate) = predicate {
+ let _ = self.add_ref(predicate);
+ }
+ FunctionUniformity::new()
+ }
+ S::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: _,
+ argument,
+ result: _,
+ } => {
+ let _ = self.add_ref(argument);
+ FunctionUniformity::new()
+ }
+ S::SubgroupGather {
+ mode,
+ argument,
+ result: _,
+ } => {
+ let _ = self.add_ref(argument);
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let _ = self.add_ref(index);
+ }
+ }
+ FunctionUniformity::new()
+ }
};
disruptor = disruptor.or(uniformity.exit_disruptor());
@@ -1047,7 +1091,7 @@ impl ModuleInfo {
gctx: crate::proc::GlobalCtx,
) -> Result<(), super::ConstExpressionError> {
self.const_expression_types[handle.index()] =
- resolve_context.resolve(&gctx.const_expressions[handle], |h| Ok(&self[h]))?;
+ resolve_context.resolve(&gctx.global_expressions[handle], |h| Ok(&self[h]))?;
Ok(())
}
@@ -1186,6 +1230,7 @@ fn uniform_control_flow() {
};
let resolve_context = ResolveContext {
constants: &Arena::new(),
+ overrides: &Arena::new(),
types: &type_arena,
special_types: &crate::SpecialTypes::default(),
global_vars: &global_var_arena,
diff --git a/third_party/rust/naga/src/valid/expression.rs b/third_party/rust/naga/src/valid/expression.rs
index 838ecc4e27..525bd28c17 100644
--- a/third_party/rust/naga/src/valid/expression.rs
+++ b/third_party/rust/naga/src/valid/expression.rs
@@ -90,6 +90,8 @@ pub enum ExpressionError {
sampler: bool,
has_ref: bool,
},
+ #[error("Sample offset must be a const-expression")]
+ InvalidSampleOffsetExprType,
#[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
#[error("Depth reference {0:?} is not a scalar float")]
@@ -129,9 +131,12 @@ pub enum ExpressionError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ConstExpressionError {
- #[error("The expression is not a constant expression")]
- NonConst,
+ #[error("The expression is not a constant or override expression")]
+ NonConstOrOverride,
+ #[error("The expression is not a fully evaluated constant expression")]
+ NonFullyEvaluatedConst,
#[error(transparent)]
Compose(#[from] super::ComposeError),
#[error("Splatting {0:?} can't be done")]
@@ -184,10 +189,15 @@ impl super::Validator {
handle: Handle<crate::Expression>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<(), ConstExpressionError> {
use crate::Expression as E;
- match gctx.const_expressions[handle] {
+ if !global_expr_kind.is_const_or_override(handle) {
+ return Err(ConstExpressionError::NonConstOrOverride);
+ }
+
+ match gctx.global_expressions[handle] {
E::Literal(literal) => {
self.validate_literal(literal)?;
}
@@ -201,14 +211,19 @@ impl super::Validator {
}
E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
crate::TypeInner::Scalar { .. } => {}
- _ => return Err(super::ConstExpressionError::InvalidSplatType(value)),
+ _ => return Err(ConstExpressionError::InvalidSplatType(value)),
},
- _ => return Err(super::ConstExpressionError::NonConst),
+ _ if global_expr_kind.is_const(handle) || !self.allow_overrides => {
+ return Err(ConstExpressionError::NonFullyEvaluatedConst)
+ }
+ // the constant evaluator will report errors about override-expressions
+ _ => {}
}
Ok(())
}
+ #[allow(clippy::too_many_arguments)]
pub(super) fn validate_expression(
&self,
root: Handle<crate::Expression>,
@@ -217,6 +232,7 @@ impl super::Validator {
module: &crate::Module,
info: &FunctionInfo,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<ShaderStages, ExpressionError> {
use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
@@ -252,9 +268,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidIndexType(index));
}
}
- if dynamic_indexing_restricted
- && function.expressions[index].is_dynamic_index(module)
- {
+ if dynamic_indexing_restricted && function.expressions[index].is_dynamic_index() {
return Err(ExpressionError::IndexMustBeConstant(base));
}
@@ -347,7 +361,7 @@ impl super::Validator {
self.validate_literal(literal)?;
ShaderStages::all()
}
- E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(),
+ E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
E::Compose { ref components, ty } => {
validate_compose(
ty,
@@ -464,6 +478,10 @@ impl super::Validator {
// check constant offset
if let Some(const_expr) = offset {
+ if !global_expr_kind.is_const(const_expr) {
+ return Err(ExpressionError::InvalidSampleOffsetExprType);
+ }
+
match *mod_info[const_expr].inner_with(&module.types) {
Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
Ti::Vector {
@@ -1623,6 +1641,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
+ E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
};
Ok(stages)
}
@@ -1716,7 +1735,7 @@ fn validate_with_const_expression(
use crate::span::Span;
let mut module = crate::Module::default();
- module.const_expressions.append(expr, Span::default());
+ module.global_expressions.append(expr, Span::default());
let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);
diff --git a/third_party/rust/naga/src/valid/function.rs b/third_party/rust/naga/src/valid/function.rs
index f0ca22cbda..71128fc86d 100644
--- a/third_party/rust/naga/src/valid/function.rs
+++ b/third_party/rust/naga/src/valid/function.rs
@@ -49,13 +49,26 @@ pub enum AtomicError {
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
+pub enum SubgroupError {
+ #[error("Operand {0:?} has invalid type.")]
+ InvalidOperand(Handle<crate::Expression>),
+ #[error("Result type for {0:?} doesn't match the statement")]
+ ResultTypeMismatch(Handle<crate::Expression>),
+ #[error("Support for subgroup operation {0:?} is required")]
+ UnsupportedOperation(super::SubgroupOperationSet),
+ #[error("Unknown operation")]
+ UnknownOperation,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum LocalVariableError {
#[error("Local variable has a type {0:?} that can't be stored in a local variable.")]
InvalidType(Handle<crate::Type>),
#[error("Initializer doesn't match the variable type")]
InitializerType,
- #[error("Initializer is not const")]
- NonConstInitializer,
+ #[error("Initializer is not a const or override expression")]
+ NonConstOrOverrideInitializer,
}
#[derive(Clone, Debug, thiserror::Error)]
@@ -135,6 +148,8 @@ pub enum FunctionError {
InvalidRayDescriptor(Handle<crate::Expression>),
#[error("Ray Query {0:?} does not have a matching type")]
InvalidRayQueryType(Handle<crate::Type>),
+ #[error("Shader requires capability {0:?}")]
+ MissingCapability(super::Capabilities),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
@@ -155,6 +170,8 @@ pub enum FunctionError {
WorkgroupUniformLoadExpressionMismatch(Handle<crate::Expression>),
#[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")]
WorkgroupUniformLoadInvalidPointer(Handle<crate::Expression>),
+ #[error("Subgroup operation is invalid")]
+ InvalidSubgroup(#[from] SubgroupError),
}
bitflags::bitflags! {
@@ -399,6 +416,127 @@ impl super::Validator {
}
Ok(())
}
+ fn validate_subgroup_operation(
+ &mut self,
+ op: &crate::SubgroupOperation,
+ collective_op: &crate::CollectiveOperation,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
+
+ let (is_scalar, scalar) = match *argument_inner {
+ crate::TypeInner::Scalar(scalar) => (true, scalar),
+ crate::TypeInner::Vector { scalar, .. } => (false, scalar),
+ _ => {
+ log::error!("Subgroup operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+ };
+
+ use crate::ScalarKind as sk;
+ use crate::SubgroupOperation as sg;
+ match (scalar.kind, *op) {
+ (sk::Bool, sg::All | sg::Any) if is_scalar => {}
+ (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {}
+ (sk::Sint | sk::Uint, sg::And | sg::Or | sg::Xor) => {}
+
+ (_, _) => {
+ log::error!("Subgroup operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+ };
+
+ use crate::CollectiveOperation as co;
+ match (*collective_op, *op) {
+ (
+ co::Reduce,
+ sg::All
+ | sg::Any
+ | sg::Add
+ | sg::Mul
+ | sg::Min
+ | sg::Max
+ | sg::And
+ | sg::Or
+ | sg::Xor,
+ ) => {}
+ (co::InclusiveScan | co::ExclusiveScan, sg::Add | sg::Mul) => {}
+
+ (_, _) => {
+ return Err(SubgroupError::UnknownOperation.with_span().into_other());
+ }
+ };
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::SubgroupOperationResult { ty }
+ if { &context.types[ty].inner == argument_inner } => {}
+ _ => {
+ return Err(SubgroupError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
+ fn validate_subgroup_gather(
+ &mut self,
+ mode: &crate::GatherMode,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ context: &BlockContext,
+ ) -> Result<(), WithSpan<FunctionError>> {
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let index_ty = context.resolve_type(index, &self.valid_expression_set)?;
+ match *index_ty {
+ crate::TypeInner::Scalar(crate::Scalar::U32) => {}
+ _ => {
+ log::error!(
+ "Subgroup gather index type {:?}, expected unsigned int",
+ index_ty
+ );
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(index, context.expressions)
+ .into_other());
+ }
+ }
+ }
+ }
+ let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?;
+ if !matches!(*argument_inner,
+ crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. }
+ if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float)
+ ) {
+ log::error!("Subgroup gather operand type {:?}", argument_inner);
+ return Err(SubgroupError::InvalidOperand(argument)
+ .with_span_handle(argument, context.expressions)
+ .into_other());
+ }
+
+ self.emit_expression(result, context)?;
+ match context.expressions[result] {
+ crate::Expression::SubgroupOperationResult { ty }
+ if { &context.types[ty].inner == argument_inner } => {}
+ _ => {
+ return Err(SubgroupError::ResultTypeMismatch(result)
+ .with_span_handle(result, context.expressions)
+ .into_other())
+ }
+ }
+ Ok(())
+ }
fn validate_block_impl(
&mut self,
@@ -613,8 +751,30 @@ impl super::Validator {
stages &= super::ShaderStages::FRAGMENT;
finished = true;
}
- S::Barrier(_) => {
+ S::Barrier(barrier) => {
stages &= super::ShaderStages::COMPUTE;
+ if barrier.contains(crate::Barrier::SUB_GROUP) {
+ if !self.capabilities.contains(
+ super::Capabilities::SUBGROUP | super::Capabilities::SUBGROUP_BARRIER,
+ ) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP
+ | super::Capabilities::SUBGROUP_BARRIER,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ if !self
+ .subgroup_operations
+ .contains(super::SubgroupOperationSet::BASIC)
+ {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(
+ super::SubgroupOperationSet::BASIC,
+ ),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ }
}
S::Store { pointer, value } => {
let mut current = pointer;
@@ -904,6 +1064,86 @@ impl super::Validator {
crate::RayQueryFunction::Terminate => {}
}
}
+ S::SubgroupBallot { result, predicate } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ if !self
+ .subgroup_operations
+ .contains(super::SubgroupOperationSet::BALLOT)
+ {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(
+ super::SubgroupOperationSet::BALLOT,
+ ),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ if let Some(predicate) = predicate {
+ let predicate_inner =
+ context.resolve_type(predicate, &self.valid_expression_set)?;
+ if !matches!(
+ *predicate_inner,
+ crate::TypeInner::Scalar(crate::Scalar::BOOL,)
+ ) {
+ log::error!(
+ "Subgroup ballot predicate type {:?} expected bool",
+ predicate_inner
+ );
+ return Err(SubgroupError::InvalidOperand(predicate)
+ .with_span_handle(predicate, context.expressions)
+ .into_other());
+ }
+ }
+ self.emit_expression(result, context)?;
+ }
+ S::SubgroupCollectiveOperation {
+ ref op,
+ ref collective_op,
+ argument,
+ result,
+ } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ let operation = op.required_operations();
+ if !self.subgroup_operations.contains(operation) {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(operation),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
+ }
+ S::SubgroupGather {
+ ref mode,
+ argument,
+ result,
+ } => {
+ stages &= self.subgroup_stages;
+ if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
+ return Err(FunctionError::MissingCapability(
+ super::Capabilities::SUBGROUP,
+ )
+ .with_span_static(span, "missing capability for this operation"));
+ }
+ let operation = mode.required_operations();
+ if !self.subgroup_operations.contains(operation) {
+ return Err(FunctionError::InvalidSubgroup(
+ SubgroupError::UnsupportedOperation(operation),
+ )
+ .with_span_static(span, "support for this operation is not present"));
+ }
+ self.validate_subgroup_gather(mode, argument, result, context)?;
+ }
}
}
Ok(BlockInfo { stages, finished })
@@ -927,7 +1167,7 @@ impl super::Validator {
var: &crate::LocalVariable,
gctx: crate::proc::GlobalCtx,
fun_info: &FunctionInfo,
- expression_constness: &crate::proc::ExpressionConstnessTracker,
+ local_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<(), LocalVariableError> {
log::debug!("var {:?}", var);
let type_info = self
@@ -945,8 +1185,8 @@ impl super::Validator {
return Err(LocalVariableError::InitializerType);
}
- if !expression_constness.is_const(init) {
- return Err(LocalVariableError::NonConstInitializer);
+ if !local_expr_kind.is_const_or_override(init) {
+ return Err(LocalVariableError::NonConstOrOverrideInitializer);
}
}
@@ -959,14 +1199,14 @@ impl super::Validator {
module: &crate::Module,
mod_info: &ModuleInfo,
entry_point: bool,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<FunctionInfo, WithSpan<FunctionError>> {
let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
- let expression_constness =
- crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
+ let local_expr_kind = crate::proc::ExpressionKindTracker::from_arena(&fun.expressions);
for (var_handle, var) in fun.local_variables.iter() {
- self.validate_local_var(var, module.to_ctx(), &info, &expression_constness)
+ self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
.map_err(|source| {
FunctionError::LocalVariable {
handle: var_handle,
@@ -1032,7 +1272,15 @@ impl super::Validator {
self.valid_expression_set.insert(handle.index());
}
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
- match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
+ match self.validate_expression(
+ handle,
+ expr,
+ fun,
+ module,
+ &info,
+ mod_info,
+ global_expr_kind,
+ ) {
Ok(stages) => info.available_stages &= stages,
Err(source) => {
return Err(FunctionError::Expression { handle, source }
diff --git a/third_party/rust/naga/src/valid/handles.rs b/third_party/rust/naga/src/valid/handles.rs
index e482f293bb..8f78204055 100644
--- a/third_party/rust/naga/src/valid/handles.rs
+++ b/third_party/rust/naga/src/valid/handles.rs
@@ -31,12 +31,13 @@ impl super::Validator {
pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
let &crate::Module {
ref constants,
+ ref overrides,
ref entry_points,
ref functions,
ref global_variables,
ref types,
ref special_types,
- ref const_expressions,
+ ref global_expressions,
} = module;
// NOTE: Types being first is important. All other forms of validation depend on this.
@@ -67,23 +68,31 @@ impl super::Validator {
}
}
- for handle_and_expr in const_expressions.iter() {
- Self::validate_const_expression_handles(handle_and_expr, constants, types)?;
+ for handle_and_expr in global_expressions.iter() {
+ Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?;
}
let validate_type = |handle| Self::validate_type_handle(handle, types);
let validate_const_expr =
- |handle| Self::validate_expression_handle(handle, const_expressions);
+ |handle| Self::validate_expression_handle(handle, global_expressions);
for (_handle, constant) in constants.iter() {
- let &crate::Constant {
+ let &crate::Constant { name: _, ty, init } = constant;
+ validate_type(ty)?;
+ validate_const_expr(init)?;
+ }
+
+ for (_handle, override_) in overrides.iter() {
+ let &crate::Override {
name: _,
- r#override: _,
+ id: _,
ty,
init,
- } = constant;
+ } = override_;
validate_type(ty)?;
- validate_const_expr(init)?;
+ if let Some(init_expr) = init {
+ validate_const_expr(init_expr)?;
+ }
}
for (_handle, global_variable) in global_variables.iter() {
@@ -140,7 +149,8 @@ impl super::Validator {
Self::validate_expression_handles(
handle_and_expr,
constants,
- const_expressions,
+ overrides,
+ global_expressions,
types,
local_variables,
global_variables,
@@ -186,6 +196,13 @@ impl super::Validator {
handle.check_valid_for(constants).map(|_| ())
}
+ fn validate_override_handle(
+ handle: Handle<crate::Override>,
+ overrides: &Arena<crate::Override>,
+ ) -> Result<(), InvalidHandleError> {
+ handle.check_valid_for(overrides).map(|_| ())
+ }
+
fn validate_expression_handle(
handle: Handle<crate::Expression>,
expressions: &Arena<crate::Expression>,
@@ -203,9 +220,11 @@ impl super::Validator {
fn validate_const_expression_handles(
(handle, expression): (Handle<crate::Expression>, &crate::Expression),
constants: &Arena<crate::Constant>,
+ overrides: &Arena<crate::Override>,
types: &UniqueArena<crate::Type>,
) -> Result<(), InvalidHandleError> {
let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_override = |handle| Self::validate_override_handle(handle, overrides);
let validate_type = |handle| Self::validate_type_handle(handle, types);
match *expression {
@@ -214,6 +233,12 @@ impl super::Validator {
validate_constant(constant)?;
handle.check_dep(constants[constant].init)?;
}
+ crate::Expression::Override(override_) => {
+ validate_override(override_)?;
+ if let Some(init) = overrides[override_].init {
+ handle.check_dep(init)?;
+ }
+ }
crate::Expression::ZeroValue(ty) => {
validate_type(ty)?;
}
@@ -230,7 +255,8 @@ impl super::Validator {
fn validate_expression_handles(
(handle, expression): (Handle<crate::Expression>, &crate::Expression),
constants: &Arena<crate::Constant>,
- const_expressions: &Arena<crate::Expression>,
+ overrides: &Arena<crate::Override>,
+ global_expressions: &Arena<crate::Expression>,
types: &UniqueArena<crate::Type>,
local_variables: &Arena<crate::LocalVariable>,
global_variables: &Arena<crate::GlobalVariable>,
@@ -239,8 +265,9 @@ impl super::Validator {
current_function: Option<Handle<crate::Function>>,
) -> Result<(), InvalidHandleError> {
let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
+ let validate_override = |handle| Self::validate_override_handle(handle, overrides);
let validate_const_expr =
- |handle| Self::validate_expression_handle(handle, const_expressions);
+ |handle| Self::validate_expression_handle(handle, global_expressions);
let validate_type = |handle| Self::validate_type_handle(handle, types);
match *expression {
@@ -260,6 +287,9 @@ impl super::Validator {
crate::Expression::Constant(constant) => {
validate_constant(constant)?;
}
+ crate::Expression::Override(override_) => {
+ validate_override(override_)?;
+ }
crate::Expression::ZeroValue(ty) => {
validate_type(ty)?;
}
@@ -390,6 +420,8 @@ impl super::Validator {
}
crate::Expression::AtomicResult { .. }
| crate::Expression::RayQueryProceedResult
+ | crate::Expression::SubgroupBallotResult
+ | crate::Expression::SubgroupOperationResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. } => (),
crate::Expression::ArrayLength(array) => {
handle.check_dep(array)?;
@@ -535,6 +567,38 @@ impl super::Validator {
}
Ok(())
}
+ crate::Statement::SubgroupBallot { result, predicate } => {
+ validate_expr_opt(predicate)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::SubgroupCollectiveOperation {
+ op: _,
+ collective_op: _,
+ argument,
+ result,
+ } => {
+ validate_expr(argument)?;
+ validate_expr(result)?;
+ Ok(())
+ }
+ crate::Statement::SubgroupGather {
+ mode,
+ argument,
+ result,
+ } => {
+ validate_expr(argument)?;
+ match mode {
+ crate::GatherMode::BroadcastFirst => {}
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?,
+ }
+ validate_expr(result)?;
+ Ok(())
+ }
crate::Statement::Break
| crate::Statement::Continue
| crate::Statement::Kill
@@ -562,6 +626,7 @@ impl From<BadRangeError> for ValidationError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum InvalidHandleError {
#[error(transparent)]
BadHandle(#[from] BadHandle),
@@ -572,6 +637,7 @@ pub enum InvalidHandleError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
#[error(
"{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
which has not been processed yet"
@@ -664,6 +730,7 @@ fn constant_deps() {
let mut const_exprs = Arena::new();
let mut fun_exprs = Arena::new();
let mut constants = Arena::new();
+ let overrides = Arena::new();
let i32_handle = types.insert(
Type {
@@ -679,7 +746,6 @@ fn constant_deps() {
let self_referential_const = constants.append(
Constant {
name: None,
- r#override: crate::Override::None,
ty: i32_handle,
init: fun_expr,
},
@@ -692,6 +758,7 @@ fn constant_deps() {
assert!(super::Validator::validate_const_expression_handles(
handle_and_expr,
&constants,
+ &overrides,
&types,
)
.is_err());
diff --git a/third_party/rust/naga/src/valid/interface.rs b/third_party/rust/naga/src/valid/interface.rs
index 84c8b09ddb..db890ddbac 100644
--- a/third_party/rust/naga/src/valid/interface.rs
+++ b/third_party/rust/naga/src/valid/interface.rs
@@ -10,6 +10,7 @@ use bit_set::BitSet;
const MAX_WORKGROUP_SIZE: u32 = 0x4000;
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum GlobalVariableError {
#[error("Usage isn't compatible with address space {0:?}")]
InvalidUsage(crate::AddressSpace),
@@ -30,6 +31,8 @@ pub enum GlobalVariableError {
Handle<crate::Type>,
#[source] Disalignment,
),
+ #[error("Initializer must be an override-expression")]
+ InitializerExprType,
#[error("Initializer doesn't match the variable type")]
InitializerType,
#[error("Initializer can't be used with address space {0:?}")]
@@ -39,6 +42,7 @@ pub enum GlobalVariableError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum VaryingError {
#[error("The type {0:?} does not match the varying")]
InvalidType(Handle<crate::Type>),
@@ -73,9 +77,12 @@ pub enum VaryingError {
location: u32,
attribute: &'static str,
},
+ #[error("Workgroup size is multi dimensional, @builtin(subgroup_id) and @builtin(subgroup_invocation_id) are not supported.")]
+ InvalidMultiDimensionalSubgroupBuiltIn,
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum EntryPointError {
#[error("Multiple conflicting entry points")]
Conflict,
@@ -135,6 +142,7 @@ struct VaryingContext<'a> {
impl VaryingContext<'_> {
fn validate_impl(
&mut self,
+ ep: &crate::EntryPoint,
ty: Handle<crate::Type>,
binding: &crate::Binding,
) -> Result<(), VaryingError> {
@@ -162,12 +170,24 @@ impl VaryingContext<'_> {
Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX,
Bi::ViewIndex => Capabilities::MULTIVIEW,
Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING,
+ Bi::NumSubgroups
+ | Bi::SubgroupId
+ | Bi::SubgroupSize
+ | Bi::SubgroupInvocationId => Capabilities::SUBGROUP,
_ => Capabilities::empty(),
};
if !self.capabilities.contains(required) {
return Err(VaryingError::UnsupportedCapability(required));
}
+ if matches!(
+ built_in,
+ crate::BuiltIn::SubgroupId | crate::BuiltIn::SubgroupInvocationId
+ ) && ep.workgroup_size[1..].iter().any(|&s| s > 1)
+ {
+ return Err(VaryingError::InvalidMultiDimensionalSubgroupBuiltIn);
+ }
+
let (visible, type_good) = match built_in {
Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => (
self.stage == St::Vertex && !self.output,
@@ -249,6 +269,17 @@ impl VaryingContext<'_> {
scalar: crate::Scalar::U32,
},
),
+ Bi::NumSubgroups | Bi::SubgroupId => (
+ self.stage == St::Compute && !self.output,
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
+ Bi::SubgroupSize | Bi::SubgroupInvocationId => (
+ match self.stage {
+ St::Compute | St::Fragment => !self.output,
+ St::Vertex => false,
+ },
+ *ty_inner == Ti::Scalar(crate::Scalar::U32),
+ ),
};
if !visible {
@@ -349,13 +380,14 @@ impl VaryingContext<'_> {
fn validate(
&mut self,
+ ep: &crate::EntryPoint,
ty: Handle<crate::Type>,
binding: Option<&crate::Binding>,
) -> Result<(), WithSpan<VaryingError>> {
let span_context = self.types.get_span_context(ty);
match binding {
Some(binding) => self
- .validate_impl(ty, binding)
+ .validate_impl(ep, ty, binding)
.map_err(|e| e.with_span_context(span_context)),
None => {
match self.types[ty].inner {
@@ -372,7 +404,7 @@ impl VaryingContext<'_> {
}
}
Some(ref binding) => self
- .validate_impl(member.ty, binding)
+ .validate_impl(ep, member.ty, binding)
.map_err(|e| e.with_span_context(span_context))?,
}
}
@@ -395,6 +427,7 @@ impl super::Validator {
var: &crate::GlobalVariable,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<(), GlobalVariableError> {
use super::TypeFlags;
@@ -523,6 +556,10 @@ impl super::Validator {
}
}
+ if !global_expr_kind.is_const_or_override(init) {
+ return Err(GlobalVariableError::InitializerExprType);
+ }
+
let decl_ty = &gctx.types[var.ty].inner;
let init_ty = mod_info[init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
@@ -538,6 +575,7 @@ impl super::Validator {
ep: &crate::EntryPoint,
module: &crate::Module,
mod_info: &ModuleInfo,
+ global_expr_kind: &crate::proc::ExpressionKindTracker,
) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
if ep.early_depth_test.is_some() {
let required = Capabilities::EARLY_DEPTH_TEST;
@@ -566,7 +604,7 @@ impl super::Validator {
}
let mut info = self
- .validate_function(&ep.function, module, mod_info, true)
+ .validate_function(&ep.function, module, mod_info, true, global_expr_kind)
.map_err(WithSpan::into_other)?;
{
@@ -598,7 +636,7 @@ impl super::Validator {
capabilities: self.capabilities,
flags: self.flags,
};
- ctx.validate(fa.ty, fa.binding.as_ref())
+ ctx.validate(ep, fa.ty, fa.binding.as_ref())
.map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?;
}
@@ -616,7 +654,7 @@ impl super::Validator {
capabilities: self.capabilities,
flags: self.flags,
};
- ctx.validate(fr.ty, fr.binding.as_ref())
+ ctx.validate(ep, fr.ty, fr.binding.as_ref())
.map_err_inner(|e| EntryPointError::Result(e).with_span())?;
if ctx.second_blend_source {
// Only the first location may be used when dual source blending
diff --git a/third_party/rust/naga/src/valid/mod.rs b/third_party/rust/naga/src/valid/mod.rs
index 5459434f33..a0057f39ac 100644
--- a/third_party/rust/naga/src/valid/mod.rs
+++ b/third_party/rust/naga/src/valid/mod.rs
@@ -12,7 +12,7 @@ mod r#type;
use crate::{
arena::Handle,
- proc::{LayoutError, Layouter, TypeResolution},
+ proc::{ExpressionKindTracker, LayoutError, Layouter, TypeResolution},
FastHashSet,
};
use bit_set::BitSet;
@@ -77,7 +77,7 @@ bitflags::bitflags! {
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
- pub struct Capabilities: u16 {
+ pub struct Capabilities: u32 {
/// Support for [`AddressSpace:PushConstant`].
const PUSH_CONSTANT = 0x1;
/// Float values with width = 8.
@@ -110,6 +110,10 @@ bitflags::bitflags! {
const CUBE_ARRAY_TEXTURES = 0x4000;
/// Support for 64-bit signed and unsigned integers.
const SHADER_INT64 = 0x8000;
+ /// Support for subgroup operations.
+ const SUBGROUP = 0x10000;
+ /// Support for subgroup barriers.
+ const SUBGROUP_BARRIER = 0x20000;
}
}
@@ -120,6 +124,57 @@ impl Default for Capabilities {
}
bitflags::bitflags! {
+ /// Supported subgroup operations
+ #[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+ #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+ #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+ pub struct SubgroupOperationSet: u8 {
+ /// Elect, Barrier
+ const BASIC = 1 << 0;
+ /// Any, All
+ const VOTE = 1 << 1;
+ /// reductions, scans
+ const ARITHMETIC = 1 << 2;
+ /// ballot, broadcast
+ const BALLOT = 1 << 3;
+ /// shuffle, shuffle xor
+ const SHUFFLE = 1 << 4;
+ /// shuffle up, down
+ const SHUFFLE_RELATIVE = 1 << 5;
+ // We don't support these operations yet
+ // /// Clustered
+ // const CLUSTERED = 1 << 6;
+ // /// Quad supported
+ // const QUAD_FRAGMENT_COMPUTE = 1 << 7;
+ // /// Quad supported in all stages
+ // const QUAD_ALL_STAGES = 1 << 8;
+ }
+}
+
+impl super::SubgroupOperation {
+ const fn required_operations(&self) -> SubgroupOperationSet {
+ use SubgroupOperationSet as S;
+ match *self {
+ Self::All | Self::Any => S::VOTE,
+ Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
+ S::ARITHMETIC
+ }
+ }
+ }
+}
+
+impl super::GatherMode {
+ const fn required_operations(&self) -> SubgroupOperationSet {
+ use SubgroupOperationSet as S;
+ match *self {
+ Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
+ Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
+ Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
+ }
+ }
+}
+
+bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
@@ -131,7 +186,7 @@ bitflags::bitflags! {
}
}
-#[derive(Debug)]
+#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
@@ -166,6 +221,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
pub struct Validator {
flags: ValidationFlags,
capabilities: Capabilities,
+ subgroup_stages: ShaderStages,
+ subgroup_operations: SubgroupOperationSet,
types: Vec<r#type::TypeInfo>,
layouter: Layouter,
location_mask: BitSet,
@@ -174,10 +231,15 @@ pub struct Validator {
switch_values: FastHashSet<crate::SwitchValue>,
valid_expression_list: Vec<Handle<crate::Expression>>,
valid_expression_set: BitSet,
+ override_ids: FastHashSet<u16>,
+ allow_overrides: bool,
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantError {
+ #[error("Initializer must be a const-expression")]
+ InitializerExprType,
#[error("The type doesn't match the constant")]
InvalidType,
#[error("The type is not constructible")]
@@ -185,6 +247,26 @@ pub enum ConstantError {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
+pub enum OverrideError {
+ #[error("Override name and ID are missing")]
+ MissingNameAndID,
+ #[error("Override ID must be unique")]
+ DuplicateID,
+ #[error("Initializer must be a const-expression or override-expression")]
+ InitializerExprType,
+ #[error("The type doesn't match the override")]
+ InvalidType,
+ #[error("The type is not constructible")]
+ NonConstructibleType,
+ #[error("The type is not a scalar")]
+ TypeNotScalar,
+ #[error("Override declarations are not allowed")]
+ NotAllowed,
+}
+
+#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum ValidationError {
#[error(transparent)]
InvalidHandle(#[from] InvalidHandleError),
@@ -207,6 +289,12 @@ pub enum ValidationError {
name: String,
source: ConstantError,
},
+ #[error("Override {handle:?} '{name}' is invalid")]
+ Override {
+ handle: Handle<crate::Override>,
+ name: String,
+ source: OverrideError,
+ },
#[error("Global variable {handle:?} '{name}' is invalid")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
@@ -286,6 +374,8 @@ impl Validator {
Validator {
flags,
capabilities,
+ subgroup_stages: ShaderStages::empty(),
+ subgroup_operations: SubgroupOperationSet::empty(),
types: Vec::new(),
layouter: Layouter::default(),
location_mask: BitSet::new(),
@@ -293,9 +383,21 @@ impl Validator {
switch_values: FastHashSet::default(),
valid_expression_list: Vec::new(),
valid_expression_set: BitSet::new(),
+ override_ids: FastHashSet::default(),
+ allow_overrides: true,
}
}
+ pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
+ self.subgroup_stages = stages;
+ self
+ }
+
+ pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
+ self.subgroup_operations = operations;
+ self
+ }
+
/// Reset the validator internals
pub fn reset(&mut self) {
self.types.clear();
@@ -305,6 +407,7 @@ impl Validator {
self.switch_values.clear();
self.valid_expression_list.clear();
self.valid_expression_set.clear();
+ self.override_ids.clear();
}
fn validate_constant(
@@ -312,6 +415,7 @@ impl Validator {
handle: Handle<crate::Constant>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
+ global_expr_kind: &ExpressionKindTracker,
) -> Result<(), ConstantError> {
let con = &gctx.constants[handle];
@@ -320,6 +424,10 @@ impl Validator {
return Err(ConstantError::NonConstructibleType);
}
+ if !global_expr_kind.is_const(con.init) {
+ return Err(ConstantError::InitializerExprType);
+ }
+
let decl_ty = &gctx.types[con.ty].inner;
let init_ty = mod_info[con.init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
@@ -329,11 +437,80 @@ impl Validator {
Ok(())
}
+ fn validate_override(
+ &mut self,
+ handle: Handle<crate::Override>,
+ gctx: crate::proc::GlobalCtx,
+ mod_info: &ModuleInfo,
+ ) -> Result<(), OverrideError> {
+ if !self.allow_overrides {
+ return Err(OverrideError::NotAllowed);
+ }
+
+ let o = &gctx.overrides[handle];
+
+ if o.name.is_none() && o.id.is_none() {
+ return Err(OverrideError::MissingNameAndID);
+ }
+
+ if let Some(id) = o.id {
+ if !self.override_ids.insert(id) {
+ return Err(OverrideError::DuplicateID);
+ }
+ }
+
+ let type_info = &self.types[o.ty.index()];
+ if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
+ return Err(OverrideError::NonConstructibleType);
+ }
+
+ let decl_ty = &gctx.types[o.ty].inner;
+ match decl_ty {
+ &crate::TypeInner::Scalar(scalar) => match scalar {
+ crate::Scalar::BOOL
+ | crate::Scalar::I32
+ | crate::Scalar::U32
+ | crate::Scalar::F32
+ | crate::Scalar::F64 => {}
+ _ => return Err(OverrideError::TypeNotScalar),
+ },
+ _ => return Err(OverrideError::TypeNotScalar),
+ }
+
+ if let Some(init) = o.init {
+ let init_ty = mod_info[init].inner_with(gctx.types);
+ if !decl_ty.equivalent(init_ty, gctx.types) {
+ return Err(OverrideError::InvalidType);
+ }
+ }
+
+ Ok(())
+ }
+
/// Check the given module to be valid.
pub fn validate(
&mut self,
module: &crate::Module,
) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.allow_overrides = true;
+ self.validate_impl(module)
+ }
+
+ /// Check the given module to be valid.
+ ///
+ /// With the additional restriction that overrides are not present.
+ pub fn validate_no_overrides(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
+ self.allow_overrides = false;
+ self.validate_impl(module)
+ }
+
+ fn validate_impl(
+ &mut self,
+ module: &crate::Module,
+ ) -> Result<ModuleInfo, WithSpan<ValidationError>> {
self.reset();
self.reset_types(module.types.len());
@@ -354,7 +531,7 @@ impl Validator {
type_flags: Vec::with_capacity(module.types.len()),
functions: Vec::with_capacity(module.functions.len()),
entry_points: Vec::with_capacity(module.entry_points.len()),
- const_expression_types: vec![placeholder; module.const_expressions.len()]
+ const_expression_types: vec![placeholder; module.global_expressions.len()]
.into_boxed_slice(),
};
@@ -376,27 +553,34 @@ impl Validator {
{
let t = crate::Arena::new();
let resolve_context = crate::proc::ResolveContext::with_locals(module, &t, &[]);
- for (handle, _) in module.const_expressions.iter() {
+ for (handle, _) in module.global_expressions.iter() {
mod_info
.process_const_expression(handle, &resolve_context, module.to_ctx())
.map_err(|source| {
ValidationError::ConstExpression { handle, source }
- .with_span_handle(handle, &module.const_expressions)
+ .with_span_handle(handle, &module.global_expressions)
})?
}
}
+ let global_expr_kind = ExpressionKindTracker::from_arena(&module.global_expressions);
+
if self.flags.contains(ValidationFlags::CONSTANTS) {
- for (handle, _) in module.const_expressions.iter() {
- self.validate_const_expression(handle, module.to_ctx(), &mod_info)
- .map_err(|source| {
- ValidationError::ConstExpression { handle, source }
- .with_span_handle(handle, &module.const_expressions)
- })?
+ for (handle, _) in module.global_expressions.iter() {
+ self.validate_const_expression(
+ handle,
+ module.to_ctx(),
+ &mod_info,
+ &global_expr_kind,
+ )
+ .map_err(|source| {
+ ValidationError::ConstExpression { handle, source }
+ .with_span_handle(handle, &module.global_expressions)
+ })?
}
for (handle, constant) in module.constants.iter() {
- self.validate_constant(handle, module.to_ctx(), &mod_info)
+ self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::Constant {
handle,
@@ -406,10 +590,22 @@ impl Validator {
.with_span_handle(handle, &module.constants)
})?
}
+
+ for (handle, override_) in module.overrides.iter() {
+ self.validate_override(handle, module.to_ctx(), &mod_info)
+ .map_err(|source| {
+ ValidationError::Override {
+ handle,
+ name: override_.name.clone().unwrap_or_default(),
+ source,
+ }
+ .with_span_handle(handle, &module.overrides)
+ })?
+ }
}
for (var_handle, var) in module.global_variables.iter() {
- self.validate_global_var(var, module.to_ctx(), &mod_info)
+ self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::GlobalVariable {
handle: var_handle,
@@ -421,7 +617,7 @@ impl Validator {
}
for (handle, fun) in module.functions.iter() {
- match self.validate_function(fun, module, &mod_info, false) {
+ match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
Ok(info) => mod_info.functions.push(info),
Err(error) => {
return Err(error.and_then(|source| {
@@ -447,7 +643,7 @@ impl Validator {
.with_span()); // TODO: keep some EP span information?
}
- match self.validate_entry_point(ep, module, &mod_info) {
+ match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
Ok(info) => mod_info.entry_points.push(info),
Err(error) => {
return Err(error.and_then(|source| {
diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs
index b8eb618ed4..f5b9856074 100644
--- a/third_party/rust/naga/src/valid/type.rs
+++ b/third_party/rust/naga/src/valid/type.rs
@@ -63,6 +63,7 @@ bitflags::bitflags! {
}
#[derive(Clone, Copy, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum Disalignment {
#[error("The array stride {stride} is not a multiple of the required alignment {alignment}")]
ArrayStride { stride: u32, alignment: Alignment },
@@ -87,6 +88,7 @@ pub enum Disalignment {
}
#[derive(Clone, Debug, thiserror::Error)]
+#[cfg_attr(test, derive(PartialEq))]
pub enum TypeError {
#[error("Capability {0:?} is required")]
MissingCapability(Capabilities),
@@ -326,7 +328,6 @@ impl super::Validator {
TypeFlags::DATA
| TypeFlags::SIZED
| TypeFlags::COPY
- | TypeFlags::HOST_SHAREABLE
| TypeFlags::ARGUMENT
| TypeFlags::CONSTRUCTIBLE
| shareable,