diff options
Diffstat (limited to 'third_party/rust/naga/src/back/spv/writer.rs')
-rw-r--r-- | third_party/rust/naga/src/back/spv/writer.rs | 68 |
1 files changed, 56 insertions, 12 deletions
diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs index a5065e0623..73a16c273e 100644 --- a/third_party/rust/naga/src/back/spv/writer.rs +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -615,7 +615,7 @@ impl Writer { // Steal the Writer's temp list for a bit. temp_list: std::mem::take(&mut self.temp_list), writer: self, - expression_constness: crate::proc::ExpressionConstnessTracker::from_arena( + expression_constness: super::ExpressionConstnessTracker::from_arena( &ir_function.expressions, ), }; @@ -970,6 +970,11 @@ impl Writer { handle: Handle<crate::Type>, ) -> Result<Word, Error> { let ty = &arena[handle]; + // If it's a type that needs SPIR-V capabilities, request them now. + // This needs to happen regardless of the LocalType lookup succeeding, + // because some types which map to the same LocalType have different + // capability requirements. See https://github.com/gfx-rs/wgpu/issues/5569 + self.request_type_capabilities(&ty.inner)?; let id = if let Some(local) = make_local(&ty.inner) { // This type can be represented as a `LocalType`, so check if we've // already written an instruction for it. If not, do so now, with @@ -985,10 +990,6 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's a type that needs SPIR-V capabilities, request them now, - // so write_type_declaration_local can stay infallible. - self.request_type_capabilities(&ty.inner)?; - id } } @@ -1150,7 +1151,7 @@ impl Writer { } pub(super) fn get_constant_scalar(&mut self, value: crate::Literal) -> Word { - let scalar = CachedConstant::Literal(value); + let scalar = CachedConstant::Literal(value.into()); if let Some(&id) = self.cached_constants.get(&scalar) { return id; } @@ -1258,7 +1259,7 @@ impl Writer { ir_module: &crate::Module, mod_info: &ModuleInfo, ) -> Result<Word, Error> { - let id = match ir_module.const_expressions[handle] { + let id = match ir_module.global_expressions[handle] { crate::Expression::Literal(literal) => self.get_constant_scalar(literal), crate::Expression::Constant(constant) => { let constant = &ir_module.constants[constant]; @@ -1272,7 +1273,7 @@ impl Writer { let component_ids: Vec<_> = crate::proc::flatten_compose( ty, components, - &ir_module.const_expressions, + &ir_module.global_expressions, &ir_module.types, ) .map(|component| self.constant_ids[component.index()]) @@ -1310,7 +1311,11 @@ impl Writer { spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); - let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { + self.get_index_constant(spirv::Scope::Subgroup as u32) + } else { + self.get_index_constant(spirv::Scope::Workgroup as u32) + }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); block.body.push(Instruction::control_barrier( @@ -1585,6 +1590,41 @@ impl Writer { Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + // Subgroup + Bi::NumSubgroups => { + self.require_any( + "`num_subgroups` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId + } + Bi::SubgroupSize => { + self.require_any( + "`subgroup_size` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupSize + } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); @@ -1899,7 +1939,7 @@ impl Writer { source_code: debug_info.source_code, source_file_id, }); - self.debugs.push(Instruction::source( + self.debugs.append(&mut Instruction::source_auto_continued( spirv::SourceLanguage::Unknown, 0, &debug_info_inner, @@ -1914,8 +1954,8 @@ impl Writer { // write all const-expressions as constants self.constant_ids - .resize(ir_module.const_expressions.len(), 0); - for (handle, _) in ir_module.const_expressions.iter() { + .resize(ir_module.global_expressions.len(), 0); + for (handle, _) in ir_module.global_expressions.iter() { self.write_constant_expr(handle, ir_module, mod_info)?; } debug_assert!(self.constant_ids.iter().all(|&id| id != 0)); @@ -2029,6 +2069,10 @@ impl Writer { debug_info: &Option<DebugInfo>, words: &mut Vec<Word>, ) -> Result<(), Error> { + if !ir_module.overrides.is_empty() { + return Err(Error::Override); + } + self.reset(); // Try to find the entry point and corresponding index |