summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/back/spv/subgroup.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/naga/src/back/spv/subgroup.rs')
-rw-r--r--third_party/rust/naga/src/back/spv/subgroup.rs207
1 files changed, 207 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/spv/subgroup.rs b/third_party/rust/naga/src/back/spv/subgroup.rs
new file mode 100644
index 0000000000..c952cb11a7
--- /dev/null
+++ b/third_party/rust/naga/src/back/spv/subgroup.rs
@@ -0,0 +1,207 @@
+use super::{Block, BlockContext, Error, Instruction};
+use crate::{
+ arena::Handle,
+ back::spv::{LocalType, LookupType},
+ TypeInner,
+};
+
+impl<'w> BlockContext<'w> {
+ pub(super) fn write_subgroup_ballot(
+ &mut self,
+ predicate: &Option<Handle<crate::Expression>>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
+ vector_size: Some(crate::VectorSize::Quad),
+ scalar: crate::Scalar::U32,
+ pointer_space: None,
+ }));
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+ let predicate = if let Some(predicate) = *predicate {
+ self.cached[predicate]
+ } else {
+ self.writer.get_constant_scalar(crate::Literal::Bool(true))
+ };
+ let id = self.gen_id();
+ block.body.push(Instruction::group_non_uniform_ballot(
+ vec4_u32_type_id,
+ id,
+ exec_scope_id,
+ predicate,
+ ));
+ self.cached[result] = id;
+ Ok(())
+ }
+ pub(super) fn write_subgroup_operation(
+ &mut self,
+ op: &crate::SubgroupOperation,
+ collective_op: &crate::CollectiveOperation,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ use crate::SubgroupOperation as sg;
+ match *op {
+ sg::All | sg::Any => {
+ self.writer.require_any(
+ "GroupNonUniformVote",
+ &[spirv::Capability::GroupNonUniformVote],
+ )?;
+ }
+ _ => {
+ self.writer.require_any(
+ "GroupNonUniformArithmetic",
+ &[spirv::Capability::GroupNonUniformArithmetic],
+ )?;
+ }
+ }
+
+ let id = self.gen_id();
+ let result_ty = &self.fun_info[result].ty;
+ let result_type_id = self.get_expression_type_id(result_ty);
+ let result_ty_inner = result_ty.inner_with(&self.ir_module.types);
+
+ let (is_scalar, scalar) = match *result_ty_inner {
+ TypeInner::Scalar(kind) => (true, kind),
+ TypeInner::Vector { scalar: kind, .. } => (false, kind),
+ _ => unimplemented!(),
+ };
+
+ use crate::ScalarKind as sk;
+ let spirv_op = match (scalar.kind, *op) {
+ (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll,
+ (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny,
+ (_, sg::All | sg::Any) => unimplemented!(),
+
+ (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd,
+ (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd,
+ (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul,
+ (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul,
+ (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax,
+ (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax,
+ (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax,
+ (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin,
+ (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin,
+ (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin,
+ (_, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(),
+
+ (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd,
+ (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr,
+ (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor,
+ (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd,
+ (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr,
+ (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor,
+ (_, sg::And | sg::Or | sg::Xor) => unimplemented!(),
+ };
+
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+
+ use crate::CollectiveOperation as c;
+ let group_op = match *op {
+ sg::All | sg::Any => None,
+ _ => Some(match *collective_op {
+ c::Reduce => spirv::GroupOperation::Reduce,
+ c::InclusiveScan => spirv::GroupOperation::InclusiveScan,
+ c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan,
+ }),
+ };
+
+ let arg_id = self.cached[argument];
+ block.body.push(Instruction::group_non_uniform_arithmetic(
+ spirv_op,
+ result_type_id,
+ id,
+ exec_scope_id,
+ group_op,
+ arg_id,
+ ));
+ self.cached[result] = id;
+ Ok(())
+ }
+ pub(super) fn write_subgroup_gather(
+ &mut self,
+ mode: &crate::GatherMode,
+ argument: Handle<crate::Expression>,
+ result: Handle<crate::Expression>,
+ block: &mut Block,
+ ) -> Result<(), Error> {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ match *mode {
+ crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => {
+ self.writer.require_any(
+ "GroupNonUniformBallot",
+ &[spirv::Capability::GroupNonUniformBallot],
+ )?;
+ }
+ crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => {
+ self.writer.require_any(
+ "GroupNonUniformShuffle",
+ &[spirv::Capability::GroupNonUniformShuffle],
+ )?;
+ }
+ crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => {
+ self.writer.require_any(
+ "GroupNonUniformShuffleRelative",
+ &[spirv::Capability::GroupNonUniformShuffleRelative],
+ )?;
+ }
+ }
+
+ let id = self.gen_id();
+ let result_ty = &self.fun_info[result].ty;
+ let result_type_id = self.get_expression_type_id(result_ty);
+
+ let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
+
+ let arg_id = self.cached[argument];
+ match *mode {
+ crate::GatherMode::BroadcastFirst => {
+ block
+ .body
+ .push(Instruction::group_non_uniform_broadcast_first(
+ result_type_id,
+ id,
+ exec_scope_id,
+ arg_id,
+ ));
+ }
+ crate::GatherMode::Broadcast(index)
+ | crate::GatherMode::Shuffle(index)
+ | crate::GatherMode::ShuffleDown(index)
+ | crate::GatherMode::ShuffleUp(index)
+ | crate::GatherMode::ShuffleXor(index) => {
+ let index_id = self.cached[index];
+ let op = match *mode {
+ crate::GatherMode::BroadcastFirst => unreachable!(),
+ // Use shuffle to emit broadcast to allow the index to
+ // be dynamically uniform on Vulkan 1.1. The argument to
+ // OpGroupNonUniformBroadcast must be a constant pre SPIR-V
+ // 1.5 (vulkan 1.2)
+ crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle,
+ crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle,
+ crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
+ crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
+ crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
+ };
+ block.body.push(Instruction::group_non_uniform_gather(
+ op,
+ result_type_id,
+ id,
+ exec_scope_id,
+ arg_id,
+ index_id,
+ ));
+ }
+ }
+ self.cached[result] = id;
+ Ok(())
+ }
+}