diff options
Diffstat (limited to 'third_party/rust/naga/src/proc')
-rw-r--r-- | third_party/rust/naga/src/proc/index.rs | 437 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/layouter.rs | 257 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/mod.rs | 490 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/namer.rs | 261 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/terminator.rs | 42 | ||||
-rw-r--r-- | third_party/rust/naga/src/proc/typifier.rs | 903 |
6 files changed, 2390 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs new file mode 100644 index 0000000000..3fea79ec01 --- /dev/null +++ b/third_party/rust/naga/src/proc/index.rs @@ -0,0 +1,437 @@ +/*! +Definitions for index bounds checking. +*/ + +use crate::{valid, Handle, UniqueArena}; +use bit_set::BitSet; + +/// How should code generated by Naga do bounds checks? +/// +/// When a vector, matrix, or array index is out of bounds—either negative, or +/// greater than or equal to the number of elements in the type—WGSL requires +/// that some other index of the implementation's choice that is in bounds is +/// used instead. (There are no types with zero elements.) +/// +/// Similarly, when out-of-bounds coordinates, array indices, or sample indices +/// are presented to the WGSL `textureLoad` and `textureStore` operations, the +/// operation is redirected to do something safe. +/// +/// Different users of Naga will prefer different defaults: +/// +/// - When used as part of a WebGPU implementation, the WGSL specification +/// requires the `Restrict` behavior for array, vector, and matrix accesses, +/// and either the `Restrict` or `ReadZeroSkipWrite` behaviors for texture +/// accesses. +/// +/// - When used by the `wgpu` crate for native development, `wgpu` selects +/// `ReadZeroSkipWrite` as its default. +/// +/// - Naga's own default is `Unchecked`, so that shader translations +/// are as faithful to the original as possible. +/// +/// Sometimes the underlying hardware and drivers can perform bounds checks +/// themselves, in a way that performs better than the checks Naga would inject. +/// If you're using native checks like this, then having Naga inject its own +/// checks as well would be redundant, and the `Unchecked` policy is +/// appropriate. +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum BoundsCheckPolicy { + /// Replace out-of-bounds indexes with some arbitrary in-bounds index. + /// + /// (This does not necessarily mean clamping. For example, interpreting the + /// index as unsigned and taking the minimum with the largest valid index + /// would also be a valid implementation. That would map negative indices to + /// the last element, not the first.) + Restrict, + + /// Out-of-bounds reads return zero, and writes have no effect. + /// + /// When applied to a chain of accesses, like `a[i][j].b[k]`, all index + /// expressions are evaluated, regardless of whether prior or later index + /// expressions were in bounds. But all the accesses per se are skipped + /// if any index is out of bounds. + ReadZeroSkipWrite, + + /// Naga adds no checks to indexing operations. Generate the fastest code + /// possible. This is the default for Naga, as a translator, but consumers + /// should consider defaulting to a safer behavior. + Unchecked, +} + +/// Policies for injecting bounds checks during code generation. +#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct BoundsCheckPolicies { + /// How should the generated code handle array, vector, or matrix indices + /// that are out of range? + #[cfg_attr(feature = "deserialize", serde(default))] + pub index: BoundsCheckPolicy, + + /// How should the generated code handle array, vector, or matrix indices + /// that are out of range, when those values live in a [`GlobalVariable`] in + /// the [`Storage`] or [`Uniform`] address spaces? + /// + /// Some graphics hardware provides "robust buffer access", a feature that + /// ensures that using a pointer cannot access memory outside the 'buffer' + /// that it was derived from. In Naga terms, this means that the hardware + /// ensures that pointers computed by applying [`Access`] and + /// [`AccessIndex`] expressions to a [`GlobalVariable`] whose [`space`] is + /// [`Storage`] or [`Uniform`] will never read or write memory outside that + /// global variable. + /// + /// When hardware offers such a feature, it is probably undesirable to have + /// Naga inject bounds checking code for such accesses, since the hardware + /// can probably provide the same protection more efficiently. However, + /// bounds checks are still needed on accesses to indexable values that do + /// not live in buffers, like local variables. + /// + /// So, this option provides a separate policy that applies only to accesses + /// to storage and uniform globals. When depending on hardware bounds + /// checking, this policy can be `Unchecked` to avoid unnecessary overhead. + /// + /// When special hardware support is not available, this should probably be + /// the same as `index_bounds_check_policy`. + /// + /// [`GlobalVariable`]: crate::GlobalVariable + /// [`space`]: crate::GlobalVariable::space + /// [`Restrict`]: crate::back::BoundsCheckPolicy::Restrict + /// [`ReadZeroSkipWrite`]: crate::back::BoundsCheckPolicy::ReadZeroSkipWrite + /// [`Access`]: crate::Expression::Access + /// [`AccessIndex`]: crate::Expression::AccessIndex + /// [`Storage`]: crate::AddressSpace::Storage + /// [`Uniform`]: crate::AddressSpace::Uniform + #[cfg_attr(feature = "deserialize", serde(default))] + pub buffer: BoundsCheckPolicy, + + /// How should the generated code handle image texel references that are out + /// of range? + /// + /// This controls the behavior of [`ImageLoad`] expressions and + /// [`ImageStore`] statements when a coordinate, texture array index, level + /// of detail, or multisampled sample number is out of range. + /// + /// [`ImageLoad`]: crate::Expression::ImageLoad + /// [`ImageStore`]: crate::Statement::ImageStore + #[cfg_attr(feature = "deserialize", serde(default))] + pub image: BoundsCheckPolicy, + + /// How should the generated code handle binding array indexes that are out of bounds. + #[cfg_attr(feature = "deserialize", serde(default))] + pub binding_array: BoundsCheckPolicy, +} + +/// The default `BoundsCheckPolicy` is `Unchecked`. +impl Default for BoundsCheckPolicy { + fn default() -> Self { + BoundsCheckPolicy::Unchecked + } +} + +impl BoundsCheckPolicies { + /// Determine which policy applies to `base`. + /// + /// `base` is the "base" expression (the expression being indexed) of a `Access` + /// and `AccessIndex` expression. This is either a pointer, a value, being directly + /// indexed, or a binding array. + /// + /// See the documentation for [`BoundsCheckPolicy`] for details about + /// when each policy applies. + pub fn choose_policy( + &self, + base: Handle<crate::Expression>, + types: &UniqueArena<crate::Type>, + info: &valid::FunctionInfo, + ) -> BoundsCheckPolicy { + let ty = info[base].ty.inner_with(types); + + if let crate::TypeInner::BindingArray { .. } = *ty { + return self.binding_array; + } + + match ty.pointer_space() { + Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => { + self.buffer + } + // This covers other address spaces, but also accessing vectors and + // matrices by value, where no pointer is involved. + _ => self.index, + } + } + + /// Return `true` if any of `self`'s policies are `policy`. + pub fn contains(&self, policy: BoundsCheckPolicy) -> bool { + self.index == policy || self.buffer == policy || self.image == policy + } +} + +/// An index that may be statically known, or may need to be computed at runtime. +/// +/// This enum lets us handle both [`Access`] and [`AccessIndex`] expressions +/// with the same code. +/// +/// [`Access`]: crate::Expression::Access +/// [`AccessIndex`]: crate::Expression::AccessIndex +#[derive(Clone, Copy, Debug)] +pub enum GuardedIndex { + Known(u32), + Expression(Handle<crate::Expression>), +} + +/// Build a set of expressions used as indices, to cache in temporary variables when +/// emitted. +/// +/// Given the bounds-check policies `policies`, construct a `BitSet` containing the handle +/// indices of all the expressions in `function` that are ever used as guarded indices +/// under the [`ReadZeroSkipWrite`] policy. The `module` argument must be the module to +/// which `function` belongs, and `info` should be that function's analysis results. +/// +/// Such index expressions will be used twice in the generated code: first for the +/// comparison to see if the index is in bounds, and then for the access itself, should +/// the comparison succeed. To avoid computing the expressions twice, the generated code +/// should cache them in temporary variables. +/// +/// Why do we need to build such a set in advance, instead of just processing access +/// expressions as we encounter them? Whether an expression needs to be cached depends on +/// whether it appears as something like the [`index`] operand of an [`Access`] expression +/// or the [`level`] operand of an [`ImageLoad`] expression, and on the index bounds check +/// policies that apply to those accesses. But [`Emit`] statements just identify a range +/// of expressions by index; there's no good way to tell what an expression is used +/// for. The only way to do it is to just iterate over all the expressions looking for +/// relevant `Access` expressions --- which is what this function does. +/// +/// Simple expressions like variable loads and constants don't make sense to cache: it's +/// no better than just re-evaluating them. But constants are not covered by `Emit` +/// statements, and `Load`s are always cached to ensure they occur at the right time, so +/// we don't bother filtering them out from this set. +/// +/// Fortunately, we don't need to deal with [`ImageStore`] statements here. When we emit +/// code for a statement, the writer isn't in the middle of an expression, so we can just +/// emit declarations for temporaries, initialized appropriately. +/// +/// None of these concerns apply for SPIR-V output, since it's easy to just reuse an +/// instruction ID in two places; that has the same semantics as a temporary variable, and +/// it's inherent in the design of SPIR-V. This function is more useful for text-based +/// back ends. +/// +/// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite +/// [`index`]: crate::Expression::Access::index +/// [`Access`]: crate::Expression::Access +/// [`level`]: crate::Expression::ImageLoad::level +/// [`ImageLoad`]: crate::Expression::ImageLoad +/// [`Emit`]: crate::Statement::Emit +/// [`ImageStore`]: crate::Statement::ImageStore +pub fn find_checked_indexes( + module: &crate::Module, + function: &crate::Function, + info: &crate::valid::FunctionInfo, + policies: BoundsCheckPolicies, +) -> BitSet { + use crate::Expression as Ex; + + let mut guarded_indices = BitSet::new(); + + // Don't bother scanning if we never need `ReadZeroSkipWrite`. + if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) { + for (_handle, expr) in function.expressions.iter() { + // There's no need to handle `AccessIndex` expressions, as their + // indices never need to be cached. + match *expr { + Ex::Access { base, index } => { + if policies.choose_policy(base, &module.types, info) + == BoundsCheckPolicy::ReadZeroSkipWrite + && access_needs_check( + base, + GuardedIndex::Expression(index), + module, + function, + info, + ) + .is_some() + { + guarded_indices.insert(index.index()); + } + } + Ex::ImageLoad { + coordinate, + array_index, + sample, + level, + .. + } => { + if policies.image == BoundsCheckPolicy::ReadZeroSkipWrite { + guarded_indices.insert(coordinate.index()); + if let Some(array_index) = array_index { + guarded_indices.insert(array_index.index()); + } + if let Some(sample) = sample { + guarded_indices.insert(sample.index()); + } + if let Some(level) = level { + guarded_indices.insert(level.index()); + } + } + } + _ => {} + } + } + } + + guarded_indices +} + +/// Determine whether `index` is statically known to be in bounds for `base`. +/// +/// If we can't be sure that the index is in bounds, return the limit within +/// which valid indices must fall. +/// +/// The return value is one of the following: +/// +/// - `Some(Known(n))` indicates that `n` is the largest valid index. +/// +/// - `Some(Computed(global))` indicates that the largest valid index is one +/// less than the length of the array that is the last member of the +/// struct held in `global`. +/// +/// - `None` indicates that the index need not be checked, either because it +/// is statically known to be in bounds, or because the applicable policy +/// is `Unchecked`. +/// +/// This function only handles subscriptable types: arrays, vectors, and +/// matrices. It does not handle struct member indices; those never require +/// run-time checks, so it's best to deal with them further up the call +/// chain. +pub fn access_needs_check( + base: Handle<crate::Expression>, + mut index: GuardedIndex, + module: &crate::Module, + function: &crate::Function, + info: &crate::valid::FunctionInfo, +) -> Option<IndexableLength> { + let base_inner = info[base].ty.inner_with(&module.types); + // Unwrap safety: `Err` here indicates unindexable base types and invalid + // length constants, but `access_needs_check` is only used by back ends, so + // validation should have caught those problems. + let length = base_inner.indexable_length(module).unwrap(); + index.try_resolve_to_constant(function, module); + if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) { + if index < length { + // Index is statically known to be in bounds, no check needed. + return None; + } + }; + + Some(length) +} + +impl GuardedIndex { + /// Make A `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible. + /// + /// If the expression is a [`Constant`] whose value is a non-specialized, scalar + /// integer constant that can be converted to a `u32`, do so and return a + /// `GuardedIndex::Known`. Otherwise, return the `GuardedIndex::Expression` + /// unchanged. + /// + /// Return values that are already `Known` unchanged. + /// + /// [`Constant`]: crate::Expression::Constant + fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) { + if let GuardedIndex::Expression(expr) = *self { + if let crate::Expression::Constant(handle) = function.expressions[expr] { + if let Some(value) = module.constants[handle].to_array_length() { + *self = GuardedIndex::Known(value); + } + } + } + } +} + +#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] +pub enum IndexableLengthError { + #[error("Type is not indexable, and has no length (validation error)")] + TypeNotIndexable, + #[error("Array length constant {0:?} is invalid")] + InvalidArrayLength(Handle<crate::Constant>), +} + +impl crate::TypeInner { + /// Return the length of a subscriptable type. + /// + /// The `self` parameter should be a handle to a vector, matrix, or array + /// type, a pointer to one of those, or a value pointer. Arrays may be + /// fixed-size, dynamically sized, or sized by a specializable constant. + /// This function does not handle struct member references, as with + /// `AccessIndex`. + /// + /// The value returned is appropriate for bounds checks on subscripting. + /// + /// Return an error if `self` does not describe a subscriptable type at all. + pub fn indexable_length( + &self, + module: &crate::Module, + ) -> Result<IndexableLength, IndexableLengthError> { + use crate::TypeInner as Ti; + let known_length = match *self { + Ti::Vector { size, .. } => size as _, + Ti::Matrix { columns, .. } => columns as _, + Ti::Array { size, .. } | Ti::BindingArray { size, .. } => { + return size.to_indexable_length(module); + } + Ti::ValuePointer { + size: Some(size), .. + } => size as _, + Ti::Pointer { base, .. } => { + // When assigning types to expressions, ResolveContext::Resolve + // does a separate sub-match here instead of a full recursion, + // so we'll do the same. + let base_inner = &module.types[base].inner; + match *base_inner { + Ti::Vector { size, .. } => size as _, + Ti::Matrix { columns, .. } => columns as _, + Ti::Array { size, .. } => return size.to_indexable_length(module), + _ => return Err(IndexableLengthError::TypeNotIndexable), + } + } + _ => return Err(IndexableLengthError::TypeNotIndexable), + }; + Ok(IndexableLength::Known(known_length)) + } +} + +/// The number of elements in an indexable type. +/// +/// This summarizes the length of vectors, matrices, and arrays in a way that is +/// convenient for indexing and bounds-checking code. +#[derive(Debug)] +pub enum IndexableLength { + /// Values of this type always have the given number of elements. + Known(u32), + + /// The number of elements is determined at runtime. + Dynamic, +} + +impl crate::ArraySize { + pub fn to_indexable_length( + self, + module: &crate::Module, + ) -> Result<IndexableLength, IndexableLengthError> { + Ok(match self { + Self::Constant(k) => { + let constant = &module.constants[k]; + if constant.specialization.is_some() { + // Specializable constants are not supported as array lengths. + // See valid::TypeError::UnsupportedSpecializedArrayLength. + return Err(IndexableLengthError::InvalidArrayLength(k)); + } + let length = constant + .to_array_length() + .ok_or(IndexableLengthError::InvalidArrayLength(k))?; + IndexableLength::Known(length) + } + Self::Dynamic => IndexableLength::Dynamic, + }) + } +} diff --git a/third_party/rust/naga/src/proc/layouter.rs b/third_party/rust/naga/src/proc/layouter.rs new file mode 100644 index 0000000000..0c3a00db15 --- /dev/null +++ b/third_party/rust/naga/src/proc/layouter.rs @@ -0,0 +1,257 @@ +use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; +use std::{fmt::Display, num::NonZeroU32, ops}; + +/// A newtype struct where its only valid values are powers of 2 +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Alignment(NonZeroU32); + +impl Alignment { + pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) }); + pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) }); + pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) }); + pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) }); + pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) }); + + pub const MIN_UNIFORM: Self = Self::SIXTEEN; + + pub const fn new(n: u32) -> Option<Self> { + if n.is_power_of_two() { + // SAFETY: value can't be 0 since we just checked if it's a power of 2 + Some(Self(unsafe { NonZeroU32::new_unchecked(n) })) + } else { + None + } + } + + /// # Panics + /// If `width` is not a power of 2 + pub fn from_width(width: u8) -> Self { + Self::new(width as u32).unwrap() + } + + /// Returns whether or not `n` is a multiple of this alignment. + pub const fn is_aligned(&self, n: u32) -> bool { + // equivalent to: `n % self.0.get() == 0` but much faster + n & (self.0.get() - 1) == 0 + } + + /// Round `n` up to the nearest alignment boundary. + pub const fn round_up(&self, n: u32) -> u32 { + // equivalent to: + // match n % self.0.get() { + // 0 => n, + // rem => n + (self.0.get() - rem), + // } + let mask = self.0.get() - 1; + (n + mask) & !mask + } +} + +impl Display for Alignment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.get().fmt(f) + } +} + +impl ops::Mul<u32> for Alignment { + type Output = u32; + + fn mul(self, rhs: u32) -> Self::Output { + self.0.get() * rhs + } +} + +impl ops::Mul for Alignment { + type Output = Alignment; + + fn mul(self, rhs: Alignment) -> Self::Output { + // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2 + Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) }) + } +} + +impl From<crate::VectorSize> for Alignment { + fn from(size: crate::VectorSize) -> Self { + match size { + crate::VectorSize::Bi => Alignment::TWO, + crate::VectorSize::Tri => Alignment::FOUR, + crate::VectorSize::Quad => Alignment::FOUR, + } + } +} + +/// Size and alignment information for a type. +#[derive(Clone, Copy, Debug, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct TypeLayout { + pub size: u32, + pub alignment: Alignment, +} + +impl TypeLayout { + /// Produce the stride as if this type is a base of an array. + pub const fn to_stride(&self) -> u32 { + self.alignment.round_up(self.size) + } +} + +/// Helper processor that derives the sizes of all types. +/// +/// `Layouter` uses the default layout algorithm/table, described in +/// [WGSL §4.3.7, "Memory Layout"] +/// +/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the +/// layout of the type whose handle is `handle`. +/// +/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts) +#[derive(Debug, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Layouter { + /// Layouts for types in an arena, indexed by `Handle` index. + layouts: Vec<TypeLayout>, +} + +impl ops::Index<Handle<crate::Type>> for Layouter { + type Output = TypeLayout; + fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout { + &self.layouts[handle.index()] + } +} + +#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] +pub enum LayoutErrorInner { + #[error("Array element type {0:?} doesn't exist")] + InvalidArrayElementType(Handle<crate::Type>), + #[error("Struct member[{0}] type {1:?} doesn't exist")] + InvalidStructMemberType(u32, Handle<crate::Type>), + #[error("Type width must be a power of two")] + NonPowerOfTwoWidth, + #[error("Array size is a bad handle")] + BadHandle(#[from] BadHandle), +} + +#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] +#[error("Error laying out type {ty:?}: {inner}")] +pub struct LayoutError { + pub ty: Handle<crate::Type>, + pub inner: LayoutErrorInner, +} + +impl LayoutErrorInner { + const fn with(self, ty: Handle<crate::Type>) -> LayoutError { + LayoutError { ty, inner: self } + } +} + +impl Layouter { + /// Remove all entries from this `Layouter`, retaining storage. + pub fn clear(&mut self) { + self.layouts.clear(); + } + + /// Extend this `Layouter` with layouts for any new entries in `types`. + /// + /// Ensure that every type in `types` has a corresponding [TypeLayout] in + /// [`self.layouts`]. + /// + /// Some front ends need to be able to compute layouts for existing types + /// while module construction is still in progress and new types are still + /// being added. This function assumes that the `TypeLayout` values already + /// present in `self.layouts` cover their corresponding entries in `types`, + /// and extends `self.layouts` as needed to cover the rest. Thus, a front + /// end can call this function at any time, passing its current type and + /// constant arenas, and then assume that layouts are available for all + /// types. + #[allow(clippy::or_fun_call)] + pub fn update( + &mut self, + types: &UniqueArena<crate::Type>, + constants: &Arena<crate::Constant>, + ) -> Result<(), LayoutError> { + use crate::TypeInner as Ti; + + for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { + let size = ty + .inner + .try_size(constants) + .map_err(|error| LayoutErrorInner::BadHandle(error).with(ty_handle))?; + let layout = match ty.inner { + Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { + let alignment = Alignment::new(width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { size, alignment } + } + Ti::Vector { + size: vec_size, + width, + .. + } => { + let alignment = Alignment::new(width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(vec_size) * alignment, + } + } + Ti::Matrix { + columns: _, + rows, + width, + } => { + let alignment = Alignment::new(width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(rows) * alignment, + } + } + Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { + size, + alignment: Alignment::ONE, + }, + Ti::Array { + base, + stride: _, + size: _, + } => TypeLayout { + size, + alignment: if base < ty_handle { + self[base].alignment + } else { + return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle)); + }, + }, + Ti::Struct { span, ref members } => { + let mut alignment = Alignment::ONE; + for (index, member) in members.iter().enumerate() { + alignment = if member.ty < ty_handle { + alignment.max(self[member.ty].alignment) + } else { + return Err(LayoutErrorInner::InvalidStructMemberType( + index as u32, + member.ty, + ) + .with(ty_handle)); + }; + } + TypeLayout { + size: span, + alignment, + } + } + Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => TypeLayout { + size, + alignment: Alignment::ONE, + }, + }; + debug_assert!(size <= layout.size); + self.layouts.push(layout); + } + + Ok(()) + } +} diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs new file mode 100644 index 0000000000..a5731de896 --- /dev/null +++ b/third_party/rust/naga/src/proc/mod.rs @@ -0,0 +1,490 @@ +/*! +[`Module`](super::Module) processing functionality. +*/ + +pub mod index; +mod layouter; +mod namer; +mod terminator; +mod typifier; + +use std::cmp::PartialEq; + +pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; +pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout}; +pub use namer::{EntryPointIndex, NameKey, Namer}; +pub use terminator::ensure_block_returns; +pub use typifier::{ResolveContext, ResolveError, TypeResolution}; + +impl From<super::StorageFormat> for super::ScalarKind { + fn from(format: super::StorageFormat) -> Self { + use super::{ScalarKind as Sk, StorageFormat as Sf}; + match format { + Sf::R8Unorm => Sk::Float, + Sf::R8Snorm => Sk::Float, + Sf::R8Uint => Sk::Uint, + Sf::R8Sint => Sk::Sint, + Sf::R16Uint => Sk::Uint, + Sf::R16Sint => Sk::Sint, + Sf::R16Float => Sk::Float, + Sf::Rg8Unorm => Sk::Float, + Sf::Rg8Snorm => Sk::Float, + Sf::Rg8Uint => Sk::Uint, + Sf::Rg8Sint => Sk::Sint, + Sf::R32Uint => Sk::Uint, + Sf::R32Sint => Sk::Sint, + Sf::R32Float => Sk::Float, + Sf::Rg16Uint => Sk::Uint, + Sf::Rg16Sint => Sk::Sint, + Sf::Rg16Float => Sk::Float, + Sf::Rgba8Unorm => Sk::Float, + Sf::Rgba8Snorm => Sk::Float, + Sf::Rgba8Uint => Sk::Uint, + Sf::Rgba8Sint => Sk::Sint, + Sf::Rgb10a2Unorm => Sk::Float, + Sf::Rg11b10Float => Sk::Float, + Sf::Rg32Uint => Sk::Uint, + Sf::Rg32Sint => Sk::Sint, + Sf::Rg32Float => Sk::Float, + Sf::Rgba16Uint => Sk::Uint, + Sf::Rgba16Sint => Sk::Sint, + Sf::Rgba16Float => Sk::Float, + Sf::Rgba32Uint => Sk::Uint, + Sf::Rgba32Sint => Sk::Sint, + Sf::Rgba32Float => Sk::Float, + } + } +} + +impl super::ScalarValue { + pub const fn scalar_kind(&self) -> super::ScalarKind { + match *self { + Self::Uint(_) => super::ScalarKind::Uint, + Self::Sint(_) => super::ScalarKind::Sint, + Self::Float(_) => super::ScalarKind::Float, + Self::Bool(_) => super::ScalarKind::Bool, + } + } +} + +impl super::ScalarKind { + pub const fn is_numeric(self) -> bool { + match self { + crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float => true, + crate::ScalarKind::Bool => false, + } + } +} + +pub const POINTER_SPAN: u32 = 4; + +impl super::TypeInner { + pub const fn scalar_kind(&self) -> Option<super::ScalarKind> { + match *self { + super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => { + Some(kind) + } + super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float), + _ => None, + } + } + + pub const fn pointer_space(&self) -> Option<crate::AddressSpace> { + match *self { + Self::Pointer { space, .. } => Some(space), + Self::ValuePointer { space, .. } => Some(space), + _ => None, + } + } + + pub fn try_size( + &self, + constants: &super::Arena<super::Constant>, + ) -> Result<u32, crate::arena::BadHandle> { + Ok(match *self { + Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, + Self::Vector { + size, + kind: _, + width, + } => size as u32 * width as u32, + // matrices are treated as arrays of aligned columns + Self::Matrix { + columns, + rows, + width, + } => Alignment::from(rows) * width as u32 * columns as u32, + Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN, + Self::Array { + base: _, + size, + stride, + } => { + let count = match size { + super::ArraySize::Constant(handle) => { + let constant = constants.try_get(handle)?; + constant.to_array_length().unwrap_or(1) + } + // A dynamically-sized array has to have at least one element + super::ArraySize::Dynamic => 1, + }; + count * stride + } + Self::Struct { span, .. } => span, + Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, + }) + } + + /// Get the size of this type. Panics if the `constants` doesn't contain + /// a referenced handle. This may not happen in a properly validated IR module. + pub fn size(&self, constants: &super::Arena<super::Constant>) -> u32 { + self.try_size(constants).unwrap() + } + + /// Return the canonical form of `self`, or `None` if it's already in + /// canonical form. + /// + /// Certain types have multiple representations in `TypeInner`. This + /// function converts all forms of equivalent types to a single + /// representative of their class, so that simply applying `Eq` to the + /// result indicates whether the types are equivalent, as far as Naga IR is + /// concerned. + pub fn canonical_form( + &self, + types: &crate::UniqueArena<crate::Type>, + ) -> Option<crate::TypeInner> { + use crate::TypeInner as Ti; + match *self { + Ti::Pointer { base, space } => match types[base].inner { + Ti::Scalar { kind, width } => Some(Ti::ValuePointer { + size: None, + kind, + width, + space, + }), + Ti::Vector { size, kind, width } => Some(Ti::ValuePointer { + size: Some(size), + kind, + width, + space, + }), + _ => None, + }, + _ => None, + } + } + + /// Compare `self` and `rhs` as types. + /// + /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats + /// `ValuePointer` and `Pointer` types as equivalent. + /// + /// When you know that one side of the comparison is never a pointer, it's + /// fine to not bother with canonicalization, and just compare `TypeInner` + /// values with `==`. + pub fn equivalent( + &self, + rhs: &crate::TypeInner, + types: &crate::UniqueArena<crate::Type>, + ) -> bool { + let left = self.canonical_form(types); + let right = rhs.canonical_form(types); + left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs) + } + + pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool { + use crate::TypeInner as Ti; + match *self { + Ti::Array { size, .. } => size == crate::ArraySize::Dynamic, + Ti::Struct { ref members, .. } => members + .last() + .map(|last| types[last.ty].inner.is_dynamically_sized(types)) + .unwrap_or(false), + _ => false, + } + } +} + +impl super::AddressSpace { + pub fn access(self) -> crate::StorageAccess { + use crate::StorageAccess as Sa; + match self { + crate::AddressSpace::Function + | crate::AddressSpace::Private + | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE, + crate::AddressSpace::Uniform => Sa::LOAD, + crate::AddressSpace::Storage { access } => access, + crate::AddressSpace::Handle => Sa::LOAD, + crate::AddressSpace::PushConstant => Sa::LOAD, + } + } +} + +impl super::MathFunction { + pub const fn argument_count(&self) -> usize { + match *self { + // comparison + Self::Abs => 1, + Self::Min => 2, + Self::Max => 2, + Self::Clamp => 3, + Self::Saturate => 1, + // trigonometry + Self::Cos => 1, + Self::Cosh => 1, + Self::Sin => 1, + Self::Sinh => 1, + Self::Tan => 1, + Self::Tanh => 1, + Self::Acos => 1, + Self::Asin => 1, + Self::Atan => 1, + Self::Atan2 => 2, + Self::Asinh => 1, + Self::Acosh => 1, + Self::Atanh => 1, + Self::Radians => 1, + Self::Degrees => 1, + // decomposition + Self::Ceil => 1, + Self::Floor => 1, + Self::Round => 1, + Self::Fract => 1, + Self::Trunc => 1, + Self::Modf => 2, + Self::Frexp => 2, + Self::Ldexp => 2, + // exponent + Self::Exp => 1, + Self::Exp2 => 1, + Self::Log => 1, + Self::Log2 => 1, + Self::Pow => 2, + // geometry + Self::Dot => 2, + Self::Outer => 2, + Self::Cross => 2, + Self::Distance => 2, + Self::Length => 1, + Self::Normalize => 1, + Self::FaceForward => 3, + Self::Reflect => 2, + Self::Refract => 3, + // computational + Self::Sign => 1, + Self::Fma => 3, + Self::Mix => 3, + Self::Step => 2, + Self::SmoothStep => 3, + Self::Sqrt => 1, + Self::InverseSqrt => 1, + Self::Inverse => 1, + Self::Transpose => 1, + Self::Determinant => 1, + // bits + Self::CountOneBits => 1, + Self::ReverseBits => 1, + Self::ExtractBits => 3, + Self::InsertBits => 4, + Self::FindLsb => 1, + Self::FindMsb => 1, + // data packing + Self::Pack4x8snorm => 1, + Self::Pack4x8unorm => 1, + Self::Pack2x16snorm => 1, + Self::Pack2x16unorm => 1, + Self::Pack2x16float => 1, + // data unpacking + Self::Unpack4x8snorm => 1, + Self::Unpack4x8unorm => 1, + Self::Unpack2x16snorm => 1, + Self::Unpack2x16unorm => 1, + Self::Unpack2x16float => 1, + } + } +} + +impl crate::Expression { + /// Returns true if the expression is considered emitted at the start of a function. + pub const fn needs_pre_emit(&self) -> bool { + match *self { + Self::Constant(_) + | Self::FunctionArgument(_) + | Self::GlobalVariable(_) + | Self::LocalVariable(_) => true, + _ => false, + } + } + + /// Return true if this expression is a dynamic array index, for [`Access`]. + /// + /// This method returns true if this expression is a dynamically computed + /// index, and as such can only be used to index matrices and arrays when + /// they appear behind a pointer. See the documentation for [`Access`] for + /// details. + /// + /// Note, this does not check the _type_ of the given expression. It's up to + /// the caller to establish that the `Access` expression is well-typed + /// through other means, like [`ResolveContext`]. + /// + /// [`Access`]: crate::Expression::Access + /// [`ResolveContext`]: crate::proc::ResolveContext + pub fn is_dynamic_index(&self, module: &crate::Module) -> bool { + if let Self::Constant(handle) = *self { + let constant = &module.constants[handle]; + constant.specialization.is_some() + } else { + true + } + } +} + +impl crate::Function { + /// Return the global variable being accessed by the expression `pointer`. + /// + /// Assuming that `pointer` is a series of `Access` and `AccessIndex` + /// expressions that ultimately access some part of a `GlobalVariable`, + /// return a handle for that global. + /// + /// If the expression does not ultimately access a global variable, return + /// `None`. + pub fn originating_global( + &self, + mut pointer: crate::Handle<crate::Expression>, + ) -> Option<crate::Handle<crate::GlobalVariable>> { + loop { + pointer = match self.expressions[pointer] { + crate::Expression::Access { base, .. } => base, + crate::Expression::AccessIndex { base, .. } => base, + crate::Expression::GlobalVariable(handle) => return Some(handle), + crate::Expression::LocalVariable(_) => return None, + crate::Expression::FunctionArgument(_) => return None, + // There are no other expressions that produce pointer values. + _ => unreachable!(), + } + } + } +} + +impl crate::SampleLevel { + pub const fn implicit_derivatives(&self) -> bool { + match *self { + Self::Auto | Self::Bias(_) => true, + Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false, + } + } +} + +impl crate::Constant { + /// Interpret this constant as an array length, and return it as a `u32`. + /// + /// Ignore any specialization available for this constant; return its + /// unspecialized value. + /// + /// If the constant has an inappropriate kind (non-scalar or non-integer) or + /// value (negative, out of range for u32), return `None`. This usually + /// indicates an error, but only the caller has enough information to report + /// the error helpfully: in back ends, it's a validation error, but in front + /// ends, it may indicate ill-formed input (for example, a SPIR-V + /// `OpArrayType` referring to an inappropriate `OpConstant`). So we return + /// `Option` and let the caller sort things out. + pub(crate) fn to_array_length(&self) -> Option<u32> { + match self.inner { + crate::ConstantInner::Scalar { value, width: _ } => match value { + crate::ScalarValue::Uint(value) => value.try_into().ok(), + // Accept a signed integer size to avoid + // requiring an explicit uint + // literal. Type inference should make + // this unnecessary. + crate::ScalarValue::Sint(value) => value.try_into().ok(), + _ => None, + }, + // caught by type validation + crate::ConstantInner::Composite { .. } => None, + } + } +} + +impl crate::Binding { + pub const fn to_built_in(&self) -> Option<crate::BuiltIn> { + match *self { + crate::Binding::BuiltIn(built_in) => Some(built_in), + Self::Location { .. } => None, + } + } +} + +//TODO: should we use an existing crate for hashable floats? +impl PartialEq for crate::ScalarValue { + fn eq(&self, other: &Self) -> bool { + match (*self, *other) { + (Self::Uint(a), Self::Uint(b)) => a == b, + (Self::Sint(a), Self::Sint(b)) => a == b, + (Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(), + (Self::Bool(a), Self::Bool(b)) => a == b, + _ => false, + } + } +} +impl Eq for crate::ScalarValue {} +impl std::hash::Hash for crate::ScalarValue { + fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) { + match *self { + Self::Sint(v) => v.hash(hasher), + Self::Uint(v) => v.hash(hasher), + Self::Float(v) => v.to_bits().hash(hasher), + Self::Bool(v) => v.hash(hasher), + } + } +} + +impl super::SwizzleComponent { + pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W]; + + pub const fn index(&self) -> u32 { + match *self { + Self::X => 0, + Self::Y => 1, + Self::Z => 2, + Self::W => 3, + } + } + pub const fn from_index(idx: u32) -> Self { + match idx { + 0 => Self::X, + 1 => Self::Y, + 2 => Self::Z, + _ => Self::W, + } + } +} + +impl super::ImageClass { + pub const fn is_multisampled(self) -> bool { + match self { + crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi, + crate::ImageClass::Storage { .. } => false, + } + } + + pub const fn is_mipmapped(self) -> bool { + match self { + crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi, + crate::ImageClass::Storage { .. } => false, + } + } +} + +#[test] +fn test_matrix_size() { + let constants = crate::Arena::new(); + assert_eq!( + crate::TypeInner::Matrix { + columns: crate::VectorSize::Tri, + rows: crate::VectorSize::Tri, + width: 4 + } + .size(&constants), + 48, + ); +} diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs new file mode 100644 index 0000000000..3f4af47884 --- /dev/null +++ b/third_party/rust/naga/src/proc/namer.rs @@ -0,0 +1,261 @@ +use crate::{arena::Handle, FastHashMap, FastHashSet}; +use std::borrow::Cow; + +pub type EntryPointIndex = u16; +const SEPARATOR: char = '_'; + +#[derive(Debug, Eq, Hash, PartialEq)] +pub enum NameKey { + Constant(Handle<crate::Constant>), + GlobalVariable(Handle<crate::GlobalVariable>), + Type(Handle<crate::Type>), + StructMember(Handle<crate::Type>, u32), + Function(Handle<crate::Function>), + FunctionArgument(Handle<crate::Function>, u32), + FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>), + EntryPoint(EntryPointIndex), + EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>), + EntryPointArgument(EntryPointIndex, u32), +} + +/// This processor assigns names to all the things in a module +/// that may need identifiers in a textual backend. +#[derive(Default)] +pub struct Namer { + /// The last numeric suffix used for each base name. Zero means "no suffix". + unique: FastHashMap<String, u32>, + keywords: FastHashSet<String>, + reserved_prefixes: Vec<String>, +} + +impl Namer { + /// Return a form of `string` suitable for use as the base of an identifier. + /// + /// - Drop leading digits. + /// - Retain only alphanumeric and `_` characters. + /// - Avoid prefixes in [`Namer::reserved_prefixes`]. + /// + /// The return value is a valid identifier prefix in all of Naga's output languages, + /// and it never ends with a `SEPARATOR` character. + /// It is used as a key into the unique table. + fn sanitize<'s>(&self, string: &'s str) -> Cow<'s, str> { + let string = string + .trim_start_matches(|c: char| c.is_numeric()) + .trim_end_matches(SEPARATOR); + + let base = if !string.is_empty() + && string + .chars() + .all(|c: char| c.is_ascii_alphanumeric() || c == '_') + { + Cow::Borrowed(string) + } else { + let mut filtered = string + .chars() + .filter(|&c| c.is_ascii_alphanumeric() || c == '_') + .collect::<String>(); + let stripped_len = filtered.trim_end_matches(SEPARATOR).len(); + filtered.truncate(stripped_len); + if filtered.is_empty() { + filtered.push_str("unnamed"); + } + Cow::Owned(filtered) + }; + + for prefix in &self.reserved_prefixes { + if base.starts_with(prefix) { + return format!("gen_{}", base).into(); + } + } + + base + } + + /// Return a new identifier based on `label_raw`. + /// + /// The result: + /// - is a valid identifier even if `label_raw` is not + /// - conflicts with no keywords listed in `Namer::keywords`, and + /// - is different from any identifier previously constructed by this + /// `Namer`. + /// + /// Guarantee uniqueness by applying a numeric suffix when necessary. If `label_raw` + /// itself ends with digits, separate them from the suffix with an underscore. + pub fn call(&mut self, label_raw: &str) -> String { + use std::fmt::Write as _; // for write!-ing to Strings + + let base = self.sanitize(label_raw); + debug_assert!(!base.is_empty() && !base.ends_with(SEPARATOR)); + + // This would seem to be a natural place to use `HashMap::entry`. However, `entry` + // requires an owned key, and we'd like to avoid heap-allocating strings we're + // just going to throw away. The approach below double-hashes only when we create + // a new entry, in which case the heap allocation of the owned key was more + // expensive anyway. + match self.unique.get_mut(base.as_ref()) { + Some(count) => { + *count += 1; + // Add the suffix. This may fit in base's existing allocation. + let mut suffixed = base.into_owned(); + write!(suffixed, "{}{}", SEPARATOR, *count).unwrap(); + suffixed + } + None => { + let mut suffixed = base.to_string(); + if base.ends_with(char::is_numeric) || self.keywords.contains(base.as_ref()) { + suffixed.push(SEPARATOR); + } + debug_assert!(!self.keywords.contains(&suffixed)); + // `self.unique` wants to own its keys. This allocates only if we haven't + // already done so earlier. + self.unique.insert(base.into_owned(), 0); + suffixed + } + } + } + + pub fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String { + self.call(match *label { + Some(ref name) => name, + None => fallback, + }) + } + + /// Enter a local namespace for things like structs. + /// + /// Struct member names only need to be unique amongst themselves, not + /// globally. This function temporarily establishes a fresh, empty naming + /// context for the duration of the call to `body`. + fn namespace(&mut self, capacity: usize, body: impl FnOnce(&mut Self)) { + let fresh = FastHashMap::with_capacity_and_hasher(capacity, Default::default()); + let outer = std::mem::replace(&mut self.unique, fresh); + body(self); + self.unique = outer; + } + + pub fn reset( + &mut self, + module: &crate::Module, + reserved_keywords: &[&str], + reserved_prefixes: &[&str], + output: &mut FastHashMap<NameKey, String>, + ) { + self.reserved_prefixes.clear(); + self.reserved_prefixes + .extend(reserved_prefixes.iter().map(|string| string.to_string())); + + self.unique.clear(); + self.keywords.clear(); + self.keywords + .extend(reserved_keywords.iter().map(|string| (string.to_string()))); + let mut temp = String::new(); + + for (ty_handle, ty) in module.types.iter() { + let ty_name = self.call_or(&ty.name, "type"); + output.insert(NameKey::Type(ty_handle), ty_name); + + if let crate::TypeInner::Struct { ref members, .. } = ty.inner { + // struct members have their own namespace, because access is always prefixed + self.namespace(members.len(), |namer| { + for (index, member) in members.iter().enumerate() { + let name = namer.call_or(&member.name, "member"); + output.insert(NameKey::StructMember(ty_handle, index as u32), name); + } + }) + } + } + + for (ep_index, ep) in module.entry_points.iter().enumerate() { + let ep_name = self.call(&ep.name); + output.insert(NameKey::EntryPoint(ep_index as _), ep_name); + for (index, arg) in ep.function.arguments.iter().enumerate() { + let name = self.call_or(&arg.name, "param"); + output.insert( + NameKey::EntryPointArgument(ep_index as _, index as u32), + name, + ); + } + for (handle, var) in ep.function.local_variables.iter() { + let name = self.call_or(&var.name, "local"); + output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name); + } + } + + for (fun_handle, fun) in module.functions.iter() { + let fun_name = self.call_or(&fun.name, "function"); + output.insert(NameKey::Function(fun_handle), fun_name); + for (index, arg) in fun.arguments.iter().enumerate() { + let name = self.call_or(&arg.name, "param"); + output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name); + } + for (handle, var) in fun.local_variables.iter() { + let name = self.call_or(&var.name, "local"); + output.insert(NameKey::FunctionLocal(fun_handle, handle), name); + } + } + + for (handle, var) in module.global_variables.iter() { + let name = self.call_or(&var.name, "global"); + output.insert(NameKey::GlobalVariable(handle), name); + } + + for (handle, constant) in module.constants.iter() { + let label = match constant.name { + Some(ref name) => name, + None => { + use std::fmt::Write; + // Try to be more descriptive about the constant values + temp.clear(); + match constant.inner { + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Sint(v), + } => write!(temp, "const_{}i", v), + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Uint(v), + } => write!(temp, "const_{}u", v), + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Float(v), + } => { + let abs = v.abs(); + write!( + temp, + "const_{}{}", + if v < 0.0 { "n" } else { "" }, + abs.trunc(), + ) + .unwrap(); + let fract = abs.fract(); + if fract == 0.0 { + write!(temp, "f") + } else { + write!(temp, "_{:02}f", (fract * 100.0) as i8) + } + } + crate::ConstantInner::Scalar { + width: _, + value: crate::ScalarValue::Bool(v), + } => write!(temp, "const_{}", v), + crate::ConstantInner::Composite { ty, components: _ } => { + write!(temp, "const_{}", output[&NameKey::Type(ty)]) + } + } + .unwrap(); + &temp + } + }; + let name = self.call(label); + output.insert(NameKey::Constant(handle), name); + } + } +} + +#[test] +fn test() { + let mut namer = Namer::default(); + assert_eq!(namer.call("x"), "x"); + assert_eq!(namer.call("x"), "x_1"); + assert_eq!(namer.call("x1"), "x1_"); +} diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs new file mode 100644 index 0000000000..5915616cc5 --- /dev/null +++ b/third_party/rust/naga/src/proc/terminator.rs @@ -0,0 +1,42 @@ +/// Ensure that the given block has return statements +/// at the end of its control flow. +/// +/// Note: we don't want to blindly append a return statement +/// to the end, because it may be either redundant or invalid, +/// e.g. when the user already has returns in if/else branches. +pub fn ensure_block_returns(block: &mut crate::Block) { + use crate::Statement as S; + match block.last_mut() { + Some(&mut S::Block(ref mut b)) => { + ensure_block_returns(b); + } + Some(&mut S::If { + condition: _, + ref mut accept, + ref mut reject, + }) => { + ensure_block_returns(accept); + ensure_block_returns(reject); + } + Some(&mut S::Switch { + selector: _, + ref mut cases, + }) => { + for case in cases.iter_mut() { + if !case.fall_through { + ensure_block_returns(&mut case.body); + } + } + } + Some(&mut (S::Emit(_) | S::Break | S::Continue | S::Return { .. } | S::Kill)) => (), + Some( + &mut (S::Loop { .. } + | S::Store { .. } + | S::ImageStore { .. } + | S::Call { .. } + | S::Atomic { .. } + | S::Barrier(_)), + ) + | None => block.push(S::Return { value: None }, Default::default()), + } +} diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs new file mode 100644 index 0000000000..9a5922ea76 --- /dev/null +++ b/third_party/rust/naga/src/proc/typifier.rs @@ -0,0 +1,903 @@ +use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; + +use thiserror::Error; + +/// The result of computing an expression's type. +/// +/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent +/// the (Naga) type it ascribes to some expression. +/// +/// You might expect such a function to simply return a `Handle<Type>`. However, +/// we want type resolution to be a read-only process, and that would limit the +/// possible results to types already present in the expression's associated +/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are +/// not certain to be present. +/// +/// So instead, type resolution returns a `TypeResolution` enum: either a +/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a +/// free-floating [`TypeInner`]. This extends the range to cover anything that +/// can be represented with a `TypeInner` referring to the existing arena. +/// +/// What sorts of expressions can have types not available in the arena? +/// +/// - An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or +/// [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector` +/// and `Matrix` represent their element and column types implicitly, not +/// via a handle, there may not be a suitable type in the expression's +/// associated arena. Instead, resolving such an expression returns a +/// `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or +/// `Vector`. +/// +/// - Similarly, the type of an [`Access`] or [`AccessIndex`] expression +/// applied to a *pointer to* a vector or matrix must produce a *pointer to* +/// a scalar or vector type. These cannot be represented with a +/// [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the +/// arena, and as before, we cannot assume that a suitable scalar or vector +/// type is there. So we take things one step further and provide +/// [`TypeInner::ValuePointer`], specifically for the case of pointers to +/// scalars or vectors. This type fits in a `TypeInner` and is exactly +/// equivalent to a `Pointer` to a `Vector` or `Scalar`. +/// +/// So, for example, the type of an `Access` expression applied to a value of type: +/// +/// ```ignore +/// TypeInner::Matrix { columns, rows, width } +/// ``` +/// +/// might be: +/// +/// ```ignore +/// TypeResolution::Value(TypeInner::Vector { +/// size: rows, +/// kind: ScalarKind::Float, +/// width, +/// }) +/// ``` +/// +/// and the type of an access to a pointer of address space `space` to such a +/// matrix might be: +/// +/// ```ignore +/// TypeResolution::Value(TypeInner::ValuePointer { +/// size: Some(rows), +/// kind: ScalarKind::Float, +/// width, +/// space, +/// }) +/// ``` +/// +/// [`Handle`]: TypeResolution::Handle +/// [`Value`]: TypeResolution::Value +/// +/// [`Access`]: crate::Expression::Access +/// [`AccessIndex`]: crate::Expression::AccessIndex +/// +/// [`TypeInner`]: crate::TypeInner +/// [`Matrix`]: crate::TypeInner::Matrix +/// [`Pointer`]: crate::TypeInner::Pointer +/// [`Scalar`]: crate::TypeInner::Scalar +/// [`ValuePointer`]: crate::TypeInner::ValuePointer +/// [`Vector`]: crate::TypeInner::Vector +/// +/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer +/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum TypeResolution { + /// A type stored in the associated arena. + Handle(Handle<crate::Type>), + + /// A free-floating [`TypeInner`], representing a type that may not be + /// available in the associated arena. However, the `TypeInner` itself may + /// contain `Handle<Type>` values referring to types from the arena. + /// + /// [`TypeInner`]: crate::TypeInner + Value(crate::TypeInner), +} + +impl TypeResolution { + pub const fn handle(&self) -> Option<Handle<crate::Type>> { + match *self { + Self::Handle(handle) => Some(handle), + Self::Value(_) => None, + } + } + + pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner { + match *self { + Self::Handle(handle) => &arena[handle].inner, + Self::Value(ref inner) => inner, + } + } +} + +// Clone is only implemented for numeric variants of `TypeInner`. +impl Clone for TypeResolution { + fn clone(&self) -> Self { + use crate::TypeInner as Ti; + match *self { + Self::Handle(handle) => Self::Handle(handle), + Self::Value(ref v) => Self::Value(match *v { + Ti::Scalar { kind, width } => Ti::Scalar { kind, width }, + Ti::Vector { size, kind, width } => Ti::Vector { size, kind, width }, + Ti::Matrix { + rows, + columns, + width, + } => Ti::Matrix { + rows, + columns, + width, + }, + Ti::Pointer { base, space } => Ti::Pointer { base, space }, + Ti::ValuePointer { + size, + kind, + width, + space, + } => Ti::ValuePointer { + size, + kind, + width, + space, + }, + _ => unreachable!("Unexpected clone type: {:?}", v), + }), + } + } +} + +impl crate::ConstantInner { + pub const fn resolve_type(&self) -> TypeResolution { + match *self { + Self::Scalar { width, ref value } => TypeResolution::Value(crate::TypeInner::Scalar { + kind: value.scalar_kind(), + width, + }), + Self::Composite { ty, components: _ } => TypeResolution::Handle(ty), + } + } +} + +#[derive(Clone, Debug, Error, PartialEq)] +pub enum ResolveError { + #[error(transparent)] + BadHandle(#[from] BadHandle), + #[error("Index {index} is out of bounds for expression {expr:?}")] + OutOfBoundsIndex { + expr: Handle<crate::Expression>, + index: u32, + }, + #[error("Invalid access into expression {expr:?}, indexed: {indexed}")] + InvalidAccess { + expr: Handle<crate::Expression>, + indexed: bool, + }, + #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")] + InvalidSubAccess { + ty: Handle<crate::Type>, + indexed: bool, + }, + #[error("Invalid scalar {0:?}")] + InvalidScalar(Handle<crate::Expression>), + #[error("Invalid vector {0:?}")] + InvalidVector(Handle<crate::Expression>), + #[error("Invalid pointer {0:?}")] + InvalidPointer(Handle<crate::Expression>), + #[error("Invalid image {0:?}")] + InvalidImage(Handle<crate::Expression>), + #[error("Function {name} not defined")] + FunctionNotDefined { name: String }, + #[error("Function without return type")] + FunctionReturnsVoid, + #[error("Incompatible operands: {0}")] + IncompatibleOperands(String), + #[error("Function argument {0} doesn't exist")] + FunctionArgumentNotFound(u32), + #[error("Expression {0:?} depends on expressions that follow")] + ExpressionForwardDependency(Handle<crate::Expression>), +} + +pub struct ResolveContext<'a> { + pub constants: &'a Arena<crate::Constant>, + pub types: &'a UniqueArena<crate::Type>, + pub global_vars: &'a Arena<crate::GlobalVariable>, + pub local_vars: &'a Arena<crate::LocalVariable>, + pub functions: &'a Arena<crate::Function>, + pub arguments: &'a [crate::FunctionArgument], +} + +impl<'a> ResolveContext<'a> { + /// Determine the type of `expr`. + /// + /// The `past` argument must be a closure that can resolve the types of any + /// expressions that `expr` refers to. These can be gathered by caching the + /// results of prior calls to `resolve`, perhaps as done by the + /// [`front::Typifier`] utility type. + /// + /// Type resolution is a read-only process: this method takes `self` by + /// shared reference. However, this means that we cannot add anything to + /// `self.types` that we might need to describe `expr`. To work around this, + /// this method returns a [`TypeResolution`], rather than simply returning a + /// `Handle<Type>`; see the documentation for [`TypeResolution`] for + /// details. + /// + /// [`front::Typifier`]: crate::front::Typifier + pub fn resolve( + &self, + expr: &crate::Expression, + past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>, + ) -> Result<TypeResolution, ResolveError> { + use crate::TypeInner as Ti; + let types = self.types; + Ok(match *expr { + crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) { + // Arrays and matrices can only be indexed dynamically behind a + // pointer, but that's a validation error, not a type error, so + // go ahead provide a type here. + Ti::Array { base, .. } => TypeResolution::Handle(base), + Ti::Matrix { rows, width, .. } => TypeResolution::Value(Ti::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }), + Ti::Vector { + size: _, + kind, + width, + } => TypeResolution::Value(Ti::Scalar { kind, width }), + Ti::ValuePointer { + size: Some(_), + kind, + width, + space, + } => TypeResolution::Value(Ti::ValuePointer { + size: None, + kind, + width, + space, + }), + Ti::Pointer { base, space } => { + TypeResolution::Value(match types[base].inner { + Ti::Array { base, .. } => Ti::Pointer { base, space }, + Ti::Vector { + size: _, + kind, + width, + } => Ti::ValuePointer { + size: None, + kind, + width, + space, + }, + // Matrices are only dynamically indexed behind a pointer + Ti::Matrix { + columns: _, + rows, + width, + } => Ti::ValuePointer { + kind: crate::ScalarKind::Float, + size: Some(rows), + width, + space, + }, + ref other => { + log::error!("Access sub-type {:?}", other); + return Err(ResolveError::InvalidSubAccess { + ty: base, + indexed: false, + }); + } + }) + } + Ti::BindingArray { base, .. } => TypeResolution::Handle(base), + ref other => { + log::error!("Access type {:?}", other); + return Err(ResolveError::InvalidAccess { + expr: base, + indexed: false, + }); + } + }, + crate::Expression::AccessIndex { base, index } => { + match *past(base)?.inner_with(types) { + Ti::Vector { size, kind, width } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + TypeResolution::Value(Ti::Scalar { kind, width }) + } + Ti::Matrix { + columns, + rows, + width, + } => { + if index >= columns as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + TypeResolution::Value(crate::TypeInner::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }) + } + Ti::Array { base, .. } => TypeResolution::Handle(base), + Ti::Struct { ref members, .. } => { + let member = members + .get(index as usize) + .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?; + TypeResolution::Handle(member.ty) + } + Ti::ValuePointer { + size: Some(size), + kind, + width, + space, + } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + TypeResolution::Value(Ti::ValuePointer { + size: None, + kind, + width, + space, + }) + } + Ti::Pointer { + base: ty_base, + space, + } => TypeResolution::Value(match types[ty_base].inner { + Ti::Array { base, .. } => Ti::Pointer { base, space }, + Ti::Vector { size, kind, width } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + Ti::ValuePointer { + size: None, + kind, + width, + space, + } + } + Ti::Matrix { + rows, + columns, + width, + } => { + if index >= columns as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + Ti::ValuePointer { + size: Some(rows), + kind: crate::ScalarKind::Float, + width, + space, + } + } + Ti::Struct { ref members, .. } => { + let member = members + .get(index as usize) + .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?; + Ti::Pointer { + base: member.ty, + space, + } + } + ref other => { + log::error!("Access index sub-type {:?}", other); + return Err(ResolveError::InvalidSubAccess { + ty: ty_base, + indexed: true, + }); + } + }), + Ti::BindingArray { base, .. } => TypeResolution::Handle(base), + ref other => { + log::error!("Access index type {:?}", other); + return Err(ResolveError::InvalidAccess { + expr: base, + indexed: true, + }); + } + } + } + crate::Expression::Constant(h) => { + let constant = self.constants.try_get(h)?; + match constant.inner { + crate::ConstantInner::Scalar { width, ref value } => { + TypeResolution::Value(Ti::Scalar { + kind: value.scalar_kind(), + width, + }) + } + crate::ConstantInner::Composite { ty, components: _ } => { + TypeResolution::Handle(ty) + } + } + } + crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { + Ti::Scalar { kind, width } => { + TypeResolution::Value(Ti::Vector { size, kind, width }) + } + ref other => { + log::error!("Scalar type {:?}", other); + return Err(ResolveError::InvalidScalar(value)); + } + }, + crate::Expression::Swizzle { + size, + vector, + pattern: _, + } => match *past(vector)?.inner_with(types) { + Ti::Vector { + size: _, + kind, + width, + } => TypeResolution::Value(Ti::Vector { size, kind, width }), + ref other => { + log::error!("Vector type {:?}", other); + return Err(ResolveError::InvalidVector(vector)); + } + }, + crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::FunctionArgument(index) => { + let arg = self + .arguments + .get(index as usize) + .ok_or(ResolveError::FunctionArgumentNotFound(index))?; + TypeResolution::Handle(arg.ty) + } + crate::Expression::GlobalVariable(h) => { + let var = self.global_vars.try_get(h)?; + if var.space == crate::AddressSpace::Handle { + TypeResolution::Handle(var.ty) + } else { + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: var.space, + }) + } + } + crate::Expression::LocalVariable(h) => { + let var = self.local_vars.try_get(h)?; + TypeResolution::Value(Ti::Pointer { + base: var.ty, + space: crate::AddressSpace::Function, + }) + } + crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) { + Ti::Pointer { base, space: _ } => { + if let Ti::Atomic { kind, width } = types[base].inner { + TypeResolution::Value(Ti::Scalar { kind, width }) + } else { + TypeResolution::Handle(base) + } + } + Ti::ValuePointer { + size, + kind, + width, + space: _, + } => TypeResolution::Value(match size { + Some(size) => Ti::Vector { size, kind, width }, + None => Ti::Scalar { kind, width }, + }), + ref other => { + log::error!("Pointer type {:?}", other); + return Err(ResolveError::InvalidPointer(pointer)); + } + }, + crate::Expression::ImageSample { + image, + gather: Some(_), + .. + } => match *past(image)?.inner_with(types) { + Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector { + kind: match class { + crate::ImageClass::Sampled { kind, multi: _ } => kind, + _ => crate::ScalarKind::Float, + }, + width: 4, + size: crate::VectorSize::Quad, + }), + ref other => { + log::error!("Image type {:?}", other); + return Err(ResolveError::InvalidImage(image)); + } + }, + crate::Expression::ImageSample { image, .. } + | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) { + Ti::Image { class, .. } => TypeResolution::Value(match class { + crate::ImageClass::Depth { multi: _ } => Ti::Scalar { + kind: crate::ScalarKind::Float, + width: 4, + }, + crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector { + kind, + width: 4, + size: crate::VectorSize::Quad, + }, + crate::ImageClass::Storage { format, .. } => Ti::Vector { + kind: format.into(), + width: 4, + size: crate::VectorSize::Quad, + }, + }), + ref other => { + log::error!("Image type {:?}", other); + return Err(ResolveError::InvalidImage(image)); + } + }, + crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query { + crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) { + Ti::Image { dim, .. } => match dim { + crate::ImageDimension::D1 => Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector { + size: crate::VectorSize::Bi, + kind: crate::ScalarKind::Sint, + width: 4, + }, + crate::ImageDimension::D3 => Ti::Vector { + size: crate::VectorSize::Tri, + kind: crate::ScalarKind::Sint, + width: 4, + }, + }, + ref other => { + log::error!("Image type {:?}", other); + return Err(ResolveError::InvalidImage(image)); + } + }, + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => Ti::Scalar { + kind: crate::ScalarKind::Sint, + width: 4, + }, + }), + crate::Expression::Unary { expr, .. } => past(expr)?.clone(), + crate::Expression::Binary { op, left, right } => match op { + crate::BinaryOperator::Add + | crate::BinaryOperator::Subtract + | crate::BinaryOperator::Divide + | crate::BinaryOperator::Modulo => past(left)?.clone(), + crate::BinaryOperator::Multiply => { + let (res_left, res_right) = (past(left)?, past(right)?); + match (res_left.inner_with(types), res_right.inner_with(types)) { + ( + &Ti::Matrix { + columns: _, + rows, + width, + }, + &Ti::Matrix { columns, .. }, + ) => TypeResolution::Value(Ti::Matrix { + columns, + rows, + width, + }), + ( + &Ti::Matrix { + columns: _, + rows, + width, + }, + &Ti::Vector { .. }, + ) => TypeResolution::Value(Ti::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }), + ( + &Ti::Vector { .. }, + &Ti::Matrix { + columns, + rows: _, + width, + }, + ) => TypeResolution::Value(Ti::Vector { + size: columns, + kind: crate::ScalarKind::Float, + width, + }), + (&Ti::Scalar { .. }, _) => res_right.clone(), + (_, &Ti::Scalar { .. }) => res_left.clone(), + (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(), + (tl, tr) => { + return Err(ResolveError::IncompatibleOperands(format!( + "{:?} * {:?}", + tl, tr + ))) + } + } + } + crate::BinaryOperator::Equal + | crate::BinaryOperator::NotEqual + | crate::BinaryOperator::Less + | crate::BinaryOperator::LessEqual + | crate::BinaryOperator::Greater + | crate::BinaryOperator::GreaterEqual + | crate::BinaryOperator::LogicalAnd + | crate::BinaryOperator::LogicalOr => { + let kind = crate::ScalarKind::Bool; + let width = crate::BOOL_WIDTH; + let inner = match *past(left)?.inner_with(types) { + Ti::Scalar { .. } => Ti::Scalar { kind, width }, + Ti::Vector { size, .. } => Ti::Vector { size, kind, width }, + ref other => { + return Err(ResolveError::IncompatibleOperands(format!( + "{:?}({:?}, _)", + op, other + ))) + } + }; + TypeResolution::Value(inner) + } + crate::BinaryOperator::And + | crate::BinaryOperator::ExclusiveOr + | crate::BinaryOperator::InclusiveOr + | crate::BinaryOperator::ShiftLeft + | crate::BinaryOperator::ShiftRight => past(left)?.clone(), + }, + crate::Expression::AtomicResult { + kind, + width, + comparison, + } => { + if comparison { + TypeResolution::Value(Ti::Vector { + size: crate::VectorSize::Bi, + kind, + width, + }) + } else { + TypeResolution::Value(Ti::Scalar { kind, width }) + } + } + crate::Expression::Select { accept, .. } => past(accept)?.clone(), + crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(), + crate::Expression::Relational { fun, argument } => match fun { + crate::RelationalFunction::All | crate::RelationalFunction::Any => { + TypeResolution::Value(Ti::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }) + } + crate::RelationalFunction::IsNan + | crate::RelationalFunction::IsInf + | crate::RelationalFunction::IsFinite + | crate::RelationalFunction::IsNormal => match *past(argument)?.inner_with(types) { + Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }), + Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + size, + }), + ref other => { + return Err(ResolveError::IncompatibleOperands(format!( + "{:?}({:?})", + fun, other + ))) + } + }, + }, + crate::Expression::Math { + fun, + arg, + arg1, + arg2: _, + arg3: _, + } => { + use crate::MathFunction as Mf; + let res_arg = past(arg)?; + match fun { + // comparison + Mf::Abs | + Mf::Min | + Mf::Max | + Mf::Clamp | + Mf::Saturate | + // trigonometry + Mf::Cos | + Mf::Cosh | + Mf::Sin | + Mf::Sinh | + Mf::Tan | + Mf::Tanh | + Mf::Acos | + Mf::Asin | + Mf::Atan | + Mf::Atan2 | + Mf::Asinh | + Mf::Acosh | + Mf::Atanh | + Mf::Radians | + Mf::Degrees | + // decomposition + Mf::Ceil | + Mf::Floor | + Mf::Round | + Mf::Fract | + Mf::Trunc | + Mf::Modf | + Mf::Frexp | + Mf::Ldexp | + // exponent + Mf::Exp | + Mf::Exp2 | + Mf::Log | + Mf::Log2 | + Mf::Pow => res_arg.clone(), + // geometry + Mf::Dot => match *res_arg.inner_with(types) { + Ti::Vector { + kind, + size: _, + width, + } => TypeResolution::Value(Ti::Scalar { kind, width }), + ref other => + return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?}, _)", fun, other) + )), + }, + Mf::Outer => { + let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands( + format!("{:?}(_, None)", fun) + ))?; + match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) { + (&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => TypeResolution::Value(Ti::Matrix { columns, rows, width }), + (left, right) => + return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?}, {:?})", fun, left, right) + )), + } + }, + Mf::Cross => res_arg.clone(), + Mf::Distance | + Mf::Length => match *res_arg.inner_with(types) { + Ti::Scalar {width,kind} | + Ti::Vector {width,kind,size:_} => TypeResolution::Value(Ti::Scalar { kind, width }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?})", fun, other) + )), + }, + Mf::Normalize | + Mf::FaceForward | + Mf::Reflect | + Mf::Refract => res_arg.clone(), + // computational + Mf::Sign | + Mf::Fma | + Mf::Mix | + Mf::Step | + Mf::SmoothStep | + Mf::Sqrt | + Mf::InverseSqrt => res_arg.clone(), + Mf::Transpose => match *res_arg.inner_with(types) { + Ti::Matrix { + columns, + rows, + width, + } => TypeResolution::Value(Ti::Matrix { + columns: rows, + rows: columns, + width, + }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?})", fun, other) + )), + }, + Mf::Inverse => match *res_arg.inner_with(types) { + Ti::Matrix { + columns, + rows, + width, + } if columns == rows => TypeResolution::Value(Ti::Matrix { + columns, + rows, + width, + }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?})", fun, other) + )), + }, + Mf::Determinant => match *res_arg.inner_with(types) { + Ti::Matrix { + width, + .. + } => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Float, width }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?})", fun, other) + )), + }, + // bits + Mf::CountOneBits | + Mf::ReverseBits | + Mf::ExtractBits | + Mf::InsertBits | + Mf::FindLsb | + Mf::FindMsb => match *res_arg.inner_with(types) { + Ti::Scalar { kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } => + TypeResolution::Value(Ti::Scalar { kind, width }), + Ti::Vector { size, kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } => + TypeResolution::Value(Ti::Vector { size, kind, width }), + ref other => return Err(ResolveError::IncompatibleOperands( + format!("{:?}({:?})", fun, other) + )), + }, + // data packing + Mf::Pack4x8snorm | + Mf::Pack4x8unorm | + Mf::Pack2x16snorm | + Mf::Pack2x16unorm | + Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Uint, width: 4 }), + // data unpacking + Mf::Unpack4x8snorm | + Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Quad, kind: crate::ScalarKind::Float, width: 4 }), + Mf::Unpack2x16snorm | + Mf::Unpack2x16unorm | + Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Bi, kind: crate::ScalarKind::Float, width: 4 }), + } + } + crate::Expression::As { + expr, + kind, + convert, + } => match *past(expr)?.inner_with(types) { + Ti::Scalar { kind: _, width } => TypeResolution::Value(Ti::Scalar { + kind, + width: convert.unwrap_or(width), + }), + Ti::Vector { + kind: _, + size, + width, + } => TypeResolution::Value(Ti::Vector { + kind, + size, + width: convert.unwrap_or(width), + }), + Ti::Matrix { + columns, + rows, + width, + } => TypeResolution::Value(Ti::Matrix { + columns, + rows, + width: convert.unwrap_or(width), + }), + ref other => { + return Err(ResolveError::IncompatibleOperands(format!( + "{:?} as {:?}", + other, kind + ))) + } + }, + crate::Expression::CallResult(function) => { + let result = self.functions[function] + .result + .as_ref() + .ok_or(ResolveError::FunctionReturnsVoid)?; + TypeResolution::Handle(result.ty) + } + crate::Expression::ArrayLength(_) => TypeResolution::Value(Ti::Scalar { + kind: crate::ScalarKind::Uint, + width: 4, + }), + }) + } +} + +#[test] +fn test_error_size() { + use std::mem::size_of; + assert_eq!(size_of::<ResolveError>(), 32); +} |