diff options
Diffstat (limited to 'third_party/rust/naga/src')
-rw-r--r-- | third_party/rust/naga/src/back/mod.rs | 58 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/spv/index.rs | 116 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/spv/mod.rs | 9 | ||||
-rw-r--r-- | third_party/rust/naga/src/back/spv/writer.rs | 70 | ||||
-rw-r--r-- | third_party/rust/naga/src/front/glsl/variables.rs | 2 | ||||
-rw-r--r-- | third_party/rust/naga/src/span.rs | 9 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/analyzer.rs | 69 | ||||
-rw-r--r-- | third_party/rust/naga/src/valid/type.rs | 2 |
8 files changed, 245 insertions, 90 deletions
diff --git a/third_party/rust/naga/src/back/mod.rs b/third_party/rust/naga/src/back/mod.rs index 8100b930e9..c8f091decb 100644 --- a/third_party/rust/naga/src/back/mod.rs +++ b/third_party/rust/naga/src/back/mod.rs @@ -16,14 +16,19 @@ pub mod spv; #[cfg(feature = "wgsl-out")] pub mod wgsl; -const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; -const INDENT: &str = " "; -const BAKE_PREFIX: &str = "_e"; +/// Names of vector components. +pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; +/// Indent for backends. +pub const INDENT: &str = " "; +/// Prefix used for baking. +pub const BAKE_PREFIX: &str = "_e"; -type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>; +/// Expressions that need baking. +pub type NeedBakeExpressions = crate::FastHashSet<crate::Handle<crate::Expression>>; +/// Indentation level. #[derive(Clone, Copy)] -struct Level(usize); +pub struct Level(pub usize); impl Level { const fn next(&self) -> Self { @@ -52,7 +57,7 @@ impl std::fmt::Display for Level { /// [`EntryPoint`]: crate::EntryPoint /// [`Module`]: crate::Module /// [`Module::entry_points`]: crate::Module::entry_points -enum FunctionType { +pub enum FunctionType { /// A regular function. Function(crate::Handle<crate::Function>), /// An [`EntryPoint`], and its index in [`Module::entry_points`]. @@ -63,7 +68,8 @@ enum FunctionType { } impl FunctionType { - fn is_compute_entry_point(&self, module: &crate::Module) -> bool { + /// Returns true if the function is an entry point for a compute shader. + pub fn is_compute_entry_point(&self, module: &crate::Module) -> bool { match *self { FunctionType::EntryPoint(index) => { module.entry_points[index as usize].stage == crate::ShaderStage::Compute @@ -74,19 +80,20 @@ impl FunctionType { } /// Helper structure that stores data needed when writing the function -struct FunctionCtx<'a> { +pub struct FunctionCtx<'a> { /// The current function being written - ty: FunctionType, + pub ty: FunctionType, /// Analysis about the function - info: &'a crate::valid::FunctionInfo, + pub info: &'a crate::valid::FunctionInfo, /// The expression arena of the current function being written - expressions: &'a crate::Arena<crate::Expression>, + pub expressions: &'a crate::Arena<crate::Expression>, /// Map of expressions that have associated variable names - named_expressions: &'a crate::NamedExpressions, + pub named_expressions: &'a crate::NamedExpressions, } impl FunctionCtx<'_> { - fn resolve_type<'a>( + /// Helper method that resolves a type of a given expression. + pub fn resolve_type<'a>( &'a self, handle: crate::Handle<crate::Expression>, types: &'a crate::UniqueArena<crate::Type>, @@ -95,7 +102,10 @@ impl FunctionCtx<'_> { } /// Helper method that generates a [`NameKey`](crate::proc::NameKey) for a local in the current function - const fn name_key(&self, local: crate::Handle<crate::LocalVariable>) -> crate::proc::NameKey { + pub const fn name_key( + &self, + local: crate::Handle<crate::LocalVariable>, + ) -> crate::proc::NameKey { match self.ty { FunctionType::Function(handle) => crate::proc::NameKey::FunctionLocal(handle, local), FunctionType::EntryPoint(idx) => crate::proc::NameKey::EntryPointLocal(idx, local), @@ -106,7 +116,7 @@ impl FunctionCtx<'_> { /// /// # Panics /// - If the function arguments are less or equal to `arg` - const fn argument_key(&self, arg: u32) -> crate::proc::NameKey { + pub const fn argument_key(&self, arg: u32) -> crate::proc::NameKey { match self.ty { FunctionType::Function(handle) => crate::proc::NameKey::FunctionArgument(handle, arg), FunctionType::EntryPoint(ep_index) => { @@ -115,8 +125,8 @@ impl FunctionCtx<'_> { } } - // Returns true if the given expression points to a fixed-function pipeline input. - fn is_fixed_function_input( + /// Returns true if the given expression points to a fixed-function pipeline input. + pub fn is_fixed_function_input( &self, mut expression: crate::Handle<crate::Expression>, module: &crate::Module, @@ -162,7 +172,7 @@ impl crate::Expression { /// See the [module-level documentation][emit] for details. /// /// [emit]: index.html#expression-evaluation-time - const fn bake_ref_count(&self) -> usize { + pub const fn bake_ref_count(&self) -> usize { match *self { // accesses are never cached, only loads are crate::Expression::Access { .. } | crate::Expression::AccessIndex { .. } => usize::MAX, @@ -181,9 +191,7 @@ impl crate::Expression { } /// Helper function that returns the string corresponding to the [`BinaryOperator`](crate::BinaryOperator) -/// # Notes -/// Used by `glsl-out`, `msl-out`, `wgsl-out`, `hlsl-out`. -const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { +pub const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { use crate::BinaryOperator as Bo; match op { Bo::Add => "+", @@ -208,8 +216,6 @@ const fn binary_operation_str(op: crate::BinaryOperator) -> &'static str { } /// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize) -/// # Notes -/// Used by `msl-out`, `wgsl-out`, `hlsl-out`. const fn vector_size_str(size: crate::VectorSize) -> &'static str { match size { crate::VectorSize::Bi => "2", @@ -219,7 +225,8 @@ const fn vector_size_str(size: crate::VectorSize) -> &'static str { } impl crate::TypeInner { - const fn is_handle(&self) -> bool { + /// Returns true if this is a handle to a type rather than the type directly. + pub const fn is_handle(&self) -> bool { match *self { crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } => true, _ => false, @@ -266,8 +273,9 @@ bitflags::bitflags! { } } +/// The intersection test to use for ray queries. #[repr(u32)] -enum RayIntersectionType { +pub enum RayIntersectionType { Triangle = 1, BoundingBox = 4, } diff --git a/third_party/rust/naga/src/back/spv/index.rs b/third_party/rust/naga/src/back/spv/index.rs index 92e0f88d9a..0effb568be 100644 --- a/third_party/rust/naga/src/back/spv/index.rs +++ b/third_party/rust/naga/src/back/spv/index.rs @@ -3,8 +3,9 @@ Bounds-checking for SPIR-V output. */ use super::{ - helpers::global_needs_wrapper, selection::Selection, Block, BlockContext, Error, IdGenerator, - Instruction, Word, + helpers::{global_needs_wrapper, map_storage_class}, + selection::Selection, + Block, BlockContext, Error, IdGenerator, Instruction, Word, }; use crate::{arena::Handle, proc::BoundsCheckPolicy}; @@ -42,32 +43,113 @@ impl<'w> BlockContext<'w> { array: Handle<crate::Expression>, block: &mut Block, ) -> Result<Word, Error> { - // Naga IR permits runtime-sized arrays as global variables or as the - // final member of a struct that is a global variable. SPIR-V permits - // only the latter, so this back end wraps bare runtime-sized arrays - // in a made-up struct; see `helpers::global_needs_wrapper` and its uses. - // This code must handle both cases. - let (structure_id, last_member_index) = match self.ir_function.expressions[array] { + // Naga IR permits runtime-sized arrays as global variables, or as the + // final member of a struct that is a global variable, or one of these + // inside a buffer that is itself an element in a buffer bindings array. + // SPIR-V requires that runtime-sized arrays are wrapped in structs. + // See `helpers::global_needs_wrapper` and its uses. + let (opt_array_index_id, global_handle, opt_last_member_index) = match self + .ir_function + .expressions[array] + { crate::Expression::AccessIndex { base, index } => { match self.ir_function.expressions[base] { - crate::Expression::GlobalVariable(handle) => ( - self.writer.global_variables[handle.index()].access_id, - index, - ), - _ => return Err(Error::Validation("array length expression")), + // The global variable is an array of buffer bindings of structs, + // we are accessing one of them with a static index, + // and the last member of it. + crate::Expression::AccessIndex { + base: base_outer, + index: index_outer, + } => match self.ir_function.expressions[base_outer] { + crate::Expression::GlobalVariable(handle) => { + let index_id = self.get_index_constant(index_outer); + (Some(index_id), handle, Some(index)) + } + _ => return Err(Error::Validation("array length expression case-1a")), + }, + // The global variable is an array of buffer bindings of structs, + // we are accessing one of them with a dynamic index, + // and the last member of it. + crate::Expression::Access { + base: base_outer, + index: index_outer, + } => match self.ir_function.expressions[base_outer] { + crate::Expression::GlobalVariable(handle) => { + let index_id = self.cached[index_outer]; + (Some(index_id), handle, Some(index)) + } + _ => return Err(Error::Validation("array length expression case-1b")), + }, + // The global variable is a buffer, and we are accessing the last member. + crate::Expression::GlobalVariable(handle) => { + let global = &self.ir_module.global_variables[handle]; + match self.ir_module.types[global.ty].inner { + // The global variable is an array of buffer bindings of run-time arrays. + crate::TypeInner::BindingArray { .. } => (Some(index), handle, None), + // The global variable is a struct, and we are accessing the last member + _ => (None, handle, Some(index)), + } + } + _ => return Err(Error::Validation("array length expression case-1c")), } } + // The global variable is an array of buffer bindings of arrays. + crate::Expression::Access { base, index } => match self.ir_function.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + let index_id = self.cached[index]; + let global = &self.ir_module.global_variables[handle]; + match self.ir_module.types[global.ty].inner { + crate::TypeInner::BindingArray { .. } => (Some(index_id), handle, None), + _ => return Err(Error::Validation("array length expression case-2a")), + } + } + _ => return Err(Error::Validation("array length expression case-2b")), + }, + // The global variable is a run-time array. crate::Expression::GlobalVariable(handle) => { let global = &self.ir_module.global_variables[handle]; if !global_needs_wrapper(self.ir_module, global) { - return Err(Error::Validation("array length expression")); + return Err(Error::Validation("array length expression case-3")); } - - (self.writer.global_variables[handle.index()].var_id, 0) + (None, handle, None) } - _ => return Err(Error::Validation("array length expression")), + _ => return Err(Error::Validation("array length expression case-4")), }; + let gvar = self.writer.global_variables[global_handle.index()].clone(); + let global = &self.ir_module.global_variables[global_handle]; + let (last_member_index, gvar_id) = match opt_last_member_index { + Some(index) => (index, gvar.access_id), + None => { + if !global_needs_wrapper(self.ir_module, global) { + return Err(Error::Validation( + "pointer to a global that is not a wrapped array", + )); + } + (0, gvar.var_id) + } + }; + let structure_id = match opt_array_index_id { + // We are indexing inside a binding array, generate the access op. + Some(index_id) => { + let element_type_id = match self.ir_module.types[global.ty].inner { + crate::TypeInner::BindingArray { base, size: _ } => { + let class = map_storage_class(global.space); + self.get_pointer_id(base, class)? + } + _ => return Err(Error::Validation("array length expression case-5")), + }; + let structure_id = self.gen_id(); + block.body.push(Instruction::access_chain( + element_type_id, + structure_id, + gvar_id, + &[index_id], + )); + structure_id + } + None => gvar_id, + }; let length_id = self.gen_id(); block.body.push(Instruction::array_length( self.writer.get_uint_type_id(), diff --git a/third_party/rust/naga/src/back/spv/mod.rs b/third_party/rust/naga/src/back/spv/mod.rs index b7d57be0d4..eb29e3cd8b 100644 --- a/third_party/rust/naga/src/back/spv/mod.rs +++ b/third_party/rust/naga/src/back/spv/mod.rs @@ -576,6 +576,15 @@ impl BlockContext<'_> { self.writer .get_constant_scalar(crate::Literal::I32(scope as _)) } + + fn get_pointer_id( + &mut self, + handle: Handle<crate::Type>, + class: spirv::StorageClass, + ) -> Result<Word, Error> { + self.writer + .get_pointer_id(&self.ir_module.types, handle, class) + } } #[derive(Clone, Copy, Default)] diff --git a/third_party/rust/naga/src/back/spv/writer.rs b/third_party/rust/naga/src/back/spv/writer.rs index de3220bbda..a5065e0623 100644 --- a/third_party/rust/naga/src/back/spv/writer.rs +++ b/third_party/rust/naga/src/back/spv/writer.rs @@ -565,36 +565,38 @@ impl Writer { // Handle globals are pre-emitted and should be loaded automatically. // // Any that are binding arrays we skip as we cannot load the array, we must load the result after indexing. - let is_binding_array = match ir_module.types[var.ty].inner { - crate::TypeInner::BindingArray { .. } => true, - _ => false, - }; - - if var.space == crate::AddressSpace::Handle && !is_binding_array { - let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); - let id = self.id_gen.next(); - prelude - .body - .push(Instruction::load(var_type_id, id, gv.var_id, None)); - gv.access_id = gv.var_id; - gv.handle_id = id; - } else if global_needs_wrapper(ir_module, var) { - let class = map_storage_class(var.space); - let pointer_type_id = self.get_pointer_id(&ir_module.types, var.ty, class)?; - let index_id = self.get_index_constant(0); - - let id = self.id_gen.next(); - prelude.body.push(Instruction::access_chain( - pointer_type_id, - id, - gv.var_id, - &[index_id], - )); - gv.access_id = id; - } else { - // by default, the variable ID is accessed as is - gv.access_id = gv.var_id; - }; + match ir_module.types[var.ty].inner { + crate::TypeInner::BindingArray { .. } => { + gv.access_id = gv.var_id; + } + _ => { + if var.space == crate::AddressSpace::Handle { + let var_type_id = self.get_type_id(LookupType::Handle(var.ty)); + let id = self.id_gen.next(); + prelude + .body + .push(Instruction::load(var_type_id, id, gv.var_id, None)); + gv.access_id = gv.var_id; + gv.handle_id = id; + } else if global_needs_wrapper(ir_module, var) { + let class = map_storage_class(var.space); + let pointer_type_id = + self.get_pointer_id(&ir_module.types, var.ty, class)?; + let index_id = self.get_index_constant(0); + let id = self.id_gen.next(); + prelude.body.push(Instruction::access_chain( + pointer_type_id, + id, + gv.var_id, + &[index_id], + )); + gv.access_id = id; + } else { + // by default, the variable ID is accessed as is + gv.access_id = gv.var_id; + }; + } + } // work around borrow checking in the presence of `self.xxx()` calls self.global_variables[handle.index()] = gv; @@ -1858,9 +1860,15 @@ impl Writer { .iter() .flat_map(|entry| entry.function.arguments.iter()) .any(|arg| has_view_index_check(ir_module, arg.binding.as_ref(), arg.ty)); - let has_ray_query = ir_module.special_types.ray_desc.is_some() + let mut has_ray_query = ir_module.special_types.ray_desc.is_some() | ir_module.special_types.ray_intersection.is_some(); + for (_, &crate::Type { ref inner, .. }) in ir_module.types.iter() { + if let &crate::TypeInner::AccelerationStructure | &crate::TypeInner::RayQuery = inner { + has_ray_query = true + } + } + if self.physical_layout.version < 0x10300 && has_storage_buffers { // enable the storage buffer class on < SPV-1.3 Instruction::extension("SPV_KHR_storage_buffer_storage_class") diff --git a/third_party/rust/naga/src/front/glsl/variables.rs b/third_party/rust/naga/src/front/glsl/variables.rs index 5af2b228f0..9d2e7a0e7b 100644 --- a/third_party/rust/naga/src/front/glsl/variables.rs +++ b/third_party/rust/naga/src/front/glsl/variables.rs @@ -65,7 +65,7 @@ impl Frontend { let idx = self.entry_args.len(); self.entry_args.push(EntryArg { - name: None, + name: Some(name.into()), binding: Binding::BuiltIn(data.builtin), handle, storage: data.storage, diff --git a/third_party/rust/naga/src/span.rs b/third_party/rust/naga/src/span.rs index 53246b25d6..10744647e9 100644 --- a/third_party/rust/naga/src/span.rs +++ b/third_party/rust/naga/src/span.rs @@ -104,16 +104,17 @@ impl std::ops::Index<Span> for str { /// A human-readable representation for a span, tailored for text source. /// -/// Corresponds to the positional members of [`GPUCompilationMessage`][gcm] from -/// the WebGPU specification, except that `offset` and `length` are in bytes -/// (UTF-8 code units), instead of UTF-16 code units. +/// Roughly corresponds to the positional members of [`GPUCompilationMessage`][gcm] from +/// the WebGPU specification, except +/// - `offset` and `length` are in bytes (UTF-8 code units), instead of UTF-16 code units. +/// - `line_position` counts entire Unicode code points, instead of UTF-16 code units. /// /// [gcm]: https://www.w3.org/TR/webgpu/#gpucompilationmessage #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub struct SourceLocation { /// 1-based line number. pub line_number: u32, - /// 1-based column of the start of this span + /// 1-based column of the start of this span, counted in Unicode code points. pub line_position: u32, /// 0-based Offset in code units (in bytes) of the start of the span. pub offset: u32, diff --git a/third_party/rust/naga/src/valid/analyzer.rs b/third_party/rust/naga/src/valid/analyzer.rs index df6fc5e9b0..03fbc4089b 100644 --- a/third_party/rust/naga/src/valid/analyzer.rs +++ b/third_party/rust/naga/src/valid/analyzer.rs @@ -145,10 +145,35 @@ pub struct SamplingKey { #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +/// Information about an expression in a function body. pub struct ExpressionInfo { + /// Whether this expression is uniform, and why. + /// + /// If this expression's value is not uniform, this is the handle + /// of the expression from which this one's non-uniformity + /// originates. Otherwise, this is `None`. pub uniformity: Uniformity, + + /// The number of statements and other expressions using this + /// expression's value. pub ref_count: usize, + + /// The global variable into which this expression produces a pointer. + /// + /// This is `None` unless this expression is either a + /// [`GlobalVariable`], or an [`Access`] or [`AccessIndex`] that + /// ultimately refers to some part of a global. + /// + /// [`Load`] expressions applied to pointer-typed arguments could + /// refer to globals, but we leave this as `None` for them. + /// + /// [`GlobalVariable`]: crate::Expression::GlobalVariable + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + /// [`Load`]: crate::Expression::Load assignable_global: Option<Handle<crate::GlobalVariable>>, + + /// The type of this expression. pub ty: TypeResolution, } @@ -311,14 +336,20 @@ pub enum UniformityDisruptor { } impl FunctionInfo { - /// Adds a value-type reference to an expression. + /// Record a use of `expr` of the sort given by `global_use`. + /// + /// Bump `expr`'s reference count, and return its uniformity. + /// + /// If `expr` is a pointer to a global variable, or some part of + /// a global variable, add `global_use` to that global's set of + /// uses. #[must_use] fn add_ref_impl( &mut self, - handle: Handle<crate::Expression>, + expr: Handle<crate::Expression>, global_use: GlobalUse, ) -> NonUniformResult { - let info = &mut self.expressions[handle.index()]; + let info = &mut self.expressions[expr.index()]; info.ref_count += 1; // mark the used global as read if let Some(global) = info.assignable_global { @@ -327,22 +358,38 @@ impl FunctionInfo { info.uniformity.non_uniform_result } - /// Adds a value-type reference to an expression. + /// Record a use of `expr` for its value. + /// + /// This is used for almost all expression references. Anything + /// that writes to the value `expr` points to, or otherwise wants + /// contribute flags other than `GlobalUse::READ`, should use + /// `add_ref_impl` directly. #[must_use] - fn add_ref(&mut self, handle: Handle<crate::Expression>) -> NonUniformResult { - self.add_ref_impl(handle, GlobalUse::READ) + fn add_ref(&mut self, expr: Handle<crate::Expression>) -> NonUniformResult { + self.add_ref_impl(expr, GlobalUse::READ) } - /// Adds a potentially assignable reference to an expression. - /// These are destinations for `Store` and `ImageStore` statements, - /// which can transit through `Access` and `AccessIndex`. + /// Record a use of `expr`, and indicate which global variable it + /// refers to, if any. + /// + /// Bump `expr`'s reference count, and return its uniformity. + /// + /// If `expr` is a pointer to a global variable, or some part + /// thereof, store that global in `*assignable_global`. Leave the + /// global's uses unchanged. + /// + /// This is used to determine the [`assignable_global`] for + /// [`Access`] and [`AccessIndex`] expressions that ultimately + /// refer to a global variable. Those expressions don't contribute + /// any usage to the global themselves; that depends on how other + /// expressions use them. #[must_use] fn add_assignable_ref( &mut self, - handle: Handle<crate::Expression>, + expr: Handle<crate::Expression>, assignable_global: &mut Option<Handle<crate::GlobalVariable>>, ) -> NonUniformResult { - let info = &mut self.expressions[handle.index()]; + let info = &mut self.expressions[expr.index()]; info.ref_count += 1; // propagate the assignable global up the chain, till it either hits // a value-type expression, or the assignment statement. diff --git a/third_party/rust/naga/src/valid/type.rs b/third_party/rust/naga/src/valid/type.rs index d44a295b1a..b8eb618ed4 100644 --- a/third_party/rust/naga/src/valid/type.rs +++ b/third_party/rust/naga/src/valid/type.rs @@ -510,7 +510,6 @@ impl super::Validator { ti.uniform_layout = Ok(Alignment::MIN_UNIFORM); let mut min_offset = 0; - let mut prev_struct_data: Option<(u32, u32)> = None; for (i, member) in members.iter().enumerate() { @@ -662,6 +661,7 @@ impl super::Validator { // Currently Naga only supports binding arrays of structs for non-handle types. match gctx.types[base].inner { crate::TypeInner::Struct { .. } => {} + crate::TypeInner::Array { .. } => {} _ => return Err(TypeError::BindingArrayBaseTypeNotStruct(base)), }; } |