diff options
Diffstat (limited to 'third_party/rust/naga/src/back/spv/block.rs')
-rw-r--r-- | third_party/rust/naga/src/back/spv/block.rs | 2121 |
1 files changed, 2121 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/back/spv/block.rs b/third_party/rust/naga/src/back/spv/block.rs new file mode 100644 index 0000000000..10fd5d72aa --- /dev/null +++ b/third_party/rust/naga/src/back/spv/block.rs @@ -0,0 +1,2121 @@ +/*! +Implementations for `BlockContext` methods. +*/ + +use super::{ + index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension, + Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags, +}; +use crate::{arena::Handle, proc::TypeResolution}; +use spirv::Word; + +fn get_dimension(type_inner: &crate::TypeInner) -> Dimension { + match *type_inner { + crate::TypeInner::Scalar { .. } => Dimension::Scalar, + crate::TypeInner::Vector { .. } => Dimension::Vector, + crate::TypeInner::Matrix { .. } => Dimension::Matrix, + _ => unreachable!(), + } +} + +/// The results of emitting code for a left-hand-side expression. +/// +/// On success, `write_expression_pointer` returns one of these. +enum ExpressionPointer { + /// The pointer to the expression's value is available, as the value of the + /// expression with the given id. + Ready { pointer_id: Word }, + + /// The access expression must be conditional on the value of `condition`, a boolean + /// expression that is true if all indices are in bounds. If `condition` is true, then + /// `access` is an `OpAccessChain` instruction that will compute a pointer to the + /// expression's value. If `condition` is false, then executing `access` would be + /// undefined behavior. + Conditional { + condition: Word, + access: Instruction, + }, +} + +/// The termination statement to be added to the end of the block +pub enum BlockExit { + /// Generates an OpReturn (void return) + Return, + /// Generates an OpBranch to the specified block + Branch { + /// The branch target block + target: Word, + }, + /// Translates a loop `break if` into an `OpBranchConditional` to the + /// merge block if true (the merge block is passed through [`LoopContext::break_id`] + /// or else to the loop header (passed through [`preamble_id`]) + /// + /// [`preamble_id`]: Self::BreakIf::preamble_id + BreakIf { + /// The condition of the `break if` + condition: Handle<crate::Expression>, + /// The loop header block id + preamble_id: Word, + }, +} + +impl Writer { + // Flip Y coordinate to adjust for coordinate space difference + // between SPIR-V and our IR. + // The `position_id` argument is a pointer to a `vecN<f32>`, + // whose `y` component we will negate. + fn write_epilogue_position_y_flip( + &mut self, + position_id: Word, + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width: 4, + pointer_space: Some(spirv::StorageClass::Output), + })); + let index_y_id = self.get_index_constant(1); + let access_id = self.id_gen.next(); + body.push(Instruction::access_chain( + float_ptr_type_id, + access_id, + position_id, + &[index_y_id], + )); + + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width: 4, + pointer_space: None, + })); + let load_id = self.id_gen.next(); + body.push(Instruction::load(float_type_id, load_id, access_id, None)); + + let neg_id = self.id_gen.next(); + body.push(Instruction::unary( + spirv::Op::FNegate, + float_type_id, + neg_id, + load_id, + )); + + body.push(Instruction::store(access_id, neg_id, None)); + Ok(()) + } + + // Clamp fragment depth between 0 and 1. + fn write_epilogue_frag_depth_clamp( + &mut self, + frag_depth_id: Word, + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Float, + width: 4, + pointer_space: None, + })); + let value0_id = self.get_constant_scalar(crate::ScalarValue::Float(0.0), 4); + let value1_id = self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4); + + let original_id = self.id_gen.next(); + body.push(Instruction::load( + float_type_id, + original_id, + frag_depth_id, + None, + )); + + let clamp_id = self.id_gen.next(); + body.push(Instruction::ext_inst( + self.gl450_ext_inst_id, + spirv::GLOp::FClamp, + float_type_id, + clamp_id, + &[original_id, value0_id, value1_id], + )); + + body.push(Instruction::store(frag_depth_id, clamp_id, None)); + Ok(()) + } + + fn write_entry_point_return( + &mut self, + value_id: Word, + ir_result: &crate::FunctionResult, + result_members: &[ResultMember], + body: &mut Vec<Instruction>, + ) -> Result<(), Error> { + for (index, res_member) in result_members.iter().enumerate() { + let member_value_id = match ir_result.binding { + Some(_) => value_id, + None => { + let member_value_id = self.id_gen.next(); + body.push(Instruction::composite_extract( + res_member.type_id, + member_value_id, + value_id, + &[index as u32], + )); + member_value_id + } + }; + + body.push(Instruction::store(res_member.id, member_value_id, None)); + + match res_member.built_in { + Some(crate::BuiltIn::Position { .. }) + if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) => + { + self.write_epilogue_position_y_flip(res_member.id, body)?; + } + Some(crate::BuiltIn::FragDepth) + if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) => + { + self.write_epilogue_frag_depth_clamp(res_member.id, body)?; + } + _ => {} + } + } + Ok(()) + } +} + +impl<'w> BlockContext<'w> { + /// Decide whether to put off emitting instructions for `expr_handle`. + /// + /// We would like to gather together chains of `Access` and `AccessIndex` + /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do + /// this, we don't generate instructions for these exprs when we first + /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then, + /// once we encounter a `Load` or `Store` expression that actually needs the + /// chain's value, we call `write_expression_pointer` to handle the whole + /// thing in one fell swoop. + fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool { + match self.ir_function.expressions[expr_handle] { + crate::Expression::GlobalVariable(handle) => { + let ty = self.ir_module.global_variables[handle].ty; + match self.ir_module.types[ty].inner { + crate::TypeInner::BindingArray { .. } => false, + _ => true, + } + } + crate::Expression::LocalVariable(_) => true, + crate::Expression::FunctionArgument(index) => { + let arg = &self.ir_function.arguments[index as usize]; + self.ir_module.types[arg.ty].inner.pointer_space().is_some() + } + + // The chain rule: if this `Access...`'s `base` operand was + // previously omitted, then omit this one, too. + _ => self.cached.ids[expr_handle.index()] == 0, + } + } + + /// Cache an expression for a value. + pub(super) fn cache_expression_value( + &mut self, + expr_handle: Handle<crate::Expression>, + block: &mut Block, + ) -> Result<(), Error> { + let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty); + + let id = match self.ir_function.expressions[expr_handle] { + crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => { + // See `is_intermediate`; we'll handle this later in + // `write_expression_pointer`. + 0 + } + crate::Expression::Access { base, index } => { + let base_ty_inner = self.fun_info[base].ty.inner_with(&self.ir_module.types); + match *base_ty_inner { + crate::TypeInner::Vector { .. } => { + self.write_vector_access(expr_handle, base, index, block)? + } + crate::TypeInner::BindingArray { + base: binding_type, .. + } => { + let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { + base: binding_type, + class: spirv::StorageClass::UniformConstant, + }); + + let result_id = match self.write_expression_pointer( + expr_handle, + block, + Some(binding_array_false_pointer), + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Texture array out-of-bounds handling", + )); + } + }; + + let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + + let load_id = self.gen_id(); + block.body.push(Instruction::load( + binding_type_id, + load_id, + result_id, + None, + )); + + if self.fun_info[index].uniformity.non_uniform_result.is_some() { + self.writer.require_any( + "NonUniformEXT", + &[spirv::Capability::ShaderNonUniform], + )?; + self.writer.use_extension("SPV_EXT_descriptor_indexing"); + self.writer + .decorate(load_id, spirv::Decoration::NonUniform, &[]); + } + load_id + } + ref other => { + log::error!( + "Unable to access base {:?} of type {:?}", + self.ir_function.expressions[base], + other + ); + return Err(Error::Validation( + "only vectors may be dynamically indexed by value", + )); + } + } + } + crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => { + // See `is_intermediate`; we'll handle this later in + // `write_expression_pointer`. + 0 + } + crate::Expression::AccessIndex { base, index } => { + match *self.fun_info[base].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::Array { .. } + | crate::TypeInner::Struct { .. } => { + // We never need bounds checks here: dynamically sized arrays can + // only appear behind pointers, and are thus handled by the + // `is_intermediate` case above. Everything else's size is + // statically known and checked in validation. + let id = self.gen_id(); + let base_id = self.cached[base]; + block.body.push(Instruction::composite_extract( + result_type_id, + id, + base_id, + &[index], + )); + id + } + crate::TypeInner::BindingArray { + base: binding_type, .. + } => { + let binding_array_false_pointer = LookupType::Local(LocalType::Pointer { + base: binding_type, + class: spirv::StorageClass::UniformConstant, + }); + + let result_id = match self.write_expression_pointer( + expr_handle, + block, + Some(binding_array_false_pointer), + )? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Texture array out-of-bounds handling", + )); + } + }; + + let binding_type_id = self.get_type_id(LookupType::Handle(binding_type)); + + let load_id = self.gen_id(); + block.body.push(Instruction::load( + binding_type_id, + load_id, + result_id, + None, + )); + + load_id + } + ref other => { + log::error!("Unable to access index of {:?}", other); + return Err(Error::FeatureNotImplemented("access index for type")); + } + } + } + crate::Expression::GlobalVariable(handle) => { + self.writer.global_variables[handle.index()].access_id + } + crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()], + crate::Expression::Splat { size, value } => { + let value_id = self.cached[value]; + let components = [value_id; 4]; + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + &components[..size as usize], + )); + id + } + crate::Expression::Swizzle { + size, + vector, + pattern, + } => { + let vector_id = self.cached[vector]; + self.temp_list.clear(); + for &sc in pattern[..size as usize].iter() { + self.temp_list.push(sc as Word); + } + let id = self.gen_id(); + block.body.push(Instruction::vector_shuffle( + result_type_id, + id, + vector_id, + vector_id, + &self.temp_list, + )); + id + } + crate::Expression::Compose { + ty: _, + ref components, + } => { + self.temp_list.clear(); + for &component in components { + self.temp_list.push(self.cached[component]); + } + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + &self.temp_list, + )); + id + } + crate::Expression::Unary { op, expr } => { + let id = self.gen_id(); + let expr_id = self.cached[expr]; + let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types); + + let spirv_op = match op { + crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Float) => spirv::Op::FNegate, + Some(crate::ScalarKind::Sint) => spirv::Op::SNegate, + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot, + Some(crate::ScalarKind::Uint) | None => { + log::error!("Unable to negate {:?}", expr_ty_inner); + return Err(Error::FeatureNotImplemented("negation")); + } + }, + crate::UnaryOperator::Not => match expr_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot, + _ => spirv::Op::Not, + }, + }; + + block + .body + .push(Instruction::unary(spirv_op, result_type_id, id, expr_id)); + id + } + crate::Expression::Binary { op, left, right } => { + let id = self.gen_id(); + let left_id = self.cached[left]; + let right_id = self.cached[right]; + + let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types); + let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types); + + let left_dimension = get_dimension(left_ty_inner); + let right_dimension = get_dimension(right_ty_inner); + + let mut reverse_operands = false; + + let spirv_op = match op { + crate::BinaryOperator::Add => match *left_ty_inner { + crate::TypeInner::Scalar { kind, .. } + | crate::TypeInner::Vector { kind, .. } => match kind { + crate::ScalarKind::Float => spirv::Op::FAdd, + _ => spirv::Op::IAdd, + }, + crate::TypeInner::Matrix { + columns, + rows, + width, + } => { + self.write_matrix_matrix_column_op( + block, + id, + result_type_id, + left_id, + right_id, + columns, + rows, + width, + spirv::Op::FAdd, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unimplemented!(), + }, + crate::BinaryOperator::Subtract => match *left_ty_inner { + crate::TypeInner::Scalar { kind, .. } + | crate::TypeInner::Vector { kind, .. } => match kind { + crate::ScalarKind::Float => spirv::Op::FSub, + _ => spirv::Op::ISub, + }, + crate::TypeInner::Matrix { + columns, + rows, + width, + } => { + self.write_matrix_matrix_column_op( + block, + id, + result_type_id, + left_id, + right_id, + columns, + rows, + width, + spirv::Op::FSub, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unimplemented!(), + }, + crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) { + (Dimension::Scalar, Dimension::Vector) => { + self.write_vector_scalar_mult( + block, + id, + result_type_id, + right_id, + left_id, + right_ty_inner, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + (Dimension::Vector, Dimension::Scalar) => { + self.write_vector_scalar_mult( + block, + id, + result_type_id, + left_id, + right_id, + left_ty_inner, + ); + + self.cached[expr_handle] = id; + return Ok(()); + } + (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix, + (Dimension::Matrix, Dimension::Scalar) => spirv::Op::MatrixTimesScalar, + (Dimension::Scalar, Dimension::Matrix) => { + reverse_operands = true; + spirv::Op::MatrixTimesScalar + } + (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector, + (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix, + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) + if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) => + { + spirv::Op::FMul + } + (Dimension::Vector, Dimension::Vector) + | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul, + }, + crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SDiv, + Some(crate::ScalarKind::Uint) => spirv::Op::UDiv, + Some(crate::ScalarKind::Float) => spirv::Op::FDiv, + _ => unimplemented!(), + }, + crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() { + // TODO: handle undefined behavior + // if right == 0 return 0 + // if left == min(type_of(left)) && right == -1 return 0 + Some(crate::ScalarKind::Sint) => spirv::Op::SRem, + // TODO: handle undefined behavior + // if right == 0 return 0 + Some(crate::ScalarKind::Uint) => spirv::Op::UMod, + // TODO: handle undefined behavior + // if right == 0 return ? see https://github.com/gpuweb/gpuweb/issues/2798 + Some(crate::ScalarKind::Float) => spirv::Op::FRem, + _ => unimplemented!(), + }, + crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + spirv::Op::IEqual + } + Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual, + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint | crate::ScalarKind::Uint) => { + spirv::Op::INotEqual + } + Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual, + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan, + Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan, + _ => unimplemented!(), + }, + crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual, + Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan, + Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan, + _ => unimplemented!(), + }, + crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual, + Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual, + Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual, + _ => unimplemented!(), + }, + crate::BinaryOperator::And => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd, + _ => spirv::Op::BitwiseAnd, + }, + crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor, + crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr, + _ => spirv::Op::BitwiseOr, + }, + crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd, + crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr, + crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical, + crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() { + Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic, + Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical, + _ => unimplemented!(), + }, + }; + + block.body.push(Instruction::binary( + spirv_op, + result_type_id, + id, + if reverse_operands { right_id } else { left_id }, + if reverse_operands { left_id } else { right_id }, + )); + id + } + crate::Expression::Math { + fun, + arg, + arg1, + arg2, + arg3, + } => { + use crate::MathFunction as Mf; + enum MathOp { + Ext(spirv::GLOp), + Custom(Instruction), + } + + let arg0_id = self.cached[arg]; + let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types); + let arg_scalar_kind = arg_ty.scalar_kind(); + let arg1_id = match arg1 { + Some(handle) => self.cached[handle], + None => 0, + }; + let arg2_id = match arg2 { + Some(handle) => self.cached[handle], + None => 0, + }; + let arg3_id = match arg3 { + Some(handle) => self.cached[handle], + None => 0, + }; + + let id = self.gen_id(); + let math_op = match fun { + // comparison + Mf::Abs => { + match arg_scalar_kind { + Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs), + Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs), + Some(crate::ScalarKind::Uint) => { + MathOp::Custom(Instruction::unary( + spirv::Op::CopyObject, // do nothing + result_type_id, + id, + arg0_id, + )) + } + other => unimplemented!("Unexpected abs({:?})", other), + } + } + Mf::Min => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FMin, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin, + other => unimplemented!("Unexpected min({:?})", other), + }), + Mf::Max => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FMax, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax, + other => unimplemented!("Unexpected max({:?})", other), + }), + Mf::Clamp => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp, + Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp, + other => unimplemented!("Unexpected max({:?})", other), + }), + Mf::Saturate => { + let (maybe_size, width) = match *arg_ty { + crate::TypeInner::Vector { size, width, .. } => (Some(size), width), + crate::TypeInner::Scalar { width, .. } => (None, width), + ref other => unimplemented!("Unexpected saturate({:?})", other), + }; + + let mut arg1_id = self + .writer + .get_constant_scalar(crate::ScalarValue::Float(0.0), width); + let mut arg2_id = self + .writer + .get_constant_scalar(crate::ScalarValue::Float(1.0), width); + + if let Some(size) = maybe_size { + let value = LocalType::Value { + vector_size: Some(size), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + }; + + let result_type_id = self.get_type_id(LookupType::Local(value)); + + self.temp_list.clear(); + self.temp_list.resize(size as _, arg1_id); + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + &self.temp_list, + )); + arg1_id = id; + + self.temp_list.clear(); + self.temp_list.resize(size as _, arg2_id); + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + result_type_id, + id, + &self.temp_list, + )); + arg2_id = id; + } + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FClamp, + result_type_id, + id, + &[arg0_id, arg1_id, arg2_id], + )) + } + // trigonometry + Mf::Sin => MathOp::Ext(spirv::GLOp::Sin), + Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh), + Mf::Asin => MathOp::Ext(spirv::GLOp::Asin), + Mf::Cos => MathOp::Ext(spirv::GLOp::Cos), + Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh), + Mf::Acos => MathOp::Ext(spirv::GLOp::Acos), + Mf::Tan => MathOp::Ext(spirv::GLOp::Tan), + Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh), + Mf::Atan => MathOp::Ext(spirv::GLOp::Atan), + Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2), + Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh), + Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh), + Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh), + Mf::Radians => MathOp::Ext(spirv::GLOp::Radians), + Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees), + // decomposition + Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil), + Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven), + Mf::Floor => MathOp::Ext(spirv::GLOp::Floor), + Mf::Fract => MathOp::Ext(spirv::GLOp::Fract), + Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc), + Mf::Modf => MathOp::Ext(spirv::GLOp::Modf), + Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp), + Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp), + // geometry + Mf::Dot => match *self.fun_info[arg].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Vector { + kind: crate::ScalarKind::Float, + .. + } => MathOp::Custom(Instruction::binary( + spirv::Op::Dot, + result_type_id, + id, + arg0_id, + arg1_id, + )), + // TODO: consider using integer dot product if VK_KHR_shader_integer_dot_product is available + crate::TypeInner::Vector { size, .. } => { + self.write_dot_product( + id, + result_type_id, + arg0_id, + arg1_id, + size as u32, + block, + ); + self.cached[expr_handle] = id; + return Ok(()); + } + _ => unreachable!( + "Correct TypeInner for dot product should be already validated" + ), + }, + Mf::Outer => MathOp::Custom(Instruction::binary( + spirv::Op::OuterProduct, + result_type_id, + id, + arg0_id, + arg1_id, + )), + Mf::Cross => MathOp::Ext(spirv::GLOp::Cross), + Mf::Distance => MathOp::Ext(spirv::GLOp::Distance), + Mf::Length => MathOp::Ext(spirv::GLOp::Length), + Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize), + Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward), + Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect), + Mf::Refract => MathOp::Ext(spirv::GLOp::Refract), + // exponent + Mf::Exp => MathOp::Ext(spirv::GLOp::Exp), + Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2), + Mf::Log => MathOp::Ext(spirv::GLOp::Log), + Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2), + Mf::Pow => MathOp::Ext(spirv::GLOp::Pow), + // computational + Mf::Sign => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Float) => spirv::GLOp::FSign, + Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign, + other => unimplemented!("Unexpected sign({:?})", other), + }), + Mf::Fma => MathOp::Ext(spirv::GLOp::Fma), + Mf::Mix => { + let selector = arg2.unwrap(); + let selector_ty = + self.fun_info[selector].ty.inner_with(&self.ir_module.types); + match (arg_ty, selector_ty) { + // if the selector is a scalar, we need to splat it + ( + &crate::TypeInner::Vector { size, .. }, + &crate::TypeInner::Scalar { kind, width }, + ) => { + let selector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + kind, + width, + pointer_space: None, + })); + self.temp_list.clear(); + self.temp_list.resize(size as usize, arg2_id); + + let selector_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + selector_type_id, + selector_id, + &self.temp_list, + )); + + MathOp::Custom(Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + spirv::GLOp::FMix, + result_type_id, + id, + &[arg0_id, arg1_id, selector_id], + )) + } + _ => MathOp::Ext(spirv::GLOp::FMix), + } + } + Mf::Step => MathOp::Ext(spirv::GLOp::Step), + Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep), + Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt), + Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt), + Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse), + Mf::Transpose => MathOp::Custom(Instruction::unary( + spirv::Op::Transpose, + result_type_id, + id, + arg0_id, + )), + Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant), + Mf::ReverseBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitReverse, + result_type_id, + id, + arg0_id, + )), + Mf::CountOneBits => MathOp::Custom(Instruction::unary( + spirv::Op::BitCount, + result_type_id, + id, + arg0_id, + )), + Mf::ExtractBits => { + let op = match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract, + Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract, + other => unimplemented!("Unexpected sign({:?})", other), + }; + MathOp::Custom(Instruction::ternary( + op, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + )) + } + Mf::InsertBits => MathOp::Custom(Instruction::quaternary( + spirv::Op::BitFieldInsert, + result_type_id, + id, + arg0_id, + arg1_id, + arg2_id, + arg3_id, + )), + Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb), + Mf::FindMsb => MathOp::Ext(match arg_scalar_kind { + Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb, + Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb, + other => unimplemented!("Unexpected findMSB({:?})", other), + }), + Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8), + Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8), + Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16), + Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16), + Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16), + Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8), + Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8), + Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16), + Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16), + Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16), + }; + + block.body.push(match math_op { + MathOp::Ext(op) => Instruction::ext_inst( + self.writer.gl450_ext_inst_id, + op, + result_type_id, + id, + &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()], + ), + MathOp::Custom(inst) => inst, + }); + id + } + crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id, + crate::Expression::Load { pointer } => { + match self.write_expression_pointer(pointer, block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let id = self.gen_id(); + let atomic_space = + match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Pointer { base, space } => { + match self.ir_module.types[base].inner { + crate::TypeInner::Atomic { .. } => Some(space), + _ => None, + } + } + _ => None, + }; + let instruction = if let Some(space) = atomic_space { + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + Instruction::atomic_load( + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + ) + } else { + Instruction::load(result_type_id, id, pointer_id, None) + }; + block.body.push(instruction); + id + } + ExpressionPointer::Conditional { condition, access } => { + //TODO: support atomics? + self.write_conditional_indexed_load( + result_type_id, + condition, + block, + move |id_gen, block| { + // The in-bounds path. Perform the access and the load. + let pointer_id = access.result_id.unwrap(); + let value_id = id_gen.next(); + block.body.push(access); + block.body.push(Instruction::load( + result_type_id, + value_id, + pointer_id, + None, + )); + value_id + }, + ) + } + } + } + crate::Expression::FunctionArgument(index) => self.function.parameter_id(index), + crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => { + self.cached[expr_handle] + } + crate::Expression::As { + expr, + kind, + convert, + } => { + use crate::ScalarKind as Sk; + + let expr_id = self.cached[expr]; + let (src_kind, src_size, src_width, is_matrix) = + match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) { + crate::TypeInner::Scalar { kind, width } => (kind, None, width, false), + crate::TypeInner::Vector { kind, width, size } => { + (kind, Some(size), width, false) + } + crate::TypeInner::Matrix { width, .. } => (kind, None, width, true), + ref other => { + log::error!("As source {:?}", other); + return Err(Error::Validation("Unexpected Expression::As source")); + } + }; + + enum Cast { + Identity, + Unary(spirv::Op), + Binary(spirv::Op, Word), + Ternary(spirv::Op, Word, Word), + } + + let cast = if is_matrix { + // we only support identity casts for matrices + Cast::Unary(spirv::Op::CopyObject) + } else { + match (src_kind, kind, convert) { + // Filter out identity casts. Some Adreno drivers are + // confused by no-op OpBitCast instructions. + (src_kind, kind, convert) + if src_kind == kind && convert.unwrap_or(src_width) == src_width => + { + Cast::Identity + } + (Sk::Bool, Sk::Bool, _) => Cast::Unary(spirv::Op::CopyObject), + (_, _, None) => Cast::Unary(spirv::Op::Bitcast), + // casting to a bool - generate `OpXxxNotEqual` + (_, Sk::Bool, Some(_)) => { + let (op, value) = match src_kind { + Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)), + Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)), + Sk::Float => { + (spirv::Op::FUnordNotEqual, crate::ScalarValue::Float(0.0)) + } + Sk::Bool => unreachable!(), + }; + let zero_scalar_id = self.writer.get_constant_scalar(value, src_width); + let zero_id = match src_size { + Some(size) => { + let vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + kind: src_kind, + width: src_width, + pointer_space: None, + })); + let components = [zero_scalar_id; 4]; + + let zero_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + vector_type_id, + zero_id, + &components[..size as usize], + )); + zero_id + } + None => zero_scalar_id, + }; + + Cast::Binary(op, zero_id) + } + // casting from a bool - generate `OpSelect` + (Sk::Bool, _, Some(dst_width)) => { + let (val0, val1) = match kind { + Sk::Sint => { + (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1)) + } + Sk::Uint => { + (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1)) + } + Sk::Float => ( + crate::ScalarValue::Float(0.0), + crate::ScalarValue::Float(1.0), + ), + Sk::Bool => unreachable!(), + }; + let scalar0_id = self.writer.get_constant_scalar(val0, dst_width); + let scalar1_id = self.writer.get_constant_scalar(val1, dst_width); + let (accept_id, reject_id) = match src_size { + Some(size) => { + let vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + kind, + width: dst_width, + pointer_space: None, + })); + let components0 = [scalar0_id; 4]; + let components1 = [scalar1_id; 4]; + + let vec0_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + vector_type_id, + vec0_id, + &components0[..size as usize], + )); + let vec1_id = self.gen_id(); + block.body.push(Instruction::composite_construct( + vector_type_id, + vec1_id, + &components1[..size as usize], + )); + (vec1_id, vec0_id) + } + None => (scalar1_id, scalar0_id), + }; + + Cast::Ternary(spirv::Op::Select, accept_id, reject_id) + } + (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU), + (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS), + (Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => { + Cast::Unary(spirv::Op::FConvert) + } + (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF), + (Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => { + Cast::Unary(spirv::Op::SConvert) + } + (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF), + (Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => { + Cast::Unary(spirv::Op::UConvert) + } + // We assume it's either an identity cast, or int-uint. + _ => Cast::Unary(spirv::Op::Bitcast), + } + }; + + let id = self.gen_id(); + let instruction = match cast { + Cast::Identity => None, + Cast::Unary(op) => Some(Instruction::unary(op, result_type_id, id, expr_id)), + Cast::Binary(op, operand) => Some(Instruction::binary( + op, + result_type_id, + id, + expr_id, + operand, + )), + Cast::Ternary(op, op1, op2) => Some(Instruction::ternary( + op, + result_type_id, + id, + expr_id, + op1, + op2, + )), + }; + if let Some(instruction) = instruction { + block.body.push(instruction); + id + } else { + expr_id + } + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => self.write_image_load( + result_type_id, + image, + coordinate, + array_index, + level, + sample, + block, + )?, + crate::Expression::ImageSample { + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + } => self.write_image_sample( + result_type_id, + image, + sampler, + gather, + coordinate, + array_index, + offset, + level, + depth_ref, + block, + )?, + crate::Expression::Select { + condition, + accept, + reject, + } => { + let id = self.gen_id(); + let mut condition_id = self.cached[condition]; + let accept_id = self.cached[accept]; + let reject_id = self.cached[reject]; + + let condition_ty = self.fun_info[condition] + .ty + .inner_with(&self.ir_module.types); + let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types); + + if let ( + &crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width, + }, + &crate::TypeInner::Vector { size, .. }, + ) = (condition_ty, object_ty) + { + self.temp_list.clear(); + self.temp_list.resize(size as usize, condition_id); + + let bool_vector_type_id = + self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(size), + kind: crate::ScalarKind::Bool, + width, + pointer_space: None, + })); + + let id = self.gen_id(); + block.body.push(Instruction::composite_construct( + bool_vector_type_id, + id, + &self.temp_list, + )); + condition_id = id + } + + let instruction = + Instruction::select(result_type_id, id, condition_id, accept_id, reject_id); + block.body.push(instruction); + id + } + crate::Expression::Derivative { axis, expr } => { + use crate::DerivativeAxis as Da; + + let id = self.gen_id(); + let expr_id = self.cached[expr]; + let op = match axis { + Da::X => spirv::Op::DPdx, + Da::Y => spirv::Op::DPdy, + Da::Width => spirv::Op::Fwidth, + }; + block + .body + .push(Instruction::derivative(op, result_type_id, id, expr_id)); + id + } + crate::Expression::ImageQuery { image, query } => { + self.write_image_query(result_type_id, image, query, block)? + } + crate::Expression::Relational { fun, argument } => { + use crate::RelationalFunction as Rf; + let arg_id = self.cached[argument]; + let op = match fun { + Rf::All => spirv::Op::All, + Rf::Any => spirv::Op::Any, + Rf::IsNan => spirv::Op::IsNan, + Rf::IsInf => spirv::Op::IsInf, + //TODO: these require Kernel capability + Rf::IsFinite | Rf::IsNormal => { + return Err(Error::FeatureNotImplemented("is finite/normal")) + } + }; + let id = self.gen_id(); + block + .body + .push(Instruction::relational(op, result_type_id, id, arg_id)); + id + } + crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?, + }; + + self.cached[expr_handle] = id; + Ok(()) + } + + /// Build an `OpAccessChain` instruction. + /// + /// Emit any needed bounds-checking expressions to `block`. + /// + /// Some cases we need to generate a different return type than what the IR gives us. + /// This is because pointers to binding arrays don't exist in the IR, but we need to + /// create them to create an access chain in SPIRV. + /// + /// On success, the return value is an [`ExpressionPointer`] value; see the + /// documentation for that type. + fn write_expression_pointer( + &mut self, + mut expr_handle: Handle<crate::Expression>, + block: &mut Block, + return_type_override: Option<LookupType>, + ) -> Result<ExpressionPointer, Error> { + let result_lookup_ty = match self.fun_info[expr_handle].ty { + TypeResolution::Handle(ty_handle) => match return_type_override { + // We use the return type override as a special case for binding arrays as the OpAccessChain + // needs to return a pointer, but indexing into a binding array just gives you the type of + // the binding in the IR. + Some(ty) => ty, + None => LookupType::Handle(ty_handle), + }, + TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()), + }; + let result_type_id = self.get_type_id(result_lookup_ty); + + // The id of the boolean `and` of all dynamic bounds checks up to this point. If + // `None`, then we haven't done any dynamic bounds checks yet. + // + // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not + // a short-circuit branch. This means we might do comparisons we don't need to, + // but we expect these checks to almost always succeed, and keeping branches to a + // minimum is essential. + let mut accumulated_checks = None; + + self.temp_list.clear(); + let root_id = loop { + expr_handle = match self.ir_function.expressions[expr_handle] { + crate::Expression::Access { base, index } => { + let index_id = match self.write_bounds_check(base, index, block)? { + BoundsCheckResult::KnownInBounds(known_index) => { + // Even if the index is known, `OpAccessIndex` + // requires expression operands, not literals. + let scalar = crate::ScalarValue::Uint(known_index as u64); + self.writer.get_constant_scalar(scalar, 4) + } + BoundsCheckResult::Computed(computed_index_id) => computed_index_id, + BoundsCheckResult::Conditional(comparison_id) => { + match accumulated_checks { + Some(prior_checks) => { + let combined = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + self.writer.get_bool_type_id(), + combined, + prior_checks, + comparison_id, + )); + accumulated_checks = Some(combined); + } + None => { + // Start a fresh chain of checks. + accumulated_checks = Some(comparison_id); + } + } + + // Either way, the index to use is unchanged. + self.cached[index] + } + }; + self.temp_list.push(index_id); + + base + } + crate::Expression::AccessIndex { base, index } => { + let const_id = self.get_index_constant(index); + self.temp_list.push(const_id); + base + } + crate::Expression::GlobalVariable(handle) => { + let gv = &self.writer.global_variables[handle.index()]; + break gv.access_id; + } + crate::Expression::LocalVariable(variable) => { + let local_var = &self.function.variables[&variable]; + break local_var.id; + } + crate::Expression::FunctionArgument(index) => { + break self.function.parameter_id(index); + } + ref other => unimplemented!("Unexpected pointer expression {:?}", other), + } + }; + + let pointer = if self.temp_list.is_empty() { + ExpressionPointer::Ready { + pointer_id: root_id, + } + } else { + self.temp_list.reverse(); + let pointer_id = self.gen_id(); + let access = + Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list); + + // If we generated some bounds checks, we need to leave it to our + // caller to generate the branch, the access, the load or store, and + // the zero value (for loads). Otherwise, we can emit the access + // ourselves, and just hand them the id of the pointer. + match accumulated_checks { + Some(condition) => ExpressionPointer::Conditional { condition, access }, + None => { + block.body.push(access); + ExpressionPointer::Ready { pointer_id } + } + } + }; + + Ok(pointer) + } + + /// Build the instructions for matrix - matrix column operations + #[allow(clippy::too_many_arguments)] + fn write_matrix_matrix_column_op( + &mut self, + block: &mut Block, + result_id: Word, + result_type_id: Word, + left_id: Word, + right_id: Word, + columns: crate::VectorSize, + rows: crate::VectorSize, + width: u8, + op: spirv::Op, + ) { + self.temp_list.clear(); + + let vector_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(rows), + kind: crate::ScalarKind::Float, + width, + pointer_space: None, + })); + + for index in 0..columns as u32 { + let column_id_left = self.gen_id(); + let column_id_right = self.gen_id(); + let column_id_res = self.gen_id(); + + block.body.push(Instruction::composite_extract( + vector_type_id, + column_id_left, + left_id, + &[index], + )); + block.body.push(Instruction::composite_extract( + vector_type_id, + column_id_right, + right_id, + &[index], + )); + block.body.push(Instruction::binary( + op, + vector_type_id, + column_id_res, + column_id_left, + column_id_right, + )); + + self.temp_list.push(column_id_res); + } + + block.body.push(Instruction::composite_construct( + result_type_id, + result_id, + &self.temp_list, + )); + } + + /// Build the instructions for vector - scalar multiplication + fn write_vector_scalar_mult( + &mut self, + block: &mut Block, + result_id: Word, + result_type_id: Word, + vector_id: Word, + scalar_id: Word, + vector: &crate::TypeInner, + ) { + let (size, kind) = match *vector { + crate::TypeInner::Vector { size, kind, .. } => (size, kind), + _ => unreachable!(), + }; + + let (op, operand_id) = match kind { + crate::ScalarKind::Float => (spirv::Op::VectorTimesScalar, scalar_id), + _ => { + let operand_id = self.gen_id(); + self.temp_list.clear(); + self.temp_list.resize(size as usize, scalar_id); + block.body.push(Instruction::composite_construct( + result_type_id, + operand_id, + &self.temp_list, + )); + (spirv::Op::IMul, operand_id) + } + }; + + block.body.push(Instruction::binary( + op, + result_type_id, + result_id, + vector_id, + operand_id, + )); + } + + /// Build the instructions for the arithmetic expression of a dot product + fn write_dot_product( + &mut self, + result_id: Word, + result_type_id: Word, + arg0_id: Word, + arg1_id: Word, + size: u32, + block: &mut Block, + ) { + let const_null = self.gen_id(); + block + .body + .push(Instruction::constant_null(result_type_id, const_null)); + + let mut partial_sum = const_null; + let last_component = size - 1; + for index in 0..=last_component { + // compute the product of the current components + let a_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + a_id, + arg0_id, + &[index], + )); + let b_id = self.gen_id(); + block.body.push(Instruction::composite_extract( + result_type_id, + b_id, + arg1_id, + &[index], + )); + let prod_id = self.gen_id(); + block.body.push(Instruction::binary( + spirv::Op::IMul, + result_type_id, + prod_id, + a_id, + b_id, + )); + + // choose the id for the next sum, depending on current index + let id = if index == last_component { + result_id + } else { + self.gen_id() + }; + + // sum the computed product with the partial sum + block.body.push(Instruction::binary( + spirv::Op::IAdd, + result_type_id, + id, + partial_sum, + prod_id, + )); + // set the id of the result as the previous partial sum + partial_sum = id; + } + } + + pub(super) fn write_block( + &mut self, + label_id: Word, + statements: &[crate::Statement], + exit: BlockExit, + loop_context: LoopContext, + ) -> Result<(), Error> { + let mut block = Block::new(label_id); + + for statement in statements { + match *statement { + crate::Statement::Emit(ref range) => { + for handle in range.clone() { + self.cache_expression_value(handle, &mut block)?; + } + } + crate::Statement::Block(ref block_statements) => { + let scope_id = self.gen_id(); + self.function.consume(block, Instruction::branch(scope_id)); + + let merge_id = self.gen_id(); + self.write_block( + scope_id, + block_statements, + BlockExit::Branch { target: merge_id }, + loop_context, + )?; + + block = Block::new(merge_id); + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + let condition_id = self.cached[condition]; + + let merge_id = self.gen_id(); + block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let accept_id = if accept.is_empty() { + None + } else { + Some(self.gen_id()) + }; + let reject_id = if reject.is_empty() { + None + } else { + Some(self.gen_id()) + }; + + self.function.consume( + block, + Instruction::branch_conditional( + condition_id, + accept_id.unwrap_or(merge_id), + reject_id.unwrap_or(merge_id), + ), + ); + + if let Some(block_id) = accept_id { + self.write_block( + block_id, + accept, + BlockExit::Branch { target: merge_id }, + loop_context, + )?; + } + if let Some(block_id) = reject_id { + self.write_block( + block_id, + reject, + BlockExit::Branch { target: merge_id }, + loop_context, + )?; + } + + block = Block::new(merge_id); + } + crate::Statement::Switch { + selector, + ref cases, + } => { + let selector_id = self.cached[selector]; + + let merge_id = self.gen_id(); + block.body.push(Instruction::selection_merge( + merge_id, + spirv::SelectionControl::NONE, + )); + + let default_id = self.gen_id(); + + let mut reached_default = false; + let mut raw_cases = Vec::with_capacity(cases.len()); + let mut case_ids = Vec::with_capacity(cases.len()); + for case in cases.iter() { + match case.value { + crate::SwitchValue::Integer(value) => { + let label_id = self.gen_id(); + // No cases should be added after the default case is encountered + // since the default case catches all + if !reached_default { + raw_cases.push(super::instructions::Case { + value: value as Word, + label_id, + }); + } + case_ids.push(label_id); + } + crate::SwitchValue::Default => { + case_ids.push(default_id); + reached_default = true; + } + } + } + + self.function.consume( + block, + Instruction::switch(selector_id, default_id, &raw_cases), + ); + + let inner_context = LoopContext { + break_id: Some(merge_id), + ..loop_context + }; + + for (i, (case, label_id)) in cases.iter().zip(case_ids.iter()).enumerate() { + let case_finish_id = if case.fall_through { + case_ids[i + 1] + } else { + merge_id + }; + self.write_block( + *label_id, + &case.body, + BlockExit::Branch { + target: case_finish_id, + }, + inner_context, + )?; + } + + // If no default was encountered write a empty block to satisfy the presence of + // a block the default label + if !reached_default { + self.write_block( + default_id, + &[], + BlockExit::Branch { target: merge_id }, + inner_context, + )?; + } + + block = Block::new(merge_id); + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + let preamble_id = self.gen_id(); + self.function + .consume(block, Instruction::branch(preamble_id)); + + let merge_id = self.gen_id(); + let body_id = self.gen_id(); + let continuing_id = self.gen_id(); + + // SPIR-V requires the continuing to the `OpLoopMerge`, + // so we have to start a new block with it. + block = Block::new(preamble_id); + block.body.push(Instruction::loop_merge( + merge_id, + continuing_id, + spirv::SelectionControl::NONE, + )); + self.function.consume(block, Instruction::branch(body_id)); + + self.write_block( + body_id, + body, + BlockExit::Branch { + target: continuing_id, + }, + LoopContext { + continuing_id: Some(continuing_id), + break_id: Some(merge_id), + }, + )?; + + let exit = match break_if { + Some(condition) => BlockExit::BreakIf { + condition, + preamble_id, + }, + None => BlockExit::Branch { + target: preamble_id, + }, + }; + + self.write_block( + continuing_id, + continuing, + exit, + LoopContext { + continuing_id: None, + break_id: Some(merge_id), + }, + )?; + + block = Block::new(merge_id); + } + crate::Statement::Break => { + self.function + .consume(block, Instruction::branch(loop_context.break_id.unwrap())); + return Ok(()); + } + crate::Statement::Continue => { + self.function.consume( + block, + Instruction::branch(loop_context.continuing_id.unwrap()), + ); + return Ok(()); + } + crate::Statement::Return { value: Some(value) } => { + let value_id = self.cached[value]; + let instruction = match self.function.entry_point_context { + // If this is an entry point, and we need to return anything, + // let's instead store the output variables and return `void`. + Some(ref context) => { + self.writer.write_entry_point_return( + value_id, + self.ir_function.result.as_ref().unwrap(), + &context.results, + &mut block.body, + )?; + Instruction::return_void() + } + None => Instruction::return_value(value_id), + }; + self.function.consume(block, instruction); + return Ok(()); + } + crate::Statement::Return { value: None } => { + self.function.consume(block, Instruction::return_void()); + return Ok(()); + } + crate::Statement::Kill => { + self.function.consume(block, Instruction::kill()); + return Ok(()); + } + crate::Statement::Barrier(flags) => { + let memory_scope = if flags.contains(crate::Barrier::STORAGE) { + spirv::Scope::Device + } else { + spirv::Scope::Workgroup + }; + let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE; + semantics.set( + spirv::MemorySemantics::UNIFORM_MEMORY, + flags.contains(crate::Barrier::STORAGE), + ); + semantics.set( + spirv::MemorySemantics::WORKGROUP_MEMORY, + flags.contains(crate::Barrier::WORK_GROUP), + ); + let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let mem_scope_id = self.get_index_constant(memory_scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + block.body.push(Instruction::control_barrier( + exec_scope_id, + mem_scope_id, + semantics_id, + )); + } + crate::Statement::Store { pointer, value } => { + let value_id = self.cached[value]; + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => { + let atomic_space = match *self.fun_info[pointer] + .ty + .inner_with(&self.ir_module.types) + { + crate::TypeInner::Pointer { base, space } => { + match self.ir_module.types[base].inner { + crate::TypeInner::Atomic { .. } => Some(space), + _ => None, + } + } + _ => None, + }; + let instruction = if let Some(space) = atomic_space { + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + Instruction::atomic_store( + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } else { + Instruction::store(pointer_id, value_id, None) + }; + block.body.push(instruction); + } + ExpressionPointer::Conditional { condition, access } => { + let mut selection = Selection::start(&mut block, ()); + selection.if_true(self, condition, ()); + + // The in-bounds path. Perform the access and the store. + let pointer_id = access.result_id.unwrap(); + selection.block().body.push(access); + selection + .block() + .body + .push(Instruction::store(pointer_id, value_id, None)); + + // Finish the in-bounds block and start the merge block. This + // is the block we'll leave current on return. + selection.finish(self, ()); + } + }; + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => self.write_image_store(image, coordinate, array_index, value, &mut block)?, + crate::Statement::Call { + function: local_function, + ref arguments, + result, + } => { + let id = self.gen_id(); + self.temp_list.clear(); + for &argument in arguments { + self.temp_list.push(self.cached[argument]); + } + + let type_id = match result { + Some(expr) => { + self.cached[expr] = id; + self.get_expression_type_id(&self.fun_info[expr].ty) + } + None => self.writer.void_type, + }; + + block.body.push(Instruction::function_call( + type_id, + id, + self.writer.lookup_function[&local_function], + &self.temp_list, + )); + } + crate::Statement::Atomic { + pointer, + ref fun, + value, + result, + } => { + let id = self.gen_id(); + let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty); + + self.cached[result] = id; + + let pointer_id = + match self.write_expression_pointer(pointer, &mut block, None)? { + ExpressionPointer::Ready { pointer_id } => pointer_id, + ExpressionPointer::Conditional { .. } => { + return Err(Error::FeatureNotImplemented( + "Atomics out-of-bounds handling", + )); + } + }; + + let space = self.fun_info[pointer] + .ty + .inner_with(&self.ir_module.types) + .pointer_space() + .unwrap(); + let (semantics, scope) = space.to_spirv_semantics_and_scope(); + let scope_constant_id = self.get_scope_constant(scope as u32); + let semantics_id = self.get_index_constant(semantics.bits()); + let value_id = self.cached[value]; + let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types); + + let instruction = match *fun { + crate::AtomicFunction::Add => Instruction::atomic_binary( + spirv::Op::AtomicIAdd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::Subtract => Instruction::atomic_binary( + spirv::Op::AtomicISub, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::And => Instruction::atomic_binary( + spirv::Op::AtomicAnd, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicOr, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary( + spirv::Op::AtomicXor, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ), + crate::AtomicFunction::Min => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => spirv::Op::AtomicSMin, + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + } => spirv::Op::AtomicUMin, + _ => unimplemented!(), + }; + Instruction::atomic_binary( + spirv_op, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Max => { + let spirv_op = match *value_inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint, + width: _, + } => spirv::Op::AtomicSMax, + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + width: _, + } => spirv::Op::AtomicUMax, + _ => unimplemented!(), + }; + Instruction::atomic_binary( + spirv_op, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Exchange { compare: None } => { + Instruction::atomic_binary( + spirv::Op::AtomicExchange, + result_type_id, + id, + pointer_id, + scope_constant_id, + semantics_id, + value_id, + ) + } + crate::AtomicFunction::Exchange { compare: Some(_) } => { + return Err(Error::FeatureNotImplemented("atomic CompareExchange")); + } + }; + + block.body.push(instruction); + } + } + } + + let termination = match exit { + // We're generating code for the top-level Block of the function, so we + // need to end it with some kind of return instruction. + BlockExit::Return => match self.ir_function.result { + Some(ref result) if self.function.entry_point_context.is_none() => { + let type_id = self.get_type_id(LookupType::Handle(result.ty)); + let null_id = self.writer.write_constant_null(type_id); + Instruction::return_value(null_id) + } + _ => Instruction::return_void(), + }, + BlockExit::Branch { target } => Instruction::branch(target), + BlockExit::BreakIf { + condition, + preamble_id, + } => { + let condition_id = self.cached[condition]; + + Instruction::branch_conditional( + condition_id, + loop_context.break_id.unwrap(), + preamble_id, + ) + } + }; + + self.function.consume(block, termination); + Ok(()) + } +} |