summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/front/spv/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/front/spv/mod.rs')
-rw-r--r--third_party/rust/naga/src/front/spv/mod.rs455
1 files changed, 349 insertions, 106 deletions
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs
index b793448597..7ac5a18cd6 100644
--- a/third_party/rust/naga/src/front/spv/mod.rs
+++ b/third_party/rust/naga/src/front/spv/mod.rs
@@ -196,7 +196,7 @@ struct Decoration {
location: Option<spirv::Word>,
desc_set: Option<spirv::Word>,
desc_index: Option<spirv::Word>,
- specialization: Option<spirv::Word>,
+ specialization_constant_id: Option<spirv::Word>,
storage_buffer: bool,
offset: Option<spirv::Word>,
array_stride: Option<NonZeroU32>,
@@ -216,11 +216,6 @@ impl Decoration {
}
}
- fn specialization(&self) -> crate::Override {
- self.specialization
- .map_or(crate::Override::None, crate::Override::ByNameOrId)
- }
-
const fn resource_binding(&self) -> Option<crate::ResourceBinding> {
match *self {
Decoration {
@@ -284,8 +279,23 @@ struct LookupType {
}
#[derive(Debug)]
+enum Constant {
+ Constant(Handle<crate::Constant>),
+ Override(Handle<crate::Override>),
+}
+
+impl Constant {
+ const fn to_expr(&self) -> crate::Expression {
+ match *self {
+ Self::Constant(c) => crate::Expression::Constant(c),
+ Self::Override(o) => crate::Expression::Override(o),
+ }
+ }
+}
+
+#[derive(Debug)]
struct LookupConstant {
- handle: Handle<crate::Constant>,
+ inner: Constant,
type_id: spirv::Word,
}
@@ -537,7 +547,8 @@ struct BlockContext<'function> {
local_arena: &'function mut Arena<crate::LocalVariable>,
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
- const_expressions: &'function mut Arena<crate::Expression>,
+ overrides: &'function mut Arena<crate::Override>,
+ global_expressions: &'function mut Arena<crate::Expression>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
/// Global arena of the module being processed
@@ -757,7 +768,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
dec.matrix_major = Some(Majority::Row);
}
spirv::Decoration::SpecId => {
- dec.specialization = Some(self.next()?);
+ dec.specialization_constant_id = Some(self.next()?);
}
other => {
log::warn!("Unknown decoration {:?}", other);
@@ -1393,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
inst.expect(5)?;
let init_id = self.next()?;
let lconst = self.lookup_constant.lookup(init_id)?;
- Some(
- ctx.expressions
- .append(crate::Expression::Constant(lconst.handle), span),
- )
+ Some(ctx.expressions.append(lconst.inner.to_expr(), span))
} else {
None
};
@@ -3650,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let semantics_const = self.lookup_constant.lookup(semantics_id)?;
- let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle)
+ let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
- let semantics = resolve_constant(ctx.gctx(), semantics_const.handle)
+ let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;
if exec_scope == spirv::Scope::Workgroup as u32 {
@@ -3692,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
},
);
}
+ Op::GroupNonUniformBallot => {
+ inst.expect(5)?;
+ block.extend(emitter.finish(ctx.expressions));
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let predicate_id = self.next()?;
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let predicate = if self
+ .lookup_constant
+ .lookup(predicate_id)
+ .ok()
+ .filter(|predicate_const| match predicate_const.inner {
+ Constant::Constant(constant) => matches!(
+ ctx.gctx().global_expressions[ctx.gctx().constants[constant].init],
+ crate::Expression::Literal(crate::Literal::Bool(true)),
+ ),
+ Constant::Override(_) => false,
+ })
+ .is_some()
+ {
+ None
+ } else {
+ let predicate_lookup = self.lookup_expression.lookup(predicate_id)?;
+ let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup);
+ Some(predicate_handle)
+ };
+
+ let result_handle = ctx
+ .expressions
+ .append(crate::Expression::SubgroupBallotResult, span);
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupBallot {
+ result: result_handle,
+ predicate,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ spirv::Op::GroupNonUniformAll
+ | spirv::Op::GroupNonUniformAny
+ | spirv::Op::GroupNonUniformIAdd
+ | spirv::Op::GroupNonUniformFAdd
+ | spirv::Op::GroupNonUniformIMul
+ | spirv::Op::GroupNonUniformFMul
+ | spirv::Op::GroupNonUniformSMax
+ | spirv::Op::GroupNonUniformUMax
+ | spirv::Op::GroupNonUniformFMax
+ | spirv::Op::GroupNonUniformSMin
+ | spirv::Op::GroupNonUniformUMin
+ | spirv::Op::GroupNonUniformFMin
+ | spirv::Op::GroupNonUniformBitwiseAnd
+ | spirv::Op::GroupNonUniformBitwiseOr
+ | spirv::Op::GroupNonUniformBitwiseXor
+ | spirv::Op::GroupNonUniformLogicalAnd
+ | spirv::Op::GroupNonUniformLogicalOr
+ | spirv::Op::GroupNonUniformLogicalXor => {
+ block.extend(emitter.finish(ctx.expressions));
+ inst.expect(
+ if matches!(
+ inst.op,
+ spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny
+ ) {
+ 5
+ } else {
+ 6
+ },
+ )?;
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let collective_op_id = match inst.op {
+ spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => {
+ crate::CollectiveOperation::Reduce
+ }
+ _ => {
+ let group_op_id = self.next()?;
+ match spirv::GroupOperation::from_u32(group_op_id) {
+ Some(spirv::GroupOperation::Reduce) => {
+ crate::CollectiveOperation::Reduce
+ }
+ Some(spirv::GroupOperation::InclusiveScan) => {
+ crate::CollectiveOperation::InclusiveScan
+ }
+ Some(spirv::GroupOperation::ExclusiveScan) => {
+ crate::CollectiveOperation::ExclusiveScan
+ }
+ _ => return Err(Error::UnsupportedGroupOperation(group_op_id)),
+ }
+ }
+ };
+ let argument_id = self.next()?;
+
+ let argument_lookup = self.lookup_expression.lookup(argument_id)?;
+ let argument_handle = get_expr_handle!(argument_id, argument_lookup);
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let op_id = match inst.op {
+ spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All,
+ spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any,
+ spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => {
+ crate::SubgroupOperation::Add
+ }
+ spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => {
+ crate::SubgroupOperation::Mul
+ }
+ spirv::Op::GroupNonUniformSMax
+ | spirv::Op::GroupNonUniformUMax
+ | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max,
+ spirv::Op::GroupNonUniformSMin
+ | spirv::Op::GroupNonUniformUMin
+ | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min,
+ spirv::Op::GroupNonUniformBitwiseAnd
+ | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And,
+ spirv::Op::GroupNonUniformBitwiseOr
+ | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or,
+ spirv::Op::GroupNonUniformBitwiseXor
+ | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor,
+ _ => unreachable!(),
+ };
+
+ let result_type = self.lookup_type.lookup(result_type_id)?;
+
+ let result_handle = ctx.expressions.append(
+ crate::Expression::SubgroupOperationResult {
+ ty: result_type.handle,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupCollectiveOperation {
+ result: result_handle,
+ op: op_id,
+ collective_op: collective_op_id,
+ argument: argument_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
+ Op::GroupNonUniformBroadcastFirst
+ | Op::GroupNonUniformBroadcast
+ | Op::GroupNonUniformShuffle
+ | Op::GroupNonUniformShuffleDown
+ | Op::GroupNonUniformShuffleUp
+ | Op::GroupNonUniformShuffleXor => {
+ inst.expect(
+ if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
+ 5
+ } else {
+ 6
+ },
+ )?;
+ block.extend(emitter.finish(ctx.expressions));
+ let result_type_id = self.next()?;
+ let result_id = self.next()?;
+ let exec_scope_id = self.next()?;
+ let argument_id = self.next()?;
+
+ let argument_lookup = self.lookup_expression.lookup(argument_id)?;
+ let argument_handle = get_expr_handle!(argument_id, argument_lookup);
+
+ let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
+ let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
+ .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32)
+ .ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
+
+ let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) {
+ crate::GatherMode::BroadcastFirst
+ } else {
+ let index_id = self.next()?;
+ let index_lookup = self.lookup_expression.lookup(index_id)?;
+ let index_handle = get_expr_handle!(index_id, index_lookup);
+ match inst.op {
+ spirv::Op::GroupNonUniformBroadcast => {
+ crate::GatherMode::Broadcast(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffle => {
+ crate::GatherMode::Shuffle(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleDown => {
+ crate::GatherMode::ShuffleDown(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleUp => {
+ crate::GatherMode::ShuffleUp(index_handle)
+ }
+ spirv::Op::GroupNonUniformShuffleXor => {
+ crate::GatherMode::ShuffleXor(index_handle)
+ }
+ _ => unreachable!(),
+ }
+ };
+
+ let result_type = self.lookup_type.lookup(result_type_id)?;
+
+ let result_handle = ctx.expressions.append(
+ crate::Expression::SubgroupOperationResult {
+ ty: result_type.handle,
+ },
+ span,
+ );
+ self.lookup_expression.insert(
+ result_id,
+ LookupExpression {
+ handle: result_handle,
+ type_id: result_type_id,
+ block_id,
+ },
+ );
+
+ block.push(
+ crate::Statement::SubgroupGather {
+ result: result_handle,
+ mode,
+ argument: argument_handle,
+ },
+ span,
+ );
+ emitter.start(ctx.expressions);
+ }
_ => return Err(Error::UnsupportedInstruction(self.state, inst.op)),
}
};
@@ -3713,6 +3969,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
&mut self,
globals: &Arena<crate::GlobalVariable>,
constants: &Arena<crate::Constant>,
+ overrides: &Arena<crate::Override>,
) -> Arena<crate::Expression> {
let mut expressions = Arena::new();
#[allow(clippy::panic)]
@@ -3737,8 +3994,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
// register constants
for (&id, con) in self.lookup_constant.iter() {
- let span = constants.get_span(con.handle);
- let handle = expressions.append(crate::Expression::Constant(con.handle), span);
+ let (expr, span) = match con.inner {
+ Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)),
+ Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)),
+ };
+ let handle = expressions.append(expr, span);
self.lookup_expression.insert(
id,
LookupExpression {
@@ -3812,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
| S::Store { .. }
| S::ImageStore { .. }
| S::Atomic { .. }
- | S::RayQuery { .. } => {}
+ | S::RayQuery { .. }
+ | S::SubgroupBallot { .. }
+ | S::SubgroupCollectiveOperation { .. }
+ | S::SubgroupGather { .. } => {}
S::Call {
function: ref mut callee,
ref arguments,
@@ -3944,10 +4207,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
- Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
+ Op::ConstantComposite | Op::SpecConstantComposite => {
+ self.parse_composite_constant(inst, &mut module)
+ }
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
- Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
- Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
+ Op::ConstantTrue | Op::SpecConstantTrue => {
+ self.parse_bool_constant(inst, true, &mut module)
+ }
+ Op::ConstantFalse | Op::SpecConstantFalse => {
+ self.parse_bool_constant(inst, false, &mut module)
+ }
Op::Variable => self.parse_global_variable(inst, &mut module),
Op::Function => {
self.switch(ModuleState::Function, inst.op)?;
@@ -4504,9 +4773,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let length_id = self.next()?;
let length_const = self.lookup_constant.lookup(length_id)?;
- let size = resolve_constant(module.to_ctx(), length_const.handle)
+ let size = resolve_constant(module.to_ctx(), &length_const.inner)
.and_then(NonZeroU32::new)
- .ok_or(Error::InvalidArraySize(length_const.handle))?;
+ .ok_or(Error::InvalidArraySize(length_id))?;
let decor = self.future_decor.remove(&id).unwrap_or_default();
let base = self.lookup_type.lookup(type_id)?.handle;
@@ -4919,29 +5188,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
_ => return Err(Error::UnsupportedType(type_lookup.handle)),
};
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let span = self.span_from_with_op(start);
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Literal(literal), span);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_composite_constant(
@@ -4965,34 +5218,18 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let span = self.span_from_with_op(start);
let constant = self.lookup_constant.lookup(component_id)?;
let expr = module
- .const_expressions
- .append(crate::Expression::Constant(constant.handle), span);
+ .global_expressions
+ .append(constant.inner.to_expr(), span);
components.push(expr);
}
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let span = self.span_from_with_op(start);
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Compose { ty, components }, span);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_null_constant(
@@ -5010,23 +5247,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
let init = module
- .const_expressions
+ .global_expressions
.append(crate::Expression::ZeroValue(ty), span);
- let handle = module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- );
- self.lookup_constant
- .insert(id, LookupConstant { handle, type_id });
- Ok(())
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
}
fn parse_bool_constant(
@@ -5045,27 +5270,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;
- let decor = self.future_decor.remove(&id).unwrap_or_default();
-
- let init = module.const_expressions.append(
+ let init = module.global_expressions.append(
crate::Expression::Literal(crate::Literal::Bool(value)),
span,
);
- self.lookup_constant.insert(
- id,
- LookupConstant {
- handle: module.constants.append(
- crate::Constant {
- r#override: decor.specialization(),
- name: decor.name,
- ty,
- init,
- },
- span,
- ),
- type_id,
- },
- );
+
+ self.insert_parsed_constant(module, id, type_id, ty, init, span)
+ }
+
+ fn insert_parsed_constant(
+ &mut self,
+ module: &mut crate::Module,
+ id: u32,
+ type_id: u32,
+ ty: Handle<crate::Type>,
+ init: Handle<crate::Expression>,
+ span: crate::Span,
+ ) -> Result<(), Error> {
+ let decor = self.future_decor.remove(&id).unwrap_or_default();
+
+ let inner = if let Some(id) = decor.specialization_constant_id {
+ let o = crate::Override {
+ name: decor.name,
+ id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?),
+ ty,
+ init: Some(init),
+ };
+ Constant::Override(module.overrides.append(o, span))
+ } else {
+ let c = crate::Constant {
+ name: decor.name,
+ ty,
+ init,
+ };
+ Constant::Constant(module.constants.append(c, span))
+ };
+
+ self.lookup_constant
+ .insert(id, LookupConstant { inner, type_id });
Ok(())
}
@@ -5087,8 +5329,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let span = self.span_from_with_op(start);
let lconst = self.lookup_constant.lookup(init_id)?;
let expr = module
- .const_expressions
- .append(crate::Expression::Constant(lconst.handle), span);
+ .global_expressions
+ .append(lconst.inner.to_expr(), span);
Some(expr)
} else {
None
@@ -5209,7 +5451,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
match null::generate_default_built_in(
Some(built_in),
ty,
- &mut module.const_expressions,
+ &mut module.global_expressions,
span,
) {
Ok(handle) => Some(handle),
@@ -5231,14 +5473,14 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let handle = null::generate_default_built_in(
built_in,
member.ty,
- &mut module.const_expressions,
+ &mut module.global_expressions,
span,
)?;
components.push(handle);
}
Some(
module
- .const_expressions
+ .global_expressions
.append(crate::Expression::Compose { ty, components }, span),
)
}
@@ -5303,11 +5545,12 @@ fn make_index_literal(
Ok(expr)
}
-fn resolve_constant(
- gctx: crate::proc::GlobalCtx,
- constant: Handle<crate::Constant>,
-) -> Option<u32> {
- match gctx.const_expressions[gctx.constants[constant].init] {
+fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> {
+ let constant = match *constant {
+ Constant::Constant(constant) => constant,
+ Constant::Override(_) => return None,
+ };
+ match gctx.global_expressions[gctx.constants[constant].init] {
crate::Expression::Literal(crate::Literal::U32(id)) => Some(id),
crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32),
_ => None,