diff options
Diffstat (limited to 'third_party/rust/naga/src/front/spv/mod.rs')
-rw-r--r-- | third_party/rust/naga/src/front/spv/mod.rs | 455 |
1 files changed, 349 insertions, 106 deletions
diff --git a/third_party/rust/naga/src/front/spv/mod.rs b/third_party/rust/naga/src/front/spv/mod.rs index b793448597..7ac5a18cd6 100644 --- a/third_party/rust/naga/src/front/spv/mod.rs +++ b/third_party/rust/naga/src/front/spv/mod.rs @@ -196,7 +196,7 @@ struct Decoration { location: Option<spirv::Word>, desc_set: Option<spirv::Word>, desc_index: Option<spirv::Word>, - specialization: Option<spirv::Word>, + specialization_constant_id: Option<spirv::Word>, storage_buffer: bool, offset: Option<spirv::Word>, array_stride: Option<NonZeroU32>, @@ -216,11 +216,6 @@ impl Decoration { } } - fn specialization(&self) -> crate::Override { - self.specialization - .map_or(crate::Override::None, crate::Override::ByNameOrId) - } - const fn resource_binding(&self) -> Option<crate::ResourceBinding> { match *self { Decoration { @@ -284,8 +279,23 @@ struct LookupType { } #[derive(Debug)] +enum Constant { + Constant(Handle<crate::Constant>), + Override(Handle<crate::Override>), +} + +impl Constant { + const fn to_expr(&self) -> crate::Expression { + match *self { + Self::Constant(c) => crate::Expression::Constant(c), + Self::Override(o) => crate::Expression::Override(o), + } + } +} + +#[derive(Debug)] struct LookupConstant { - handle: Handle<crate::Constant>, + inner: Constant, type_id: spirv::Word, } @@ -537,7 +547,8 @@ struct BlockContext<'function> { local_arena: &'function mut Arena<crate::LocalVariable>, /// Constants arena of the module being processed const_arena: &'function mut Arena<crate::Constant>, - const_expressions: &'function mut Arena<crate::Expression>, + overrides: &'function mut Arena<crate::Override>, + global_expressions: &'function mut Arena<crate::Expression>, /// Type arena of the module being processed type_arena: &'function UniqueArena<crate::Type>, /// Global arena of the module being processed @@ -757,7 +768,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { dec.matrix_major = Some(Majority::Row); } spirv::Decoration::SpecId => { - dec.specialization = Some(self.next()?); + dec.specialization_constant_id = Some(self.next()?); } other => { log::warn!("Unknown decoration {:?}", other); @@ -1393,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { inst.expect(5)?; let init_id = self.next()?; let lconst = self.lookup_constant.lookup(init_id)?; - Some( - ctx.expressions - .append(crate::Expression::Constant(lconst.handle), span), - ) + Some(ctx.expressions.append(lconst.inner.to_expr(), span)) } else { None }; @@ -3650,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let semantics_const = self.lookup_constant.lookup(semantics_id)?; - let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; - let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) + let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; if exec_scope == spirv::Scope::Workgroup as u32 { @@ -3692,6 +3700,254 @@ impl<I: Iterator<Item = u32>> Frontend<I> { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(5)?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| match predicate_const.inner { + Constant::Constant(constant) => matches!( + ctx.gctx().global_expressions[ctx.gctx().constants[constant].init], + crate::Expression::Literal(crate::Literal::Bool(true)), + ), + Constant::Override(_) => false, + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + let result_handle = ctx + .expressions + .append(crate::Expression::SubgroupBallotResult, span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + emitter.start(ctx.expressions); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + block.extend(emitter.finish(ctx.expressions)); + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 5 + } else { + 6 + }, + )?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 5 + } else { + 6 + }, + )?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3713,6 +3969,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { &mut self, globals: &Arena<crate::GlobalVariable>, constants: &Arena<crate::Constant>, + overrides: &Arena<crate::Override>, ) -> Arena<crate::Expression> { let mut expressions = Arena::new(); #[allow(clippy::panic)] @@ -3737,8 +3994,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> { } // register constants for (&id, con) in self.lookup_constant.iter() { - let span = constants.get_span(con.handle); - let handle = expressions.append(crate::Expression::Constant(con.handle), span); + let (expr, span) = match con.inner { + Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)), + Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)), + }; + let handle = expressions.append(expr, span); self.lookup_expression.insert( id, LookupExpression { @@ -3812,7 +4072,10 @@ impl<I: Iterator<Item = u32>> Frontend<I> { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, @@ -3944,10 +4207,16 @@ impl<I: Iterator<Item = u32>> Frontend<I> { Op::TypeSampledImage => self.parse_type_sampled_image(inst), Op::TypeSampler => self.parse_type_sampler(inst, &mut module), Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), - Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), + Op::ConstantComposite | Op::SpecConstantComposite => { + self.parse_composite_constant(inst, &mut module) + } Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module), - Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module), - Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module), + Op::ConstantTrue | Op::SpecConstantTrue => { + self.parse_bool_constant(inst, true, &mut module) + } + Op::ConstantFalse | Op::SpecConstantFalse => { + self.parse_bool_constant(inst, false, &mut module) + } Op::Variable => self.parse_global_variable(inst, &mut module), Op::Function => { self.switch(ModuleState::Function, inst.op)?; @@ -4504,9 +4773,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let length_id = self.next()?; let length_const = self.lookup_constant.lookup(length_id)?; - let size = resolve_constant(module.to_ctx(), length_const.handle) + let size = resolve_constant(module.to_ctx(), &length_const.inner) .and_then(NonZeroU32::new) - .ok_or(Error::InvalidArraySize(length_const.handle))?; + .ok_or(Error::InvalidArraySize(length_id))?; let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; @@ -4919,29 +5188,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> { _ => return Err(Error::UnsupportedType(type_lookup.handle)), }; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Literal(literal), span); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_composite_constant( @@ -4965,34 +5218,18 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let span = self.span_from_with_op(start); let constant = self.lookup_constant.lookup(component_id)?; let expr = module - .const_expressions - .append(crate::Expression::Constant(constant.handle), span); + .global_expressions + .append(constant.inner.to_expr(), span); components.push(expr); } - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let span = self.span_from_with_op(start); let init = module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_null_constant( @@ -5010,23 +5247,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - let init = module - .const_expressions + .global_expressions .append(crate::Expression::ZeroValue(ty), span); - let handle = module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ); - self.lookup_constant - .insert(id, LookupConstant { handle, type_id }); - Ok(()) + + self.insert_parsed_constant(module, id, type_id, ty, init, span) } fn parse_bool_constant( @@ -5045,27 +5270,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let type_lookup = self.lookup_type.lookup(type_id)?; let ty = type_lookup.handle; - let decor = self.future_decor.remove(&id).unwrap_or_default(); - - let init = module.const_expressions.append( + let init = module.global_expressions.append( crate::Expression::Literal(crate::Literal::Bool(value)), span, ); - self.lookup_constant.insert( - id, - LookupConstant { - handle: module.constants.append( - crate::Constant { - r#override: decor.specialization(), - name: decor.name, - ty, - init, - }, - span, - ), - type_id, - }, - ); + + self.insert_parsed_constant(module, id, type_id, ty, init, span) + } + + fn insert_parsed_constant( + &mut self, + module: &mut crate::Module, + id: u32, + type_id: u32, + ty: Handle<crate::Type>, + init: Handle<crate::Expression>, + span: crate::Span, + ) -> Result<(), Error> { + let decor = self.future_decor.remove(&id).unwrap_or_default(); + + let inner = if let Some(id) = decor.specialization_constant_id { + let o = crate::Override { + name: decor.name, + id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?), + ty, + init: Some(init), + }; + Constant::Override(module.overrides.append(o, span)) + } else { + let c = crate::Constant { + name: decor.name, + ty, + init, + }; + Constant::Constant(module.constants.append(c, span)) + }; + + self.lookup_constant + .insert(id, LookupConstant { inner, type_id }); Ok(()) } @@ -5087,8 +5329,8 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let span = self.span_from_with_op(start); let lconst = self.lookup_constant.lookup(init_id)?; let expr = module - .const_expressions - .append(crate::Expression::Constant(lconst.handle), span); + .global_expressions + .append(lconst.inner.to_expr(), span); Some(expr) } else { None @@ -5209,7 +5451,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> { match null::generate_default_built_in( Some(built_in), ty, - &mut module.const_expressions, + &mut module.global_expressions, span, ) { Ok(handle) => Some(handle), @@ -5231,14 +5473,14 @@ impl<I: Iterator<Item = u32>> Frontend<I> { let handle = null::generate_default_built_in( built_in, member.ty, - &mut module.const_expressions, + &mut module.global_expressions, span, )?; components.push(handle); } Some( module - .const_expressions + .global_expressions .append(crate::Expression::Compose { ty, components }, span), ) } @@ -5303,11 +5545,12 @@ fn make_index_literal( Ok(expr) } -fn resolve_constant( - gctx: crate::proc::GlobalCtx, - constant: Handle<crate::Constant>, -) -> Option<u32> { - match gctx.const_expressions[gctx.constants[constant].init] { +fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> { + let constant = match *constant { + Constant::Constant(constant) => constant, + Constant::Override(_) => return None, + }; + match gctx.global_expressions[gctx.constants[constant].init] { crate::Expression::Literal(crate::Literal::U32(id)) => Some(id), crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32), _ => None, |