diff options
Diffstat (limited to 'third_party/rust/naga/src/proc/layouter.rs')
-rw-r--r-- | third_party/rust/naga/src/proc/layouter.rs | 256 |
1 files changed, 256 insertions, 0 deletions
diff --git a/third_party/rust/naga/src/proc/layouter.rs b/third_party/rust/naga/src/proc/layouter.rs new file mode 100644 index 0000000000..65369d1cc8 --- /dev/null +++ b/third_party/rust/naga/src/proc/layouter.rs @@ -0,0 +1,256 @@ +use crate::arena::{Arena, Handle, UniqueArena}; +use std::{fmt::Display, num::NonZeroU32, ops}; + +/// A newtype struct where its only valid values are powers of 2 +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Alignment(NonZeroU32); + +impl Alignment { + pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) }); + pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) }); + pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) }); + pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) }); + pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) }); + + pub const MIN_UNIFORM: Self = Self::SIXTEEN; + + pub const fn new(n: u32) -> Option<Self> { + if n.is_power_of_two() { + // SAFETY: value can't be 0 since we just checked if it's a power of 2 + Some(Self(unsafe { NonZeroU32::new_unchecked(n) })) + } else { + None + } + } + + /// # Panics + /// If `width` is not a power of 2 + pub fn from_width(width: u8) -> Self { + Self::new(width as u32).unwrap() + } + + /// Returns whether or not `n` is a multiple of this alignment. + pub const fn is_aligned(&self, n: u32) -> bool { + // equivalent to: `n % self.0.get() == 0` but much faster + n & (self.0.get() - 1) == 0 + } + + /// Round `n` up to the nearest alignment boundary. + pub const fn round_up(&self, n: u32) -> u32 { + // equivalent to: + // match n % self.0.get() { + // 0 => n, + // rem => n + (self.0.get() - rem), + // } + let mask = self.0.get() - 1; + (n + mask) & !mask + } +} + +impl Display for Alignment { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.get().fmt(f) + } +} + +impl ops::Mul<u32> for Alignment { + type Output = u32; + + fn mul(self, rhs: u32) -> Self::Output { + self.0.get() * rhs + } +} + +impl ops::Mul for Alignment { + type Output = Alignment; + + fn mul(self, rhs: Alignment) -> Self::Output { + // SAFETY: both lhs and rhs are powers of 2, the result will be a power of 2 + Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) }) + } +} + +impl From<crate::VectorSize> for Alignment { + fn from(size: crate::VectorSize) -> Self { + match size { + crate::VectorSize::Bi => Alignment::TWO, + crate::VectorSize::Tri => Alignment::FOUR, + crate::VectorSize::Quad => Alignment::FOUR, + } + } +} + +/// Size and alignment information for a type. +#[derive(Clone, Copy, Debug, Hash, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct TypeLayout { + pub size: u32, + pub alignment: Alignment, +} + +impl TypeLayout { + /// Produce the stride as if this type is a base of an array. + pub const fn to_stride(&self) -> u32 { + self.alignment.round_up(self.size) + } +} + +/// Helper processor that derives the sizes of all types. +/// +/// `Layouter` uses the default layout algorithm/table, described in +/// [WGSL §4.3.7, "Memory Layout"] +/// +/// A `Layouter` may be indexed by `Handle<Type>` values: `layouter[handle]` is the +/// layout of the type whose handle is `handle`. +/// +/// [WGSL §4.3.7, "Memory Layout"](https://gpuweb.github.io/gpuweb/wgsl/#memory-layouts) +#[derive(Debug, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct Layouter { + /// Layouts for types in an arena, indexed by `Handle` index. + layouts: Vec<TypeLayout>, +} + +impl ops::Index<Handle<crate::Type>> for Layouter { + type Output = TypeLayout; + fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout { + &self.layouts[handle.index()] + } +} + +#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] +pub enum LayoutErrorInner { + #[error("Array element type {0:?} doesn't exist")] + InvalidArrayElementType(Handle<crate::Type>), + #[error("Struct member[{0}] type {1:?} doesn't exist")] + InvalidStructMemberType(u32, Handle<crate::Type>), + #[error("Type width must be a power of two")] + NonPowerOfTwoWidth, +} + +#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] +#[error("Error laying out type {ty:?}: {inner}")] +pub struct LayoutError { + pub ty: Handle<crate::Type>, + pub inner: LayoutErrorInner, +} + +impl LayoutErrorInner { + const fn with(self, ty: Handle<crate::Type>) -> LayoutError { + LayoutError { ty, inner: self } + } +} + +impl Layouter { + /// Remove all entries from this `Layouter`, retaining storage. + pub fn clear(&mut self) { + self.layouts.clear(); + } + + /// Extend this `Layouter` with layouts for any new entries in `types`. + /// + /// Ensure that every type in `types` has a corresponding [TypeLayout] in + /// [`self.layouts`]. + /// + /// Some front ends need to be able to compute layouts for existing types + /// while module construction is still in progress and new types are still + /// being added. This function assumes that the `TypeLayout` values already + /// present in `self.layouts` cover their corresponding entries in `types`, + /// and extends `self.layouts` as needed to cover the rest. Thus, a front + /// end can call this function at any time, passing its current type and + /// constant arenas, and then assume that layouts are available for all + /// types. + #[allow(clippy::or_fun_call)] + pub fn update( + &mut self, + types: &UniqueArena<crate::Type>, + constants: &Arena<crate::Constant>, + ) -> Result<(), LayoutError> { + use crate::TypeInner as Ti; + + for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { + let size = ty.inner.size(constants); + let layout = match ty.inner { + Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { + let alignment = Alignment::new(width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { size, alignment } + } + Ti::Vector { + size: vec_size, + width, + .. + } => { + let alignment = Alignment::new(width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(vec_size) * alignment, + } + } + Ti::Matrix { + columns: _, + rows, + width, + } => { + let alignment = Alignment::new(width as u32) + .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?; + TypeLayout { + size, + alignment: Alignment::from(rows) * alignment, + } + } + Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { + size, + alignment: Alignment::ONE, + }, + Ti::Array { + base, + stride: _, + size: _, + } => TypeLayout { + size, + alignment: if base < ty_handle { + self[base].alignment + } else { + return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle)); + }, + }, + Ti::Struct { span, ref members } => { + let mut alignment = Alignment::ONE; + for (index, member) in members.iter().enumerate() { + alignment = if member.ty < ty_handle { + alignment.max(self[member.ty].alignment) + } else { + return Err(LayoutErrorInner::InvalidStructMemberType( + index as u32, + member.ty, + ) + .with(ty_handle)); + }; + } + TypeLayout { + size: span, + alignment, + } + } + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => TypeLayout { + size, + alignment: Alignment::ONE, + }, + }; + debug_assert!(size <= layout.size); + self.layouts.push(layout); + } + + Ok(()) + } +} |