summaryrefslogtreecommitdiffstats
path: root/third_party/rust/naga/src/proc
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 19:33:14 +0000
commit36d22d82aa202bb199967e9512281e9a53db42c9 (patch)
tree105e8c98ddea1c1e4784a60a5a6410fa416be2de /third_party/rust/naga/src/proc
parentInitial commit. (diff)
downloadfirefox-esr-upstream.tar.xz
firefox-esr-upstream.zip
Adding upstream version 115.7.0esr.upstream/115.7.0esrupstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/rust/naga/src/proc')
-rw-r--r--third_party/rust/naga/src/proc/index.rs437
-rw-r--r--third_party/rust/naga/src/proc/layouter.rs256
-rw-r--r--third_party/rust/naga/src/proc/mod.rs493
-rw-r--r--third_party/rust/naga/src/proc/namer.rs261
-rw-r--r--third_party/rust/naga/src/proc/terminator.rs43
-rw-r--r--third_party/rust/naga/src/proc/typifier.rs909
6 files changed, 2399 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/proc/index.rs b/third_party/rust/naga/src/proc/index.rs
new file mode 100644
index 0000000000..3fea79ec01
--- /dev/null
+++ b/third_party/rust/naga/src/proc/index.rs
@@ -0,0 +1,437 @@
+/*!
+Definitions for index bounds checking.
+*/
+
+use crate::{valid, Handle, UniqueArena};
+use bit_set::BitSet;
+
+/// How should code generated by Naga do bounds checks?
+///
+/// When a vector, matrix, or array index is out of bounds—either negative, or
+/// greater than or equal to the number of elements in the type—WGSL requires
+/// that some other index of the implementation's choice that is in bounds is
+/// used instead. (There are no types with zero elements.)
+///
+/// Similarly, when out-of-bounds coordinates, array indices, or sample indices
+/// are presented to the WGSL `textureLoad` and `textureStore` operations, the
+/// operation is redirected to do something safe.
+///
+/// Different users of Naga will prefer different defaults:
+///
+/// - When used as part of a WebGPU implementation, the WGSL specification
+/// requires the `Restrict` behavior for array, vector, and matrix accesses,
+/// and either the `Restrict` or `ReadZeroSkipWrite` behaviors for texture
+/// accesses.
+///
+/// - When used by the `wgpu` crate for native development, `wgpu` selects
+/// `ReadZeroSkipWrite` as its default.
+///
+/// - Naga's own default is `Unchecked`, so that shader translations
+/// are as faithful to the original as possible.
+///
+/// Sometimes the underlying hardware and drivers can perform bounds checks
+/// themselves, in a way that performs better than the checks Naga would inject.
+/// If you're using native checks like this, then having Naga inject its own
+/// checks as well would be redundant, and the `Unchecked` policy is
+/// appropriate.
+#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum BoundsCheckPolicy {
+ /// Replace out-of-bounds indexes with some arbitrary in-bounds index.
+ ///
+ /// (This does not necessarily mean clamping. For example, interpreting the
+ /// index as unsigned and taking the minimum with the largest valid index
+ /// would also be a valid implementation. That would map negative indices to
+ /// the last element, not the first.)
+ Restrict,
+
+ /// Out-of-bounds reads return zero, and writes have no effect.
+ ///
+ /// When applied to a chain of accesses, like `a[i][j].b[k]`, all index
+ /// expressions are evaluated, regardless of whether prior or later index
+ /// expressions were in bounds. But all the accesses per se are skipped
+ /// if any index is out of bounds.
+ ReadZeroSkipWrite,
+
+ /// Naga adds no checks to indexing operations. Generate the fastest code
+ /// possible. This is the default for Naga, as a translator, but consumers
+ /// should consider defaulting to a safer behavior.
+ Unchecked,
+}
+
+/// Policies for injecting bounds checks during code generation.
+#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct BoundsCheckPolicies {
+ /// How should the generated code handle array, vector, or matrix indices
+ /// that are out of range?
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub index: BoundsCheckPolicy,
+
+ /// How should the generated code handle array, vector, or matrix indices
+ /// that are out of range, when those values live in a [`GlobalVariable`] in
+ /// the [`Storage`] or [`Uniform`] address spaces?
+ ///
+ /// Some graphics hardware provides "robust buffer access", a feature that
+ /// ensures that using a pointer cannot access memory outside the 'buffer'
+ /// that it was derived from. In Naga terms, this means that the hardware
+ /// ensures that pointers computed by applying [`Access`] and
+ /// [`AccessIndex`] expressions to a [`GlobalVariable`] whose [`space`] is
+ /// [`Storage`] or [`Uniform`] will never read or write memory outside that
+ /// global variable.
+ ///
+ /// When hardware offers such a feature, it is probably undesirable to have
+ /// Naga inject bounds checking code for such accesses, since the hardware
+ /// can probably provide the same protection more efficiently. However,
+ /// bounds checks are still needed on accesses to indexable values that do
+ /// not live in buffers, like local variables.
+ ///
+ /// So, this option provides a separate policy that applies only to accesses
+ /// to storage and uniform globals. When depending on hardware bounds
+ /// checking, this policy can be `Unchecked` to avoid unnecessary overhead.
+ ///
+ /// When special hardware support is not available, this should probably be
+ /// the same as `index_bounds_check_policy`.
+ ///
+ /// [`GlobalVariable`]: crate::GlobalVariable
+ /// [`space`]: crate::GlobalVariable::space
+ /// [`Restrict`]: crate::back::BoundsCheckPolicy::Restrict
+ /// [`ReadZeroSkipWrite`]: crate::back::BoundsCheckPolicy::ReadZeroSkipWrite
+ /// [`Access`]: crate::Expression::Access
+ /// [`AccessIndex`]: crate::Expression::AccessIndex
+ /// [`Storage`]: crate::AddressSpace::Storage
+ /// [`Uniform`]: crate::AddressSpace::Uniform
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub buffer: BoundsCheckPolicy,
+
+ /// How should the generated code handle image texel references that are out
+ /// of range?
+ ///
+ /// This controls the behavior of [`ImageLoad`] expressions and
+ /// [`ImageStore`] statements when a coordinate, texture array index, level
+ /// of detail, or multisampled sample number is out of range.
+ ///
+ /// [`ImageLoad`]: crate::Expression::ImageLoad
+ /// [`ImageStore`]: crate::Statement::ImageStore
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub image: BoundsCheckPolicy,
+
+ /// How should the generated code handle binding array indexes that are out of bounds.
+ #[cfg_attr(feature = "deserialize", serde(default))]
+ pub binding_array: BoundsCheckPolicy,
+}
+
+/// The default `BoundsCheckPolicy` is `Unchecked`.
+impl Default for BoundsCheckPolicy {
+ fn default() -> Self {
+ BoundsCheckPolicy::Unchecked
+ }
+}
+
+impl BoundsCheckPolicies {
+ /// Determine which policy applies to `base`.
+ ///
+ /// `base` is the "base" expression (the expression being indexed) of a `Access`
+ /// and `AccessIndex` expression. This is either a pointer, a value, being directly
+ /// indexed, or a binding array.
+ ///
+ /// See the documentation for [`BoundsCheckPolicy`] for details about
+ /// when each policy applies.
+ pub fn choose_policy(
+ &self,
+ base: Handle<crate::Expression>,
+ types: &UniqueArena<crate::Type>,
+ info: &valid::FunctionInfo,
+ ) -> BoundsCheckPolicy {
+ let ty = info[base].ty.inner_with(types);
+
+ if let crate::TypeInner::BindingArray { .. } = *ty {
+ return self.binding_array;
+ }
+
+ match ty.pointer_space() {
+ Some(crate::AddressSpace::Storage { access: _ } | crate::AddressSpace::Uniform) => {
+ self.buffer
+ }
+ // This covers other address spaces, but also accessing vectors and
+ // matrices by value, where no pointer is involved.
+ _ => self.index,
+ }
+ }
+
+ /// Return `true` if any of `self`'s policies are `policy`.
+ pub fn contains(&self, policy: BoundsCheckPolicy) -> bool {
+ self.index == policy || self.buffer == policy || self.image == policy
+ }
+}
+
+/// An index that may be statically known, or may need to be computed at runtime.
+///
+/// This enum lets us handle both [`Access`] and [`AccessIndex`] expressions
+/// with the same code.
+///
+/// [`Access`]: crate::Expression::Access
+/// [`AccessIndex`]: crate::Expression::AccessIndex
+#[derive(Clone, Copy, Debug)]
+pub enum GuardedIndex {
+ Known(u32),
+ Expression(Handle<crate::Expression>),
+}
+
+/// Build a set of expressions used as indices, to cache in temporary variables when
+/// emitted.
+///
+/// Given the bounds-check policies `policies`, construct a `BitSet` containing the handle
+/// indices of all the expressions in `function` that are ever used as guarded indices
+/// under the [`ReadZeroSkipWrite`] policy. The `module` argument must be the module to
+/// which `function` belongs, and `info` should be that function's analysis results.
+///
+/// Such index expressions will be used twice in the generated code: first for the
+/// comparison to see if the index is in bounds, and then for the access itself, should
+/// the comparison succeed. To avoid computing the expressions twice, the generated code
+/// should cache them in temporary variables.
+///
+/// Why do we need to build such a set in advance, instead of just processing access
+/// expressions as we encounter them? Whether an expression needs to be cached depends on
+/// whether it appears as something like the [`index`] operand of an [`Access`] expression
+/// or the [`level`] operand of an [`ImageLoad`] expression, and on the index bounds check
+/// policies that apply to those accesses. But [`Emit`] statements just identify a range
+/// of expressions by index; there's no good way to tell what an expression is used
+/// for. The only way to do it is to just iterate over all the expressions looking for
+/// relevant `Access` expressions --- which is what this function does.
+///
+/// Simple expressions like variable loads and constants don't make sense to cache: it's
+/// no better than just re-evaluating them. But constants are not covered by `Emit`
+/// statements, and `Load`s are always cached to ensure they occur at the right time, so
+/// we don't bother filtering them out from this set.
+///
+/// Fortunately, we don't need to deal with [`ImageStore`] statements here. When we emit
+/// code for a statement, the writer isn't in the middle of an expression, so we can just
+/// emit declarations for temporaries, initialized appropriately.
+///
+/// None of these concerns apply for SPIR-V output, since it's easy to just reuse an
+/// instruction ID in two places; that has the same semantics as a temporary variable, and
+/// it's inherent in the design of SPIR-V. This function is more useful for text-based
+/// back ends.
+///
+/// [`ReadZeroSkipWrite`]: BoundsCheckPolicy::ReadZeroSkipWrite
+/// [`index`]: crate::Expression::Access::index
+/// [`Access`]: crate::Expression::Access
+/// [`level`]: crate::Expression::ImageLoad::level
+/// [`ImageLoad`]: crate::Expression::ImageLoad
+/// [`Emit`]: crate::Statement::Emit
+/// [`ImageStore`]: crate::Statement::ImageStore
+pub fn find_checked_indexes(
+ module: &crate::Module,
+ function: &crate::Function,
+ info: &crate::valid::FunctionInfo,
+ policies: BoundsCheckPolicies,
+) -> BitSet {
+ use crate::Expression as Ex;
+
+ let mut guarded_indices = BitSet::new();
+
+ // Don't bother scanning if we never need `ReadZeroSkipWrite`.
+ if policies.contains(BoundsCheckPolicy::ReadZeroSkipWrite) {
+ for (_handle, expr) in function.expressions.iter() {
+ // There's no need to handle `AccessIndex` expressions, as their
+ // indices never need to be cached.
+ match *expr {
+ Ex::Access { base, index } => {
+ if policies.choose_policy(base, &module.types, info)
+ == BoundsCheckPolicy::ReadZeroSkipWrite
+ && access_needs_check(
+ base,
+ GuardedIndex::Expression(index),
+ module,
+ function,
+ info,
+ )
+ .is_some()
+ {
+ guarded_indices.insert(index.index());
+ }
+ }
+ Ex::ImageLoad {
+ coordinate,
+ array_index,
+ sample,
+ level,
+ ..
+ } => {
+ if policies.image == BoundsCheckPolicy::ReadZeroSkipWrite {
+ guarded_indices.insert(coordinate.index());
+ if let Some(array_index) = array_index {
+ guarded_indices.insert(array_index.index());
+ }
+ if let Some(sample) = sample {
+ guarded_indices.insert(sample.index());
+ }
+ if let Some(level) = level {
+ guarded_indices.insert(level.index());
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ }
+
+ guarded_indices
+}
+
+/// Determine whether `index` is statically known to be in bounds for `base`.
+///
+/// If we can't be sure that the index is in bounds, return the limit within
+/// which valid indices must fall.
+///
+/// The return value is one of the following:
+///
+/// - `Some(Known(n))` indicates that `n` is the largest valid index.
+///
+/// - `Some(Computed(global))` indicates that the largest valid index is one
+/// less than the length of the array that is the last member of the
+/// struct held in `global`.
+///
+/// - `None` indicates that the index need not be checked, either because it
+/// is statically known to be in bounds, or because the applicable policy
+/// is `Unchecked`.
+///
+/// This function only handles subscriptable types: arrays, vectors, and
+/// matrices. It does not handle struct member indices; those never require
+/// run-time checks, so it's best to deal with them further up the call
+/// chain.
+pub fn access_needs_check(
+ base: Handle<crate::Expression>,
+ mut index: GuardedIndex,
+ module: &crate::Module,
+ function: &crate::Function,
+ info: &crate::valid::FunctionInfo,
+) -> Option<IndexableLength> {
+ let base_inner = info[base].ty.inner_with(&module.types);
+ // Unwrap safety: `Err` here indicates unindexable base types and invalid
+ // length constants, but `access_needs_check` is only used by back ends, so
+ // validation should have caught those problems.
+ let length = base_inner.indexable_length(module).unwrap();
+ index.try_resolve_to_constant(function, module);
+ if let (&GuardedIndex::Known(index), &IndexableLength::Known(length)) = (&index, &length) {
+ if index < length {
+ // Index is statically known to be in bounds, no check needed.
+ return None;
+ }
+ };
+
+ Some(length)
+}
+
+impl GuardedIndex {
+ /// Make A `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible.
+ ///
+ /// If the expression is a [`Constant`] whose value is a non-specialized, scalar
+ /// integer constant that can be converted to a `u32`, do so and return a
+ /// `GuardedIndex::Known`. Otherwise, return the `GuardedIndex::Expression`
+ /// unchanged.
+ ///
+ /// Return values that are already `Known` unchanged.
+ ///
+ /// [`Constant`]: crate::Expression::Constant
+ fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) {
+ if let GuardedIndex::Expression(expr) = *self {
+ if let crate::Expression::Constant(handle) = function.expressions[expr] {
+ if let Some(value) = module.constants[handle].to_array_length() {
+ *self = GuardedIndex::Known(value);
+ }
+ }
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)]
+pub enum IndexableLengthError {
+ #[error("Type is not indexable, and has no length (validation error)")]
+ TypeNotIndexable,
+ #[error("Array length constant {0:?} is invalid")]
+ InvalidArrayLength(Handle<crate::Constant>),
+}
+
+impl crate::TypeInner {
+ /// Return the length of a subscriptable type.
+ ///
+ /// The `self` parameter should be a handle to a vector, matrix, or array
+ /// type, a pointer to one of those, or a value pointer. Arrays may be
+ /// fixed-size, dynamically sized, or sized by a specializable constant.
+ /// This function does not handle struct member references, as with
+ /// `AccessIndex`.
+ ///
+ /// The value returned is appropriate for bounds checks on subscripting.
+ ///
+ /// Return an error if `self` does not describe a subscriptable type at all.
+ pub fn indexable_length(
+ &self,
+ module: &crate::Module,
+ ) -> Result<IndexableLength, IndexableLengthError> {
+ use crate::TypeInner as Ti;
+ let known_length = match *self {
+ Ti::Vector { size, .. } => size as _,
+ Ti::Matrix { columns, .. } => columns as _,
+ Ti::Array { size, .. } | Ti::BindingArray { size, .. } => {
+ return size.to_indexable_length(module);
+ }
+ Ti::ValuePointer {
+ size: Some(size), ..
+ } => size as _,
+ Ti::Pointer { base, .. } => {
+ // When assigning types to expressions, ResolveContext::Resolve
+ // does a separate sub-match here instead of a full recursion,
+ // so we'll do the same.
+ let base_inner = &module.types[base].inner;
+ match *base_inner {
+ Ti::Vector { size, .. } => size as _,
+ Ti::Matrix { columns, .. } => columns as _,
+ Ti::Array { size, .. } => return size.to_indexable_length(module),
+ _ => return Err(IndexableLengthError::TypeNotIndexable),
+ }
+ }
+ _ => return Err(IndexableLengthError::TypeNotIndexable),
+ };
+ Ok(IndexableLength::Known(known_length))
+ }
+}
+
+/// The number of elements in an indexable type.
+///
+/// This summarizes the length of vectors, matrices, and arrays in a way that is
+/// convenient for indexing and bounds-checking code.
+#[derive(Debug)]
+pub enum IndexableLength {
+ /// Values of this type always have the given number of elements.
+ Known(u32),
+
+ /// The number of elements is determined at runtime.
+ Dynamic,
+}
+
+impl crate::ArraySize {
+ pub fn to_indexable_length(
+ self,
+ module: &crate::Module,
+ ) -> Result<IndexableLength, IndexableLengthError> {
+ Ok(match self {
+ Self::Constant(k) => {
+ let constant = &module.constants[k];
+ if constant.specialization.is_some() {
+ // Specializable constants are not supported as array lengths.
+ // See valid::TypeError::UnsupportedSpecializedArrayLength.
+ return Err(IndexableLengthError::InvalidArrayLength(k));
+ }
+ let length = constant
+ .to_array_length()
+ .ok_or(IndexableLengthError::InvalidArrayLength(k))?;
+ IndexableLength::Known(length)
+ }
+ Self::Dynamic => IndexableLength::Dynamic,
+ })
+ }
+}
diff --git a/third_party/rust/naga/src/proc/layouter.rs b/third_party/rust/naga/src/proc/layouter.rs
new file mode 100644
index 0000000000..65369d1cc8
--- /dev/null
+++ b/third_party/rust/naga/src/proc/layouter.rs
@@ -0,0 +1,256 @@
+use crate::arena::{Arena, Handle, UniqueArena};
+use std::{fmt::Display, num::NonZeroU32, ops};
+
+/// A newtype struct where its only valid values are powers of 2
+#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Alignment(NonZeroU32);
+
+impl Alignment {
+ pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
+ pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
+ pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
+ pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
+ pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
+
+ pub const MIN_UNIFORM: Self = Self::SIXTEEN;
+
+ pub const fn new(n: u32) -> Option<Self> {
+ if n.is_power_of_two() {
+ // SAFETY: value can't be 0 since we just checked if it's a power of 2
+ Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
+ } else {
+ None
+ }
+ }
+
+ /// # Panics
+ /// If `width` is not a power of 2
+ pub fn from_width(width: u8) -> Self {
+ Self::new(width as u32).unwrap()
+ }
+
+ /// Returns whether or not `n` is a multiple of this alignment.
+ pub const fn is_aligned(&self, n: u32) -> bool {
+ // equivalent to: `n % self.0.get() == 0` but much faster
+ n & (self.0.get() - 1) == 0
+ }
+
+ /// Round `n` up to the nearest alignment boundary.
+ pub const fn round_up(&self, n: u32) -> u32 {
+ // equivalent to:
+ // match n % self.0.get() {
+ // 0 => n,
+ // rem => n + (self.0.get() - rem),
+ // }
+ let mask = self.0.get() - 1;
+ (n + mask) & !mask
+ }
+}
+
+impl Display for Alignment {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ self.0.get().fmt(f)
+ }
+}
+
+impl ops::Mul<u32> for Alignment {
+ type Output = u32;
+
+ fn mul(self, rhs: u32) -> Self::Output {
+ self.0.get() * rhs
+ }
+}
+
+impl ops::Mul for Alignment {
+ type Output = Alignment;
+
+ fn mul(self, rhs: Alignment) -> Self::Output {
+ // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2
+ Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
+ }
+}
+
+impl From<crate::VectorSize> for Alignment {
+ fn from(size: crate::VectorSize) -> Self {
+ match size {
+ crate::VectorSize::Bi => Alignment::TWO,
+ crate::VectorSize::Tri => Alignment::FOUR,
+ crate::VectorSize::Quad => Alignment::FOUR,
+ }
+ }
+}
+
+/// Size and alignment information for a type.
+#[derive(Clone, Copy, Debug, Hash, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct TypeLayout {
+ pub size: u32,
+ pub alignment: Alignment,
+}
+
+impl TypeLayout {
+ /// Produce the stride as if this type is a base of an array.
+ pub const fn to_stride(&self) -> u32 {
+ self.alignment.round_up(self.size)
+ }
+}
+
+/// Helper processor that derives the sizes of all types.
+///
+/// `Layouter` uses the default layout algorithm/table, described in
+/// [WGSL §4.3.7, "Memory Layout"]
+///
+/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the
+/// layout of the type whose handle is `handle`.
+///
+/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts)
+#[derive(Debug, Default)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub struct Layouter {
+ /// Layouts for types in an arena, indexed by `Handle` index.
+ layouts: Vec<TypeLayout>,
+}
+
+impl ops::Index<Handle<crate::Type>> for Layouter {
+ type Output = TypeLayout;
+ fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
+ &self.layouts[handle.index()]
+ }
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
+pub enum LayoutErrorInner {
+ #[error("Array element type {0:?} doesn't exist")]
+ InvalidArrayElementType(Handle<crate::Type>),
+ #[error("Struct member[{0}] type {1:?} doesn't exist")]
+ InvalidStructMemberType(u32, Handle<crate::Type>),
+ #[error("Type width must be a power of two")]
+ NonPowerOfTwoWidth,
+}
+
+#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
+#[error("Error laying out type {ty:?}: {inner}")]
+pub struct LayoutError {
+ pub ty: Handle<crate::Type>,
+ pub inner: LayoutErrorInner,
+}
+
+impl LayoutErrorInner {
+ const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
+ LayoutError { ty, inner: self }
+ }
+}
+
+impl Layouter {
+ /// Remove all entries from this `Layouter`, retaining storage.
+ pub fn clear(&mut self) {
+ self.layouts.clear();
+ }
+
+ /// Extend this `Layouter` with layouts for any new entries in `types`.
+ ///
+ /// Ensure that every type in `types` has a corresponding [TypeLayout] in
+ /// [`self.layouts`].
+ ///
+ /// Some front ends need to be able to compute layouts for existing types
+ /// while module construction is still in progress and new types are still
+ /// being added. This function assumes that the `TypeLayout` values already
+ /// present in `self.layouts` cover their corresponding entries in `types`,
+ /// and extends `self.layouts` as needed to cover the rest. Thus, a front
+ /// end can call this function at any time, passing its current type and
+ /// constant arenas, and then assume that layouts are available for all
+ /// types.
+ #[allow(clippy::or_fun_call)]
+ pub fn update(
+ &mut self,
+ types: &UniqueArena<crate::Type>,
+ constants: &Arena<crate::Constant>,
+ ) -> Result<(), LayoutError> {
+ use crate::TypeInner as Ti;
+
+ for (ty_handle, ty) in types.iter().skip(self.layouts.len()) {
+ let size = ty.inner.size(constants);
+ let layout = match ty.inner {
+ Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => {
+ let alignment = Alignment::new(width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout { size, alignment }
+ }
+ Ti::Vector {
+ size: vec_size,
+ width,
+ ..
+ } => {
+ let alignment = Alignment::new(width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout {
+ size,
+ alignment: Alignment::from(vec_size) * alignment,
+ }
+ }
+ Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ } => {
+ let alignment = Alignment::new(width as u32)
+ .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
+ TypeLayout {
+ size,
+ alignment: Alignment::from(rows) * alignment,
+ }
+ }
+ Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
+ size,
+ alignment: Alignment::ONE,
+ },
+ Ti::Array {
+ base,
+ stride: _,
+ size: _,
+ } => TypeLayout {
+ size,
+ alignment: if base < ty_handle {
+ self[base].alignment
+ } else {
+ return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
+ },
+ },
+ Ti::Struct { span, ref members } => {
+ let mut alignment = Alignment::ONE;
+ for (index, member) in members.iter().enumerate() {
+ alignment = if member.ty < ty_handle {
+ alignment.max(self[member.ty].alignment)
+ } else {
+ return Err(LayoutErrorInner::InvalidStructMemberType(
+ index as u32,
+ member.ty,
+ )
+ .with(ty_handle));
+ };
+ }
+ TypeLayout {
+ size: span,
+ alignment,
+ }
+ }
+ Ti::Image { .. }
+ | Ti::Sampler { .. }
+ | Ti::AccelerationStructure
+ | Ti::RayQuery
+ | Ti::BindingArray { .. } => TypeLayout {
+ size,
+ alignment: Alignment::ONE,
+ },
+ };
+ debug_assert!(size <= layout.size);
+ self.layouts.push(layout);
+ }
+
+ Ok(())
+ }
+}
diff --git a/third_party/rust/naga/src/proc/mod.rs b/third_party/rust/naga/src/proc/mod.rs
new file mode 100644
index 0000000000..a775272a19
--- /dev/null
+++ b/third_party/rust/naga/src/proc/mod.rs
@@ -0,0 +1,493 @@
+/*!
+[`Module`](super::Module) processing functionality.
+*/
+
+pub mod index;
+mod layouter;
+mod namer;
+mod terminator;
+mod typifier;
+
+use std::cmp::PartialEq;
+
+pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
+pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
+pub use namer::{EntryPointIndex, NameKey, Namer};
+pub use terminator::ensure_block_returns;
+pub use typifier::{ResolveContext, ResolveError, TypeResolution};
+
+impl From<super::StorageFormat> for super::ScalarKind {
+ fn from(format: super::StorageFormat) -> Self {
+ use super::{ScalarKind as Sk, StorageFormat as Sf};
+ match format {
+ Sf::R8Unorm => Sk::Float,
+ Sf::R8Snorm => Sk::Float,
+ Sf::R8Uint => Sk::Uint,
+ Sf::R8Sint => Sk::Sint,
+ Sf::R16Uint => Sk::Uint,
+ Sf::R16Sint => Sk::Sint,
+ Sf::R16Float => Sk::Float,
+ Sf::Rg8Unorm => Sk::Float,
+ Sf::Rg8Snorm => Sk::Float,
+ Sf::Rg8Uint => Sk::Uint,
+ Sf::Rg8Sint => Sk::Sint,
+ Sf::R32Uint => Sk::Uint,
+ Sf::R32Sint => Sk::Sint,
+ Sf::R32Float => Sk::Float,
+ Sf::Rg16Uint => Sk::Uint,
+ Sf::Rg16Sint => Sk::Sint,
+ Sf::Rg16Float => Sk::Float,
+ Sf::Rgba8Unorm => Sk::Float,
+ Sf::Rgba8Snorm => Sk::Float,
+ Sf::Rgba8Uint => Sk::Uint,
+ Sf::Rgba8Sint => Sk::Sint,
+ Sf::Rgb10a2Unorm => Sk::Float,
+ Sf::Rg11b10Float => Sk::Float,
+ Sf::Rg32Uint => Sk::Uint,
+ Sf::Rg32Sint => Sk::Sint,
+ Sf::Rg32Float => Sk::Float,
+ Sf::Rgba16Uint => Sk::Uint,
+ Sf::Rgba16Sint => Sk::Sint,
+ Sf::Rgba16Float => Sk::Float,
+ Sf::Rgba32Uint => Sk::Uint,
+ Sf::Rgba32Sint => Sk::Sint,
+ Sf::Rgba32Float => Sk::Float,
+ Sf::R16Unorm => Sk::Float,
+ Sf::R16Snorm => Sk::Float,
+ Sf::Rg16Unorm => Sk::Float,
+ Sf::Rg16Snorm => Sk::Float,
+ Sf::Rgba16Unorm => Sk::Float,
+ Sf::Rgba16Snorm => Sk::Float,
+ }
+ }
+}
+
+impl super::ScalarValue {
+ pub const fn scalar_kind(&self) -> super::ScalarKind {
+ match *self {
+ Self::Uint(_) => super::ScalarKind::Uint,
+ Self::Sint(_) => super::ScalarKind::Sint,
+ Self::Float(_) => super::ScalarKind::Float,
+ Self::Bool(_) => super::ScalarKind::Bool,
+ }
+ }
+}
+
+impl super::ScalarKind {
+ pub const fn is_numeric(self) -> bool {
+ match self {
+ crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float => true,
+ crate::ScalarKind::Bool => false,
+ }
+ }
+}
+
+pub const POINTER_SPAN: u32 = 4;
+
+impl super::TypeInner {
+ pub const fn scalar_kind(&self) -> Option<super::ScalarKind> {
+ match *self {
+ super::TypeInner::Scalar { kind, .. } | super::TypeInner::Vector { kind, .. } => {
+ Some(kind)
+ }
+ super::TypeInner::Matrix { .. } => Some(super::ScalarKind::Float),
+ _ => None,
+ }
+ }
+
+ pub const fn pointer_space(&self) -> Option<crate::AddressSpace> {
+ match *self {
+ Self::Pointer { space, .. } => Some(space),
+ Self::ValuePointer { space, .. } => Some(space),
+ _ => None,
+ }
+ }
+
+ /// Get the size of this type.
+ pub fn size(&self, constants: &super::Arena<super::Constant>) -> u32 {
+ match *self {
+ Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32,
+ Self::Vector {
+ size,
+ kind: _,
+ width,
+ } => size as u32 * width as u32,
+ // matrices are treated as arrays of aligned columns
+ Self::Matrix {
+ columns,
+ rows,
+ width,
+ } => Alignment::from(rows) * width as u32 * columns as u32,
+ Self::Pointer { .. } | Self::ValuePointer { .. } => POINTER_SPAN,
+ Self::Array {
+ base: _,
+ size,
+ stride,
+ } => {
+ let count = match size {
+ super::ArraySize::Constant(handle) => {
+ constants[handle].to_array_length().unwrap_or(1)
+ }
+ // A dynamically-sized array has to have at least one element
+ super::ArraySize::Dynamic => 1,
+ };
+ count * stride
+ }
+ Self::Struct { span, .. } => span,
+ Self::Image { .. }
+ | Self::Sampler { .. }
+ | Self::AccelerationStructure
+ | Self::RayQuery
+ | Self::BindingArray { .. } => 0,
+ }
+ }
+
+ /// Return the canonical form of `self`, or `None` if it's already in
+ /// canonical form.
+ ///
+ /// Certain types have multiple representations in `TypeInner`. This
+ /// function converts all forms of equivalent types to a single
+ /// representative of their class, so that simply applying `Eq` to the
+ /// result indicates whether the types are equivalent, as far as Naga IR is
+ /// concerned.
+ pub fn canonical_form(
+ &self,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> Option<crate::TypeInner> {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Pointer { base, space } => match types[base].inner {
+ Ti::Scalar { kind, width } => Some(Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ }),
+ Ti::Vector { size, kind, width } => Some(Ti::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space,
+ }),
+ _ => None,
+ },
+ _ => None,
+ }
+ }
+
+ /// Compare `self` and `rhs` as types.
+ ///
+ /// This is mostly the same as `<TypeInner as Eq>::eq`, but it treats
+ /// `ValuePointer` and `Pointer` types as equivalent.
+ ///
+ /// When you know that one side of the comparison is never a pointer, it's
+ /// fine to not bother with canonicalization, and just compare `TypeInner`
+ /// values with `==`.
+ pub fn equivalent(
+ &self,
+ rhs: &crate::TypeInner,
+ types: &crate::UniqueArena<crate::Type>,
+ ) -> bool {
+ let left = self.canonical_form(types);
+ let right = rhs.canonical_form(types);
+ left.as_ref().unwrap_or(self) == right.as_ref().unwrap_or(rhs)
+ }
+
+ pub fn is_dynamically_sized(&self, types: &crate::UniqueArena<crate::Type>) -> bool {
+ use crate::TypeInner as Ti;
+ match *self {
+ Ti::Array { size, .. } => size == crate::ArraySize::Dynamic,
+ Ti::Struct { ref members, .. } => members
+ .last()
+ .map(|last| types[last.ty].inner.is_dynamically_sized(types))
+ .unwrap_or(false),
+ _ => false,
+ }
+ }
+}
+
+impl super::AddressSpace {
+ pub fn access(self) -> crate::StorageAccess {
+ use crate::StorageAccess as Sa;
+ match self {
+ crate::AddressSpace::Function
+ | crate::AddressSpace::Private
+ | crate::AddressSpace::WorkGroup => Sa::LOAD | Sa::STORE,
+ crate::AddressSpace::Uniform => Sa::LOAD,
+ crate::AddressSpace::Storage { access } => access,
+ crate::AddressSpace::Handle => Sa::LOAD,
+ crate::AddressSpace::PushConstant => Sa::LOAD,
+ }
+ }
+}
+
+impl super::MathFunction {
+ pub const fn argument_count(&self) -> usize {
+ match *self {
+ // comparison
+ Self::Abs => 1,
+ Self::Min => 2,
+ Self::Max => 2,
+ Self::Clamp => 3,
+ Self::Saturate => 1,
+ // trigonometry
+ Self::Cos => 1,
+ Self::Cosh => 1,
+ Self::Sin => 1,
+ Self::Sinh => 1,
+ Self::Tan => 1,
+ Self::Tanh => 1,
+ Self::Acos => 1,
+ Self::Asin => 1,
+ Self::Atan => 1,
+ Self::Atan2 => 2,
+ Self::Asinh => 1,
+ Self::Acosh => 1,
+ Self::Atanh => 1,
+ Self::Radians => 1,
+ Self::Degrees => 1,
+ // decomposition
+ Self::Ceil => 1,
+ Self::Floor => 1,
+ Self::Round => 1,
+ Self::Fract => 1,
+ Self::Trunc => 1,
+ Self::Modf => 2,
+ Self::Frexp => 2,
+ Self::Ldexp => 2,
+ // exponent
+ Self::Exp => 1,
+ Self::Exp2 => 1,
+ Self::Log => 1,
+ Self::Log2 => 1,
+ Self::Pow => 2,
+ // geometry
+ Self::Dot => 2,
+ Self::Outer => 2,
+ Self::Cross => 2,
+ Self::Distance => 2,
+ Self::Length => 1,
+ Self::Normalize => 1,
+ Self::FaceForward => 3,
+ Self::Reflect => 2,
+ Self::Refract => 3,
+ // computational
+ Self::Sign => 1,
+ Self::Fma => 3,
+ Self::Mix => 3,
+ Self::Step => 2,
+ Self::SmoothStep => 3,
+ Self::Sqrt => 1,
+ Self::InverseSqrt => 1,
+ Self::Inverse => 1,
+ Self::Transpose => 1,
+ Self::Determinant => 1,
+ // bits
+ Self::CountTrailingZeros => 1,
+ Self::CountLeadingZeros => 1,
+ Self::CountOneBits => 1,
+ Self::ReverseBits => 1,
+ Self::ExtractBits => 3,
+ Self::InsertBits => 4,
+ Self::FindLsb => 1,
+ Self::FindMsb => 1,
+ // data packing
+ Self::Pack4x8snorm => 1,
+ Self::Pack4x8unorm => 1,
+ Self::Pack2x16snorm => 1,
+ Self::Pack2x16unorm => 1,
+ Self::Pack2x16float => 1,
+ // data unpacking
+ Self::Unpack4x8snorm => 1,
+ Self::Unpack4x8unorm => 1,
+ Self::Unpack2x16snorm => 1,
+ Self::Unpack2x16unorm => 1,
+ Self::Unpack2x16float => 1,
+ }
+ }
+}
+
+impl crate::Expression {
+ /// Returns true if the expression is considered emitted at the start of a function.
+ pub const fn needs_pre_emit(&self) -> bool {
+ match *self {
+ Self::Constant(_)
+ | Self::FunctionArgument(_)
+ | Self::GlobalVariable(_)
+ | Self::LocalVariable(_) => true,
+ _ => false,
+ }
+ }
+
+ /// Return true if this expression is a dynamic array index, for [`Access`].
+ ///
+ /// This method returns true if this expression is a dynamically computed
+ /// index, and as such can only be used to index matrices and arrays when
+ /// they appear behind a pointer. See the documentation for [`Access`] for
+ /// details.
+ ///
+ /// Note, this does not check the _type_ of the given expression. It's up to
+ /// the caller to establish that the `Access` expression is well-typed
+ /// through other means, like [`ResolveContext`].
+ ///
+ /// [`Access`]: crate::Expression::Access
+ /// [`ResolveContext`]: crate::proc::ResolveContext
+ pub fn is_dynamic_index(&self, module: &crate::Module) -> bool {
+ if let Self::Constant(handle) = *self {
+ let constant = &module.constants[handle];
+ constant.specialization.is_some()
+ } else {
+ true
+ }
+ }
+}
+
+impl crate::Function {
+ /// Return the global variable being accessed by the expression `pointer`.
+ ///
+ /// Assuming that `pointer` is a series of `Access` and `AccessIndex`
+ /// expressions that ultimately access some part of a `GlobalVariable`,
+ /// return a handle for that global.
+ ///
+ /// If the expression does not ultimately access a global variable, return
+ /// `None`.
+ pub fn originating_global(
+ &self,
+ mut pointer: crate::Handle<crate::Expression>,
+ ) -> Option<crate::Handle<crate::GlobalVariable>> {
+ loop {
+ pointer = match self.expressions[pointer] {
+ crate::Expression::Access { base, .. } => base,
+ crate::Expression::AccessIndex { base, .. } => base,
+ crate::Expression::GlobalVariable(handle) => return Some(handle),
+ crate::Expression::LocalVariable(_) => return None,
+ crate::Expression::FunctionArgument(_) => return None,
+ // There are no other expressions that produce pointer values.
+ _ => unreachable!(),
+ }
+ }
+ }
+}
+
+impl crate::SampleLevel {
+ pub const fn implicit_derivatives(&self) -> bool {
+ match *self {
+ Self::Auto | Self::Bias(_) => true,
+ Self::Zero | Self::Exact(_) | Self::Gradient { .. } => false,
+ }
+ }
+}
+
+impl crate::Constant {
+ /// Interpret this constant as an array length, and return it as a `u32`.
+ ///
+ /// Ignore any specialization available for this constant; return its
+ /// unspecialized value.
+ ///
+ /// If the constant has an inappropriate kind (non-scalar or non-integer) or
+ /// value (negative, out of range for u32), return `None`. This usually
+ /// indicates an error, but only the caller has enough information to report
+ /// the error helpfully: in back ends, it's a validation error, but in front
+ /// ends, it may indicate ill-formed input (for example, a SPIR-V
+ /// `OpArrayType` referring to an inappropriate `OpConstant`). So we return
+ /// `Option` and let the caller sort things out.
+ pub(crate) fn to_array_length(&self) -> Option<u32> {
+ match self.inner {
+ crate::ConstantInner::Scalar { value, width: _ } => match value {
+ crate::ScalarValue::Uint(value) => value.try_into().ok(),
+ // Accept a signed integer size to avoid
+ // requiring an explicit uint
+ // literal. Type inference should make
+ // this unnecessary.
+ crate::ScalarValue::Sint(value) => value.try_into().ok(),
+ _ => None,
+ },
+ // caught by type validation
+ crate::ConstantInner::Composite { .. } => None,
+ }
+ }
+}
+
+impl crate::Binding {
+ pub const fn to_built_in(&self) -> Option<crate::BuiltIn> {
+ match *self {
+ crate::Binding::BuiltIn(built_in) => Some(built_in),
+ Self::Location { .. } => None,
+ }
+ }
+}
+
+//TODO: should we use an existing crate for hashable floats?
+impl PartialEq for crate::ScalarValue {
+ fn eq(&self, other: &Self) -> bool {
+ match (*self, *other) {
+ (Self::Uint(a), Self::Uint(b)) => a == b,
+ (Self::Sint(a), Self::Sint(b)) => a == b,
+ (Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(),
+ (Self::Bool(a), Self::Bool(b)) => a == b,
+ _ => false,
+ }
+ }
+}
+impl Eq for crate::ScalarValue {}
+impl std::hash::Hash for crate::ScalarValue {
+ fn hash<H: std::hash::Hasher>(&self, hasher: &mut H) {
+ match *self {
+ Self::Sint(v) => v.hash(hasher),
+ Self::Uint(v) => v.hash(hasher),
+ Self::Float(v) => v.to_bits().hash(hasher),
+ Self::Bool(v) => v.hash(hasher),
+ }
+ }
+}
+
+impl super::SwizzleComponent {
+ pub const XYZW: [Self; 4] = [Self::X, Self::Y, Self::Z, Self::W];
+
+ pub const fn index(&self) -> u32 {
+ match *self {
+ Self::X => 0,
+ Self::Y => 1,
+ Self::Z => 2,
+ Self::W => 3,
+ }
+ }
+ pub const fn from_index(idx: u32) -> Self {
+ match idx {
+ 0 => Self::X,
+ 1 => Self::Y,
+ 2 => Self::Z,
+ _ => Self::W,
+ }
+ }
+}
+
+impl super::ImageClass {
+ pub const fn is_multisampled(self) -> bool {
+ match self {
+ crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => multi,
+ crate::ImageClass::Storage { .. } => false,
+ }
+ }
+
+ pub const fn is_mipmapped(self) -> bool {
+ match self {
+ crate::ImageClass::Sampled { multi, .. } | crate::ImageClass::Depth { multi } => !multi,
+ crate::ImageClass::Storage { .. } => false,
+ }
+ }
+}
+
+#[test]
+fn test_matrix_size() {
+ let constants = crate::Arena::new();
+ assert_eq!(
+ crate::TypeInner::Matrix {
+ columns: crate::VectorSize::Tri,
+ rows: crate::VectorSize::Tri,
+ width: 4
+ }
+ .size(&constants),
+ 48,
+ );
+}
diff --git a/third_party/rust/naga/src/proc/namer.rs b/third_party/rust/naga/src/proc/namer.rs
new file mode 100644
index 0000000000..053126b8ac
--- /dev/null
+++ b/third_party/rust/naga/src/proc/namer.rs
@@ -0,0 +1,261 @@
+use crate::{arena::Handle, FastHashMap, FastHashSet};
+use std::borrow::Cow;
+
+pub type EntryPointIndex = u16;
+const SEPARATOR: char = '_';
+
+#[derive(Debug, Eq, Hash, PartialEq)]
+pub enum NameKey {
+ Constant(Handle<crate::Constant>),
+ GlobalVariable(Handle<crate::GlobalVariable>),
+ Type(Handle<crate::Type>),
+ StructMember(Handle<crate::Type>, u32),
+ Function(Handle<crate::Function>),
+ FunctionArgument(Handle<crate::Function>, u32),
+ FunctionLocal(Handle<crate::Function>, Handle<crate::LocalVariable>),
+ EntryPoint(EntryPointIndex),
+ EntryPointLocal(EntryPointIndex, Handle<crate::LocalVariable>),
+ EntryPointArgument(EntryPointIndex, u32),
+}
+
+/// This processor assigns names to all the things in a module
+/// that may need identifiers in a textual backend.
+#[derive(Default)]
+pub struct Namer {
+ /// The last numeric suffix used for each base name. Zero means "no suffix".
+ unique: FastHashMap<String, u32>,
+ keywords: FastHashSet<String>,
+ reserved_prefixes: Vec<String>,
+}
+
+impl Namer {
+ /// Return a form of `string` suitable for use as the base of an identifier.
+ ///
+ /// - Drop leading digits.
+ /// - Retain only alphanumeric and `_` characters.
+ /// - Avoid prefixes in [`Namer::reserved_prefixes`].
+ ///
+ /// The return value is a valid identifier prefix in all of Naga's output languages,
+ /// and it never ends with a `SEPARATOR` character.
+ /// It is used as a key into the unique table.
+ fn sanitize<'s>(&self, string: &'s str) -> Cow<'s, str> {
+ let string = string
+ .trim_start_matches(|c: char| c.is_numeric())
+ .trim_end_matches(SEPARATOR);
+
+ let base = if !string.is_empty()
+ && string
+ .chars()
+ .all(|c: char| c.is_ascii_alphanumeric() || c == '_')
+ {
+ Cow::Borrowed(string)
+ } else {
+ let mut filtered = string
+ .chars()
+ .filter(|&c| c.is_ascii_alphanumeric() || c == '_')
+ .collect::<String>();
+ let stripped_len = filtered.trim_end_matches(SEPARATOR).len();
+ filtered.truncate(stripped_len);
+ if filtered.is_empty() {
+ filtered.push_str("unnamed");
+ }
+ Cow::Owned(filtered)
+ };
+
+ for prefix in &self.reserved_prefixes {
+ if base.starts_with(prefix) {
+ return format!("gen_{base}").into();
+ }
+ }
+
+ base
+ }
+
+ /// Return a new identifier based on `label_raw`.
+ ///
+ /// The result:
+ /// - is a valid identifier even if `label_raw` is not
+ /// - conflicts with no keywords listed in `Namer::keywords`, and
+ /// - is different from any identifier previously constructed by this
+ /// `Namer`.
+ ///
+ /// Guarantee uniqueness by applying a numeric suffix when necessary. If `label_raw`
+ /// itself ends with digits, separate them from the suffix with an underscore.
+ pub fn call(&mut self, label_raw: &str) -> String {
+ use std::fmt::Write as _; // for write!-ing to Strings
+
+ let base = self.sanitize(label_raw);
+ debug_assert!(!base.is_empty() && !base.ends_with(SEPARATOR));
+
+ // This would seem to be a natural place to use `HashMap::entry`. However, `entry`
+ // requires an owned key, and we'd like to avoid heap-allocating strings we're
+ // just going to throw away. The approach below double-hashes only when we create
+ // a new entry, in which case the heap allocation of the owned key was more
+ // expensive anyway.
+ match self.unique.get_mut(base.as_ref()) {
+ Some(count) => {
+ *count += 1;
+ // Add the suffix. This may fit in base's existing allocation.
+ let mut suffixed = base.into_owned();
+ write!(suffixed, "{}{}", SEPARATOR, *count).unwrap();
+ suffixed
+ }
+ None => {
+ let mut suffixed = base.to_string();
+ if base.ends_with(char::is_numeric) || self.keywords.contains(base.as_ref()) {
+ suffixed.push(SEPARATOR);
+ }
+ debug_assert!(!self.keywords.contains(&suffixed));
+ // `self.unique` wants to own its keys. This allocates only if we haven't
+ // already done so earlier.
+ self.unique.insert(base.into_owned(), 0);
+ suffixed
+ }
+ }
+ }
+
+ pub fn call_or(&mut self, label: &Option<String>, fallback: &str) -> String {
+ self.call(match *label {
+ Some(ref name) => name,
+ None => fallback,
+ })
+ }
+
+ /// Enter a local namespace for things like structs.
+ ///
+ /// Struct member names only need to be unique amongst themselves, not
+ /// globally. This function temporarily establishes a fresh, empty naming
+ /// context for the duration of the call to `body`.
+ fn namespace(&mut self, capacity: usize, body: impl FnOnce(&mut Self)) {
+ let fresh = FastHashMap::with_capacity_and_hasher(capacity, Default::default());
+ let outer = std::mem::replace(&mut self.unique, fresh);
+ body(self);
+ self.unique = outer;
+ }
+
+ pub fn reset(
+ &mut self,
+ module: &crate::Module,
+ reserved_keywords: &[&str],
+ reserved_prefixes: &[&str],
+ output: &mut FastHashMap<NameKey, String>,
+ ) {
+ self.reserved_prefixes.clear();
+ self.reserved_prefixes
+ .extend(reserved_prefixes.iter().map(|string| string.to_string()));
+
+ self.unique.clear();
+ self.keywords.clear();
+ self.keywords
+ .extend(reserved_keywords.iter().map(|string| (string.to_string())));
+ let mut temp = String::new();
+
+ for (ty_handle, ty) in module.types.iter() {
+ let ty_name = self.call_or(&ty.name, "type");
+ output.insert(NameKey::Type(ty_handle), ty_name);
+
+ if let crate::TypeInner::Struct { ref members, .. } = ty.inner {
+ // struct members have their own namespace, because access is always prefixed
+ self.namespace(members.len(), |namer| {
+ for (index, member) in members.iter().enumerate() {
+ let name = namer.call_or(&member.name, "member");
+ output.insert(NameKey::StructMember(ty_handle, index as u32), name);
+ }
+ })
+ }
+ }
+
+ for (ep_index, ep) in module.entry_points.iter().enumerate() {
+ let ep_name = self.call(&ep.name);
+ output.insert(NameKey::EntryPoint(ep_index as _), ep_name);
+ for (index, arg) in ep.function.arguments.iter().enumerate() {
+ let name = self.call_or(&arg.name, "param");
+ output.insert(
+ NameKey::EntryPointArgument(ep_index as _, index as u32),
+ name,
+ );
+ }
+ for (handle, var) in ep.function.local_variables.iter() {
+ let name = self.call_or(&var.name, "local");
+ output.insert(NameKey::EntryPointLocal(ep_index as _, handle), name);
+ }
+ }
+
+ for (fun_handle, fun) in module.functions.iter() {
+ let fun_name = self.call_or(&fun.name, "function");
+ output.insert(NameKey::Function(fun_handle), fun_name);
+ for (index, arg) in fun.arguments.iter().enumerate() {
+ let name = self.call_or(&arg.name, "param");
+ output.insert(NameKey::FunctionArgument(fun_handle, index as u32), name);
+ }
+ for (handle, var) in fun.local_variables.iter() {
+ let name = self.call_or(&var.name, "local");
+ output.insert(NameKey::FunctionLocal(fun_handle, handle), name);
+ }
+ }
+
+ for (handle, var) in module.global_variables.iter() {
+ let name = self.call_or(&var.name, "global");
+ output.insert(NameKey::GlobalVariable(handle), name);
+ }
+
+ for (handle, constant) in module.constants.iter() {
+ let label = match constant.name {
+ Some(ref name) => name,
+ None => {
+ use std::fmt::Write;
+ // Try to be more descriptive about the constant values
+ temp.clear();
+ match constant.inner {
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Sint(v),
+ } => write!(temp, "const_{v}i"),
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Uint(v),
+ } => write!(temp, "const_{v}u"),
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Float(v),
+ } => {
+ let abs = v.abs();
+ write!(
+ temp,
+ "const_{}{}",
+ if v < 0.0 { "n" } else { "" },
+ abs.trunc(),
+ )
+ .unwrap();
+ let fract = abs.fract();
+ if fract == 0.0 {
+ write!(temp, "f")
+ } else {
+ write!(temp, "_{:02}f", (fract * 100.0) as i8)
+ }
+ }
+ crate::ConstantInner::Scalar {
+ width: _,
+ value: crate::ScalarValue::Bool(v),
+ } => write!(temp, "const_{v}"),
+ crate::ConstantInner::Composite { ty, components: _ } => {
+ write!(temp, "const_{}", output[&NameKey::Type(ty)])
+ }
+ }
+ .unwrap();
+ &temp
+ }
+ };
+ let name = self.call(label);
+ output.insert(NameKey::Constant(handle), name);
+ }
+ }
+}
+
+#[test]
+fn test() {
+ let mut namer = Namer::default();
+ assert_eq!(namer.call("x"), "x");
+ assert_eq!(namer.call("x"), "x_1");
+ assert_eq!(namer.call("x1"), "x1_");
+}
diff --git a/third_party/rust/naga/src/proc/terminator.rs b/third_party/rust/naga/src/proc/terminator.rs
new file mode 100644
index 0000000000..ca0c3f10bc
--- /dev/null
+++ b/third_party/rust/naga/src/proc/terminator.rs
@@ -0,0 +1,43 @@
+/// Ensure that the given block has return statements
+/// at the end of its control flow.
+///
+/// Note: we don't want to blindly append a return statement
+/// to the end, because it may be either redundant or invalid,
+/// e.g. when the user already has returns in if/else branches.
+pub fn ensure_block_returns(block: &mut crate::Block) {
+ use crate::Statement as S;
+ match block.last_mut() {
+ Some(&mut S::Block(ref mut b)) => {
+ ensure_block_returns(b);
+ }
+ Some(&mut S::If {
+ condition: _,
+ ref mut accept,
+ ref mut reject,
+ }) => {
+ ensure_block_returns(accept);
+ ensure_block_returns(reject);
+ }
+ Some(&mut S::Switch {
+ selector: _,
+ ref mut cases,
+ }) => {
+ for case in cases.iter_mut() {
+ if !case.fall_through {
+ ensure_block_returns(&mut case.body);
+ }
+ }
+ }
+ Some(&mut (S::Emit(_) | S::Break | S::Continue | S::Return { .. } | S::Kill)) => (),
+ Some(
+ &mut (S::Loop { .. }
+ | S::Store { .. }
+ | S::ImageStore { .. }
+ | S::Call { .. }
+ | S::RayQuery { .. }
+ | S::Atomic { .. }
+ | S::Barrier(_)),
+ )
+ | None => block.push(S::Return { value: None }, Default::default()),
+ }
+}
diff --git a/third_party/rust/naga/src/proc/typifier.rs b/third_party/rust/naga/src/proc/typifier.rs
new file mode 100644
index 0000000000..0bb9019a29
--- /dev/null
+++ b/third_party/rust/naga/src/proc/typifier.rs
@@ -0,0 +1,909 @@
+use crate::arena::{Arena, Handle, UniqueArena};
+
+use thiserror::Error;
+
+/// The result of computing an expression's type.
+///
+/// This is the (Rust) type returned by [`ResolveContext::resolve`] to represent
+/// the (Naga) type it ascribes to some expression.
+///
+/// You might expect such a function to simply return a `Handle<Type>`. However,
+/// we want type resolution to be a read-only process, and that would limit the
+/// possible results to types already present in the expression's associated
+/// `UniqueArena<Type>`. Naga IR does have certain expressions whose types are
+/// not certain to be present.
+///
+/// So instead, type resolution returns a `TypeResolution` enum: either a
+/// [`Handle`], referencing some type in the arena, or a [`Value`], holding a
+/// free-floating [`TypeInner`]. This extends the range to cover anything that
+/// can be represented with a `TypeInner` referring to the existing arena.
+///
+/// What sorts of expressions can have types not available in the arena?
+///
+/// - An [`Access`] or [`AccessIndex`] expression applied to a [`Vector`] or
+/// [`Matrix`] must have a [`Scalar`] or [`Vector`] type. But since `Vector`
+/// and `Matrix` represent their element and column types implicitly, not
+/// via a handle, there may not be a suitable type in the expression's
+/// associated arena. Instead, resolving such an expression returns a
+/// `TypeResolution::Value(TypeInner::X { ... })`, where `X` is `Scalar` or
+/// `Vector`.
+///
+/// - Similarly, the type of an [`Access`] or [`AccessIndex`] expression
+/// applied to a *pointer to* a vector or matrix must produce a *pointer to*
+/// a scalar or vector type. These cannot be represented with a
+/// [`TypeInner::Pointer`], since the `Pointer`'s `base` must point into the
+/// arena, and as before, we cannot assume that a suitable scalar or vector
+/// type is there. So we take things one step further and provide
+/// [`TypeInner::ValuePointer`], specifically for the case of pointers to
+/// scalars or vectors. This type fits in a `TypeInner` and is exactly
+/// equivalent to a `Pointer` to a `Vector` or `Scalar`.
+///
+/// So, for example, the type of an `Access` expression applied to a value of type:
+///
+/// ```ignore
+/// TypeInner::Matrix { columns, rows, width }
+/// ```
+///
+/// might be:
+///
+/// ```ignore
+/// TypeResolution::Value(TypeInner::Vector {
+/// size: rows,
+/// kind: ScalarKind::Float,
+/// width,
+/// })
+/// ```
+///
+/// and the type of an access to a pointer of address space `space` to such a
+/// matrix might be:
+///
+/// ```ignore
+/// TypeResolution::Value(TypeInner::ValuePointer {
+/// size: Some(rows),
+/// kind: ScalarKind::Float,
+/// width,
+/// space,
+/// })
+/// ```
+///
+/// [`Handle`]: TypeResolution::Handle
+/// [`Value`]: TypeResolution::Value
+///
+/// [`Access`]: crate::Expression::Access
+/// [`AccessIndex`]: crate::Expression::AccessIndex
+///
+/// [`TypeInner`]: crate::TypeInner
+/// [`Matrix`]: crate::TypeInner::Matrix
+/// [`Pointer`]: crate::TypeInner::Pointer
+/// [`Scalar`]: crate::TypeInner::Scalar
+/// [`ValuePointer`]: crate::TypeInner::ValuePointer
+/// [`Vector`]: crate::TypeInner::Vector
+///
+/// [`TypeInner::Pointer`]: crate::TypeInner::Pointer
+/// [`TypeInner::ValuePointer`]: crate::TypeInner::ValuePointer
+#[derive(Debug, PartialEq)]
+#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
+#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
+pub enum TypeResolution {
+ /// A type stored in the associated arena.
+ Handle(Handle<crate::Type>),
+
+ /// A free-floating [`TypeInner`], representing a type that may not be
+ /// available in the associated arena. However, the `TypeInner` itself may
+ /// contain `Handle<Type>` values referring to types from the arena.
+ ///
+ /// [`TypeInner`]: crate::TypeInner
+ Value(crate::TypeInner),
+}
+
+impl TypeResolution {
+ pub const fn handle(&self) -> Option<Handle<crate::Type>> {
+ match *self {
+ Self::Handle(handle) => Some(handle),
+ Self::Value(_) => None,
+ }
+ }
+
+ pub fn inner_with<'a>(&'a self, arena: &'a UniqueArena<crate::Type>) -> &'a crate::TypeInner {
+ match *self {
+ Self::Handle(handle) => &arena[handle].inner,
+ Self::Value(ref inner) => inner,
+ }
+ }
+}
+
+// Clone is only implemented for numeric variants of `TypeInner`.
+impl Clone for TypeResolution {
+ fn clone(&self) -> Self {
+ use crate::TypeInner as Ti;
+ match *self {
+ Self::Handle(handle) => Self::Handle(handle),
+ Self::Value(ref v) => Self::Value(match *v {
+ Ti::Scalar { kind, width } => Ti::Scalar { kind, width },
+ Ti::Vector { size, kind, width } => Ti::Vector { size, kind, width },
+ Ti::Matrix {
+ rows,
+ columns,
+ width,
+ } => Ti::Matrix {
+ rows,
+ columns,
+ width,
+ },
+ Ti::Pointer { base, space } => Ti::Pointer { base, space },
+ Ti::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ } => Ti::ValuePointer {
+ size,
+ kind,
+ width,
+ space,
+ },
+ _ => unreachable!("Unexpected clone type: {:?}", v),
+ }),
+ }
+ }
+}
+
+impl crate::ConstantInner {
+ pub const fn resolve_type(&self) -> TypeResolution {
+ match *self {
+ Self::Scalar { width, ref value } => TypeResolution::Value(crate::TypeInner::Scalar {
+ kind: value.scalar_kind(),
+ width,
+ }),
+ Self::Composite { ty, components: _ } => TypeResolution::Handle(ty),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Error, PartialEq)]
+pub enum ResolveError {
+ #[error("Index {index} is out of bounds for expression {expr:?}")]
+ OutOfBoundsIndex {
+ expr: Handle<crate::Expression>,
+ index: u32,
+ },
+ #[error("Invalid access into expression {expr:?}, indexed: {indexed}")]
+ InvalidAccess {
+ expr: Handle<crate::Expression>,
+ indexed: bool,
+ },
+ #[error("Invalid sub-access into type {ty:?}, indexed: {indexed}")]
+ InvalidSubAccess {
+ ty: Handle<crate::Type>,
+ indexed: bool,
+ },
+ #[error("Invalid scalar {0:?}")]
+ InvalidScalar(Handle<crate::Expression>),
+ #[error("Invalid vector {0:?}")]
+ InvalidVector(Handle<crate::Expression>),
+ #[error("Invalid pointer {0:?}")]
+ InvalidPointer(Handle<crate::Expression>),
+ #[error("Invalid image {0:?}")]
+ InvalidImage(Handle<crate::Expression>),
+ #[error("Function {name} not defined")]
+ FunctionNotDefined { name: String },
+ #[error("Function without return type")]
+ FunctionReturnsVoid,
+ #[error("Incompatible operands: {0}")]
+ IncompatibleOperands(String),
+ #[error("Function argument {0} doesn't exist")]
+ FunctionArgumentNotFound(u32),
+ #[error("Special type is not registered within the module")]
+ MissingSpecialType,
+}
+
+pub struct ResolveContext<'a> {
+ pub constants: &'a Arena<crate::Constant>,
+ pub types: &'a UniqueArena<crate::Type>,
+ pub special_types: &'a crate::SpecialTypes,
+ pub global_vars: &'a Arena<crate::GlobalVariable>,
+ pub local_vars: &'a Arena<crate::LocalVariable>,
+ pub functions: &'a Arena<crate::Function>,
+ pub arguments: &'a [crate::FunctionArgument],
+}
+
+impl<'a> ResolveContext<'a> {
+ /// Initialize a resolve context from the module.
+ pub const fn with_locals(
+ module: &'a crate::Module,
+ local_vars: &'a Arena<crate::LocalVariable>,
+ arguments: &'a [crate::FunctionArgument],
+ ) -> Self {
+ Self {
+ constants: &module.constants,
+ types: &module.types,
+ special_types: &module.special_types,
+ global_vars: &module.global_variables,
+ local_vars,
+ functions: &module.functions,
+ arguments,
+ }
+ }
+
+ /// Determine the type of `expr`.
+ ///
+ /// The `past` argument must be a closure that can resolve the types of any
+ /// expressions that `expr` refers to. These can be gathered by caching the
+ /// results of prior calls to `resolve`, perhaps as done by the
+ /// [`front::Typifier`] utility type.
+ ///
+ /// Type resolution is a read-only process: this method takes `self` by
+ /// shared reference. However, this means that we cannot add anything to
+ /// `self.types` that we might need to describe `expr`. To work around this,
+ /// this method returns a [`TypeResolution`], rather than simply returning a
+ /// `Handle<Type>`; see the documentation for [`TypeResolution`] for
+ /// details.
+ ///
+ /// [`front::Typifier`]: crate::front::Typifier
+ pub fn resolve(
+ &self,
+ expr: &crate::Expression,
+ past: impl Fn(Handle<crate::Expression>) -> Result<&'a TypeResolution, ResolveError>,
+ ) -> Result<TypeResolution, ResolveError> {
+ use crate::TypeInner as Ti;
+ let types = self.types;
+ Ok(match *expr {
+ crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) {
+ // Arrays and matrices can only be indexed dynamically behind a
+ // pointer, but that's a validation error, not a type error, so
+ // go ahead provide a type here.
+ Ti::Array { base, .. } => TypeResolution::Handle(base),
+ Ti::Matrix { rows, width, .. } => TypeResolution::Value(Ti::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ Ti::Vector {
+ size: _,
+ kind,
+ width,
+ } => TypeResolution::Value(Ti::Scalar { kind, width }),
+ Ti::ValuePointer {
+ size: Some(_),
+ kind,
+ width,
+ space,
+ } => TypeResolution::Value(Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ }),
+ Ti::Pointer { base, space } => {
+ TypeResolution::Value(match types[base].inner {
+ Ti::Array { base, .. } => Ti::Pointer { base, space },
+ Ti::Vector {
+ size: _,
+ kind,
+ width,
+ } => Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ },
+ // Matrices are only dynamically indexed behind a pointer
+ Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ } => Ti::ValuePointer {
+ kind: crate::ScalarKind::Float,
+ size: Some(rows),
+ width,
+ space,
+ },
+ ref other => {
+ log::error!("Access sub-type {:?}", other);
+ return Err(ResolveError::InvalidSubAccess {
+ ty: base,
+ indexed: false,
+ });
+ }
+ })
+ }
+ Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
+ ref other => {
+ log::error!("Access type {:?}", other);
+ return Err(ResolveError::InvalidAccess {
+ expr: base,
+ indexed: false,
+ });
+ }
+ },
+ crate::Expression::AccessIndex { base, index } => {
+ match *past(base)?.inner_with(types) {
+ Ti::Vector { size, kind, width } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(Ti::Scalar { kind, width })
+ }
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(crate::TypeInner::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ })
+ }
+ Ti::Array { base, .. } => TypeResolution::Handle(base),
+ Ti::Struct { ref members, .. } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
+ TypeResolution::Handle(member.ty)
+ }
+ Ti::ValuePointer {
+ size: Some(size),
+ kind,
+ width,
+ space,
+ } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ TypeResolution::Value(Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ })
+ }
+ Ti::Pointer {
+ base: ty_base,
+ space,
+ } => TypeResolution::Value(match types[ty_base].inner {
+ Ti::Array { base, .. } => Ti::Pointer { base, space },
+ Ti::Vector { size, kind, width } => {
+ if index >= size as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ Ti::ValuePointer {
+ size: None,
+ kind,
+ width,
+ space,
+ }
+ }
+ Ti::Matrix {
+ rows,
+ columns,
+ width,
+ } => {
+ if index >= columns as u32 {
+ return Err(ResolveError::OutOfBoundsIndex { expr: base, index });
+ }
+ Ti::ValuePointer {
+ size: Some(rows),
+ kind: crate::ScalarKind::Float,
+ width,
+ space,
+ }
+ }
+ Ti::Struct { ref members, .. } => {
+ let member = members
+ .get(index as usize)
+ .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?;
+ Ti::Pointer {
+ base: member.ty,
+ space,
+ }
+ }
+ ref other => {
+ log::error!("Access index sub-type {:?}", other);
+ return Err(ResolveError::InvalidSubAccess {
+ ty: ty_base,
+ indexed: true,
+ });
+ }
+ }),
+ Ti::BindingArray { base, .. } => TypeResolution::Handle(base),
+ ref other => {
+ log::error!("Access index type {:?}", other);
+ return Err(ResolveError::InvalidAccess {
+ expr: base,
+ indexed: true,
+ });
+ }
+ }
+ }
+ crate::Expression::Constant(h) => match self.constants[h].inner {
+ crate::ConstantInner::Scalar { width, ref value } => {
+ TypeResolution::Value(Ti::Scalar {
+ kind: value.scalar_kind(),
+ width,
+ })
+ }
+ crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty),
+ },
+ crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) {
+ Ti::Scalar { kind, width } => {
+ TypeResolution::Value(Ti::Vector { size, kind, width })
+ }
+ ref other => {
+ log::error!("Scalar type {:?}", other);
+ return Err(ResolveError::InvalidScalar(value));
+ }
+ },
+ crate::Expression::Swizzle {
+ size,
+ vector,
+ pattern: _,
+ } => match *past(vector)?.inner_with(types) {
+ Ti::Vector {
+ size: _,
+ kind,
+ width,
+ } => TypeResolution::Value(Ti::Vector { size, kind, width }),
+ ref other => {
+ log::error!("Vector type {:?}", other);
+ return Err(ResolveError::InvalidVector(vector));
+ }
+ },
+ crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
+ crate::Expression::FunctionArgument(index) => {
+ let arg = self
+ .arguments
+ .get(index as usize)
+ .ok_or(ResolveError::FunctionArgumentNotFound(index))?;
+ TypeResolution::Handle(arg.ty)
+ }
+ crate::Expression::GlobalVariable(h) => {
+ let var = &self.global_vars[h];
+ if var.space == crate::AddressSpace::Handle {
+ TypeResolution::Handle(var.ty)
+ } else {
+ TypeResolution::Value(Ti::Pointer {
+ base: var.ty,
+ space: var.space,
+ })
+ }
+ }
+ crate::Expression::LocalVariable(h) => {
+ let var = &self.local_vars[h];
+ TypeResolution::Value(Ti::Pointer {
+ base: var.ty,
+ space: crate::AddressSpace::Function,
+ })
+ }
+ crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) {
+ Ti::Pointer { base, space: _ } => {
+ if let Ti::Atomic { kind, width } = types[base].inner {
+ TypeResolution::Value(Ti::Scalar { kind, width })
+ } else {
+ TypeResolution::Handle(base)
+ }
+ }
+ Ti::ValuePointer {
+ size,
+ kind,
+ width,
+ space: _,
+ } => TypeResolution::Value(match size {
+ Some(size) => Ti::Vector { size, kind, width },
+ None => Ti::Scalar { kind, width },
+ }),
+ ref other => {
+ log::error!("Pointer type {:?}", other);
+ return Err(ResolveError::InvalidPointer(pointer));
+ }
+ },
+ crate::Expression::ImageSample {
+ image,
+ gather: Some(_),
+ ..
+ } => match *past(image)?.inner_with(types) {
+ Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector {
+ kind: match class {
+ crate::ImageClass::Sampled { kind, multi: _ } => kind,
+ _ => crate::ScalarKind::Float,
+ },
+ width: 4,
+ size: crate::VectorSize::Quad,
+ }),
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::Expression::ImageSample { image, .. }
+ | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) {
+ Ti::Image { class, .. } => TypeResolution::Value(match class {
+ crate::ImageClass::Depth { multi: _ } => Ti::Scalar {
+ kind: crate::ScalarKind::Float,
+ width: 4,
+ },
+ crate::ImageClass::Sampled { kind, multi: _ } => Ti::Vector {
+ kind,
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ crate::ImageClass::Storage { format, .. } => Ti::Vector {
+ kind: format.into(),
+ width: 4,
+ size: crate::VectorSize::Quad,
+ },
+ }),
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query {
+ crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) {
+ Ti::Image { dim, .. } => match dim {
+ crate::ImageDimension::D1 => Ti::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ crate::ImageDimension::D2 | crate::ImageDimension::Cube => Ti::Vector {
+ size: crate::VectorSize::Bi,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ crate::ImageDimension::D3 => Ti::Vector {
+ size: crate::VectorSize::Tri,
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ },
+ ref other => {
+ log::error!("Image type {:?}", other);
+ return Err(ResolveError::InvalidImage(image));
+ }
+ },
+ crate::ImageQuery::NumLevels
+ | crate::ImageQuery::NumLayers
+ | crate::ImageQuery::NumSamples => Ti::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ },
+ }),
+ crate::Expression::Unary { expr, .. } => past(expr)?.clone(),
+ crate::Expression::Binary { op, left, right } => match op {
+ crate::BinaryOperator::Add
+ | crate::BinaryOperator::Subtract
+ | crate::BinaryOperator::Divide
+ | crate::BinaryOperator::Modulo => past(left)?.clone(),
+ crate::BinaryOperator::Multiply => {
+ let (res_left, res_right) = (past(left)?, past(right)?);
+ match (res_left.inner_with(types), res_right.inner_with(types)) {
+ (
+ &Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ },
+ &Ti::Matrix { columns, .. },
+ ) => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ width,
+ }),
+ (
+ &Ti::Matrix {
+ columns: _,
+ rows,
+ width,
+ },
+ &Ti::Vector { .. },
+ ) => TypeResolution::Value(Ti::Vector {
+ size: rows,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ (
+ &Ti::Vector { .. },
+ &Ti::Matrix {
+ columns,
+ rows: _,
+ width,
+ },
+ ) => TypeResolution::Value(Ti::Vector {
+ size: columns,
+ kind: crate::ScalarKind::Float,
+ width,
+ }),
+ (&Ti::Scalar { .. }, _) => res_right.clone(),
+ (_, &Ti::Scalar { .. }) => res_left.clone(),
+ (&Ti::Vector { .. }, &Ti::Vector { .. }) => res_left.clone(),
+ (tl, tr) => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{tl:?} * {tr:?}"
+ )))
+ }
+ }
+ }
+ crate::BinaryOperator::Equal
+ | crate::BinaryOperator::NotEqual
+ | crate::BinaryOperator::Less
+ | crate::BinaryOperator::LessEqual
+ | crate::BinaryOperator::Greater
+ | crate::BinaryOperator::GreaterEqual
+ | crate::BinaryOperator::LogicalAnd
+ | crate::BinaryOperator::LogicalOr => {
+ let kind = crate::ScalarKind::Bool;
+ let width = crate::BOOL_WIDTH;
+ let inner = match *past(left)?.inner_with(types) {
+ Ti::Scalar { .. } => Ti::Scalar { kind, width },
+ Ti::Vector { size, .. } => Ti::Vector { size, kind, width },
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{op:?}({other:?}, _)"
+ )))
+ }
+ };
+ TypeResolution::Value(inner)
+ }
+ crate::BinaryOperator::And
+ | crate::BinaryOperator::ExclusiveOr
+ | crate::BinaryOperator::InclusiveOr
+ | crate::BinaryOperator::ShiftLeft
+ | crate::BinaryOperator::ShiftRight => past(left)?.clone(),
+ },
+ crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
+ crate::Expression::Select { accept, .. } => past(accept)?.clone(),
+ crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
+ crate::Expression::Relational { fun, argument } => match fun {
+ crate::RelationalFunction::All | crate::RelationalFunction::Any => {
+ TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ })
+ }
+ crate::RelationalFunction::IsNan
+ | crate::RelationalFunction::IsInf
+ | crate::RelationalFunction::IsFinite
+ | crate::RelationalFunction::IsNormal => match *past(argument)?.inner_with(types) {
+ Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ }),
+ Ti::Vector { size, .. } => TypeResolution::Value(Ti::Vector {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ size,
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{fun:?}({other:?})"
+ )))
+ }
+ },
+ },
+ crate::Expression::Math {
+ fun,
+ arg,
+ arg1,
+ arg2: _,
+ arg3: _,
+ } => {
+ use crate::MathFunction as Mf;
+ let res_arg = past(arg)?;
+ match fun {
+ // comparison
+ Mf::Abs |
+ Mf::Min |
+ Mf::Max |
+ Mf::Clamp |
+ Mf::Saturate |
+ // trigonometry
+ Mf::Cos |
+ Mf::Cosh |
+ Mf::Sin |
+ Mf::Sinh |
+ Mf::Tan |
+ Mf::Tanh |
+ Mf::Acos |
+ Mf::Asin |
+ Mf::Atan |
+ Mf::Atan2 |
+ Mf::Asinh |
+ Mf::Acosh |
+ Mf::Atanh |
+ Mf::Radians |
+ Mf::Degrees |
+ // decomposition
+ Mf::Ceil |
+ Mf::Floor |
+ Mf::Round |
+ Mf::Fract |
+ Mf::Trunc |
+ Mf::Modf |
+ Mf::Frexp |
+ Mf::Ldexp |
+ // exponent
+ Mf::Exp |
+ Mf::Exp2 |
+ Mf::Log |
+ Mf::Log2 |
+ Mf::Pow => res_arg.clone(),
+ // geometry
+ Mf::Dot => match *res_arg.inner_with(types) {
+ Ti::Vector {
+ kind,
+ size: _,
+ width,
+ } => TypeResolution::Value(Ti::Scalar { kind, width }),
+ ref other =>
+ return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?}, _)")
+ )),
+ },
+ Mf::Outer => {
+ let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands(
+ format!("{fun:?}(_, None)")
+ ))?;
+ match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) {
+ (&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => TypeResolution::Value(Ti::Matrix { columns, rows, width }),
+ (left, right) =>
+ return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({left:?}, {right:?})")
+ )),
+ }
+ },
+ Mf::Cross => res_arg.clone(),
+ Mf::Distance |
+ Mf::Length => match *res_arg.inner_with(types) {
+ Ti::Scalar {width,kind} |
+ Ti::Vector {width,kind,size:_} => TypeResolution::Value(Ti::Scalar { kind, width }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Normalize |
+ Mf::FaceForward |
+ Mf::Reflect |
+ Mf::Refract => res_arg.clone(),
+ // computational
+ Mf::Sign |
+ Mf::Fma |
+ Mf::Mix |
+ Mf::Step |
+ Mf::SmoothStep |
+ Mf::Sqrt |
+ Mf::InverseSqrt => res_arg.clone(),
+ Mf::Transpose => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => TypeResolution::Value(Ti::Matrix {
+ columns: rows,
+ rows: columns,
+ width,
+ }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Inverse => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } if columns == rows => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ width,
+ }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ Mf::Determinant => match *res_arg.inner_with(types) {
+ Ti::Matrix {
+ width,
+ ..
+ } => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Float, width }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ // bits
+ Mf::CountTrailingZeros |
+ Mf::CountLeadingZeros |
+ Mf::CountOneBits |
+ Mf::ReverseBits |
+ Mf::ExtractBits |
+ Mf::InsertBits |
+ Mf::FindLsb |
+ Mf::FindMsb => match *res_arg.inner_with(types) {
+ Ti::Scalar { kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
+ TypeResolution::Value(Ti::Scalar { kind, width }),
+ Ti::Vector { size, kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
+ TypeResolution::Value(Ti::Vector { size, kind, width }),
+ ref other => return Err(ResolveError::IncompatibleOperands(
+ format!("{fun:?}({other:?})")
+ )),
+ },
+ // data packing
+ Mf::Pack4x8snorm |
+ Mf::Pack4x8unorm |
+ Mf::Pack2x16snorm |
+ Mf::Pack2x16unorm |
+ Mf::Pack2x16float => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Uint, width: 4 }),
+ // data unpacking
+ Mf::Unpack4x8snorm |
+ Mf::Unpack4x8unorm => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Quad, kind: crate::ScalarKind::Float, width: 4 }),
+ Mf::Unpack2x16snorm |
+ Mf::Unpack2x16unorm |
+ Mf::Unpack2x16float => TypeResolution::Value(Ti::Vector { size: crate::VectorSize::Bi, kind: crate::ScalarKind::Float, width: 4 }),
+ }
+ }
+ crate::Expression::As {
+ expr,
+ kind,
+ convert,
+ } => match *past(expr)?.inner_with(types) {
+ Ti::Scalar { kind: _, width } => TypeResolution::Value(Ti::Scalar {
+ kind,
+ width: convert.unwrap_or(width),
+ }),
+ Ti::Vector {
+ kind: _,
+ size,
+ width,
+ } => TypeResolution::Value(Ti::Vector {
+ kind,
+ size,
+ width: convert.unwrap_or(width),
+ }),
+ Ti::Matrix {
+ columns,
+ rows,
+ width,
+ } => TypeResolution::Value(Ti::Matrix {
+ columns,
+ rows,
+ width: convert.unwrap_or(width),
+ }),
+ ref other => {
+ return Err(ResolveError::IncompatibleOperands(format!(
+ "{other:?} as {kind:?}"
+ )))
+ }
+ },
+ crate::Expression::CallResult(function) => {
+ let result = self.functions[function]
+ .result
+ .as_ref()
+ .ok_or(ResolveError::FunctionReturnsVoid)?;
+ TypeResolution::Handle(result.ty)
+ }
+ crate::Expression::ArrayLength(_) => TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Uint,
+ width: 4,
+ }),
+ crate::Expression::RayQueryProceedResult => TypeResolution::Value(Ti::Scalar {
+ kind: crate::ScalarKind::Bool,
+ width: crate::BOOL_WIDTH,
+ }),
+ crate::Expression::RayQueryGetIntersection { .. } => {
+ let result = self
+ .special_types
+ .ray_intersection
+ .ok_or(ResolveError::MissingSpecialType)?;
+ TypeResolution::Handle(result)
+ }
+ })
+ }
+}
+
+#[test]
+fn test_error_size() {
+ use std::mem::size_of;
+ assert_eq!(size_of::<ResolveError>(), 32);
+}